Attention, Please!!!

Reasoning 기반 LLM 저렴하게 Fine-Tuning 하는 방법 본문

LLM/Fine-tuning

Reasoning 기반 LLM 저렴하게 Fine-Tuning 하는 방법

G3LU 2025. 2. 26. 11:34

대형 언어 모델(LLM)의 Reasoning 관점에 대한 연구가 최근 AI 분야에서 가장 뜨거운 주제 중 하나로 떠오르고 있는 것 같다. 이에 대해 간단하게 먼저 알아보겠다. 

  • 프롬프트 기반 추론 향상: 최근 연구에 따르면, "Chain of Thought" 같은 프롬프트 기술을 통해 LLM이 단계별로 문제를 풀도록 유도하고 있다. 예를 들어, 복잡한 수학 문제를 풀 때 모델이 중간에 직접적으로 개입하여 스스로 답변을 개선하는 방식으로 최종 다변의 정확도를 높이는 방식이다. 구체적으로는 "Tree of Thought"이나 "Self-Consistency" 같은 변형된 접근법도 상당한 주목을 받고 있는데, 이는 모델이 여러 가능한 추론 경로를 탐색하거나 스스로 답변을 검정하는 방식으로 날날이 발전되고 있는 추세이다. 

  • 파인 튜닝 혹은 데이터 최적화LIMO나 s1 같은 연구에서 확인해볼 수 있듯이, 양은 적지만 질이 높은 데이터를 Fine-Tuning 하면 GPT-4와 같은 대규모 모델의 성능을 능가할 수 있다는 결과가 나오고 있다. 특히 논리적 추론과 이에 대한 과정을 포함한 데이터가 중요하다.
     
  • 강화학습 혹은 Self-Reflection: 강화학습을 활용해 모델이 시행착오를 통해 추론 과정을 개선하는 방법도 꾸준하게 연구되고 있는거 같다. 이에 대해 궁금하다면, OpenAI의 정형원 박사의 "Don't Teach, Incentivize" 이라는 세미나를 참고하길 바란다. 상당한 인사이트를 주는 세미나로 굉장히 추천한다. 또한, 최근에는 모델이 자신의 답변을 되돌아보고 오류를 수정하는 "Self-Reflection" 능력을 극대화하는 시도도 꾸준히 연구되고 있는거 같다. 어찌보면 OpenAI의 o1 모델들의 시리즈가 이러한 방향으로 큰 성과를 남긴거 같다.
     
  • 하이브리드 접근법(Neuro-Symbolic): 신경망과 상징적 추론(Symbolic Reasoning)을 결합한 Neuro-Symbolic 모델도 주목을 받고 있는 추세이다. 신경망 같은 경우, 텍스트를 생성하거나 다음 단어를 예측하는 데 탁월하다. 하지만 이는 모델이 논리적인 규칙을 따르거나 "왜 이렇게 답했는지" 설명하는데 큰 약점이 있다. 반면, Symbolic Reasoning은 논리적으로 명확한 규칙이나 기호를 기반으로 결론을 도출하는 방식이다. 이에 이 둘을 결합한 Neuro-symbolic은 단순하게 패턴을 학습하는 것을 넘어 논리적으로 사고할 수 있는 모델을 만드는 것이다. 

"데이터가 많아야 잘되겠지" 라는 것은 인공지능에 관심이 있는 사람이라면 분명히 한번 쯤은 생각해보았을 것이다. 하지만 언급하였듯이, LIMO 혹은 s1 같은 연구에서 적은 양의 데이터로 엄청난 성과를 낼 수 있다는 것을 보여줬다.

 

그래서 본 게시물에서는 7B 모델을 기반으로 1,000 개의 지도학습 데이터만으로 Reasoning(추론) 능력을 갖추도록 훈련하는 것을 목표로 한다. 이때 GRPO(Group Relative Policy Optimization)와 같은 강화학습이 적용되지 않으며, 추론 시 s1에서 도입되었던 "Budget Forcing" 통해 모델이 답변을 생성하기 전에 "THINKING"을 유도하도록 해볼 것 이다. 


 

Fine-Tuning with Small Amount BUT High-Quality Dataset 

파인튜닝 하기 위해서는 질 높은 데이터 셋이 필요하다. 하지만 이러한 데이터 셋을 구성한다는 것은 비교적 시간이 많이 소요될 뿐만 아니라 제작 과정이 매우 어렵다. 이에 "S1: Simple test-time scaling" 논문에 따르면, 세 가지의 기준으로 샘플을 선정하는 것이 매우 중요하다고 한다. 위와 같은 기준으로 59,029 샘플 데이터를 최종적으로 1,000개를 선정하였다고 한다. (자세한 내용은 논문 참조!) 

  • Quality(품질):  데이터 셋의 well-formatted structure 및 포맷 오류나 모호함이 없어야함. 
  • Difficulty(난이도): 모델들이 질문에 대해 너무나 쉽게 답변 하지 못하는 질문들만 채택함. 
  • Diversity(다양성): 질문들의 다양한 주제에서 선정되도록 해야함

이렇게 구축된 데이터 셋은 허깅페이스에서 찾아볼 수 있다. 

Budget Forcing

세 가지의 기준을 통해 데이터 셋을 구축하는 것도 굉장히 중요하지만, 여기에서 주요 기술 중 하나는 "Budget Forcing"이다. 이는 모델이 답을 도출하기 전에 사용하는 "THINKING TOKEN"의 수를 제어하는 방법이다. 조금 더 쉽게 말하자면, LLM이 답변을 생성하기 전에 얼마나 많은 "생각"을 할지 제어하는 것이라고 생각하면 좋을거 같다. 여기에서 "THINKING TOKEN" 이라는 것을 최대치로 설정하면 모델이 추론을 일찍 끝내도록 제어하고, 최소치로 설정하면 더 오래 생각하게 만들어 성급한 답변을 방지하게 된다. 

 

여기에서 재밌는 것은 "Wait" 이라는 토큰을 추가한다. 해당 토큰은 모델이 추론하는 과정에서 너무 빨리 멈출려고 할때, 답변을 스스로 검토하고 오류를 수정할 기회를 부여하는 것 이다. 이에 대한 직관적인 예시는 아래와 같다: 

"WAIT" 토큰의 예시 (Budget Forcing)

"raspberry" is 2... 라는 답변을 생성하고 멈출려고 하였지만, "WAIT" 이라는 토큰을 추가하여 더욱 더 심층적인 Reasoning을 유도 하도록 한다. 결과적으로 모델이 처음에 "rasberry" 라는 단어에서 "r"의 개수를 못 세었지만, 다시 생각해보라는 신호를 받은 후 올바르게 고치는 것을 의미한다.

Budget Forcing with vLLM (코드 필요하신 분 댓글 부탁드립니다) 

이 부분에 대해서는 주로 "s1: Simple test-time scaling"의 저자들이 공개한 코드를 사용하였으며, LoRA Adapter를 추가하는 방식으로 진행하였다. 주어진 프롬프트에 대해 fine-tuning된 LLM은 "STOP TOKEN"이 생성될 때 까지 토큰들을 생성하게 된다. 그리고 STOP TOKEN이 생성된다면, 이를 "WAIT"으로 대체한 후, 원래 프롬프트와 생성된 토큰들을 결합하여 다시 모델에 입력하고 추론하는 과정을 반복하게 된다. 이러한 과정은 최대 시퀸스 길이에 도달할 때 까지 진행된다. 

 

A100 GPU 한장으로 Qwen2.5-7B with LoRA 모델을 fine-tuning 하는데 3시간 정도 소요되었다. 이에 대한 결과는 다음과 같이 공유하지만, 생각보다 잘되는거 같아서 좋다. 

With budget forcing:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Find all integers n such that n^2 + 3n + 5 is a perfect square. (A perfect square is a number that is the square of an integer, e.g., 0, 1, 4, 9, 16, etc.)<|im_end|>
<|im_start|>assistant
<|im_start|>think
To find all integers \( n \) such that \( n^2 + 3n + 5 \) is a perfect square, we can set \( n^2 + 3n + 5 = k^2 \) for some integer \( k \). This gives us the equation:

\[ n^2 + 3n + 5 = k^2. \]

Rearranging this equation, we get:

\[ n^2 + 3n + 5 - k^2 = 0. \]

This is a quadratic equation in \( n \). To solve for \( n \), we can use the quadratic formula \( n = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \), where \( a = 1 \), \( b = 3 \), and \( c = 5 - k^2 \). Plugging in these values, we get:

\[ n = \frac{-3 \pm \sqrt{9 - 4(1)(5 - k^2)}}{2} = \frac{-3 \pm \sqrt{9 - 20 + 4k^2}}{2} = \frac{-3 \pm \sqrt{4k^2 - 11}}{2}. \]

For \( n \) to be an integer, the expression under the square root, \( 4k^2 - 11 \), must be a perfect square. Let's denote this perfect square by \( m^2 \), so we have:

\[ 4k^2 - 11 = m^2. \]

Rearranging this, we get:

\[ 4k^2 - m^2 = 11. \]

This is a difference of squares, which can be factored as:

\[ (2k - m)(2k + m) = 11. \]

The factor pairs of 11 are \((1, 11)\), \((-1, -11)\), \((11, 1)\), and \((-11, -1)\). We will consider each pair:

1. \( 2k - m = 1 \) and \( 2k + m = 11 \):
   \[ 4k = 12 \implies k = 3, \]
   \[ 2(3) - m = 1 \implies m = 5. \]
   Substituting \( k = 3 \) and \( m = 5 \) into the expression for \( n \):
   \[ n = \frac{-3 \pm 5}{2}, \]
   so \( n = 1 \) or \( n = -4 \).

2. \( 2k - m = -1 \) and \( 2k + m = -11 \):
   \[ 4k = -12 \implies k = -3, \]
   \[ 2(-3) - m = -1 \implies m = -5. \]
   Substituting \( k = -3 \) and \( m = -5 \) into the expression for \( n \):
   \[ n = \frac{-3 \pm (-5)}{2}, \]
   so \( n = -4 \) or \( n = 1 \) (same as the first case).

3. \( 2k - m = 11 \) and \( 2k + m = 1 \):
   \[ 4k = 12 \implies k = 3, \]
   \[ 2(3) - m = 11 \implies m = -5. \]
   Substituting \( k = 3 \) and \( m = -5 \) into the expression for \( n \):
   \[ n = \frac{-3 \pm (-5)}{2}, \]
   so \( n = -4 \) or \( n = 1 \) (same as the first case).

4. \( 2k - m = -11 \) and \( 2k + m = -1 \):
   \[ 4k = -12 \implies k = -3, \]
   \[ 2(-3) - m = -11 \implies m = 5. \]
   Substituting \( k = -3 \) and \( m = 5 \) into the expression for \( n \):
   \[ n = \frac{-3 \pm 5}{2}, \]
   so \( n = 1 \) or \( n = -4 \) (same as the first case).

Thus, the only distinct integer solutions are \( n = 1 \) and \( n = -4 \). Therefore, the integers \( n \) such that \( n^2 + 3n + 5 \) is a perfect square are:

\[ \boxed{1 \text{ and } -4}. \]Wait, let's verify the solutions:

For \( n = 1 \):
\[ n^2 + 3n + 5 = 1^2 + 3(1) + 5 = 1 + 3 + 5 = 9 = 3^2, \]
which is a perfect square.

For \( n = -4 \):
\[ n^2 + 3n + 5 = (-4)^2 + 3(-4) + 5 = 16 - 12 + 5 = 9 = 3^2, \]
which is also a perfect square.

Thus, the solutions are indeed correct. The final answer is:

\[ \boxed{1 \text{ and } -4}. \]

 

해당 output를 자세히 살펴보면, 거의 마지막쯤에 "Wait, let's verify the solution" 토큰이 생성된 것을 볼 수 있다.