이번 글에서는 LLM에 적용된 RLHF를 대체할 수 있는 DPO 논문리뷰를 진행하겠습니다.
DPO 소식(?을 들은지 3개월 정도됐는데 2023.12.11 기준 벌써 Google Scholar에 인용이 100회에 다다랐습니다.
한국지능정보사회진흥원 NIA와 AI-Hub가 공동주최하고 AI 스타트업 Upstage에서 관리하는 한국어 LLM 리더보드에서도
현재(2023.12.11 기준) 상위 랭커들은 DPO 알고리즘(변형 ver.)을 사용하고 있습니다.
과연 DPO 알고리즘은 무엇이고 어떻게 좋은 성능을 내는지 작성해보도록 하겠습니다.
✅ Key Idea
1. RLHF에서 사용되었던 PPO 목적함수를 조작(? 해서 별도의 Reward Model 학습과정 없이 LM policy를 Human Preference 데이터셋으로 직접 지도학습하여 최적화하는 방식
2. RLHF보다 간단하면서 시간 및 메모리를 최소 2배 이상 절약할 수 있다.
Related Work & Preliminaries
- GPT-1,2,3에 이르면서 대규모 언어모델들은 zero-shot으로 downstream task 해결능력을 보여주는 등 놀라운 능력을 보여주었지만 사실을 지어내거나, 편향적이면서 유해한 텍스트를 생성하거나 사용자의 지시를 따르지 않는 한계점이 있었습니다.
- 모델이 대규모 데이터에서 학습한 지식과 능력을 바탕으로 인간이 원하는 의도에 맞게 대답하도록 하기 위해서 강화학습(RL)을 사용해서 human preference에 맞게 모델을 조정하는 방식(RLHF)을 사용했습니다.
- RLHF 파이프라인은 크게 세 단계로 구성됩니다.
- Supervised fine-tuning(SFT)
- perference sampling and reward learning
- reinforcement-learning optimization
1. SFT phase
- 이 단계에서는 원하는 downstream task(e.g. 대화, insturction-following, 요약)에 맞는 추가 데이터셋으로 사전학습된 LM을 fine-tuning 하는 과정입니다.
- 일반적으로 많이 알고있는 `alpaca 데이터셋`으로 fine-tuning 하는 과정을 이 단계라고 생각하면 됩니다.
- 예를들면 1에서 instruction-following task에 대해서 fine-tuning을 했다고 하면
- 우리가 1에서 fine-tuning에 사용한 데이터셋은 아래와 같은 형태일겁니다.
{
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 건강을 유지하기 위한 세 가지 팁을 알려주세요.
### Response: 세 가지 팁은 아침식사를 꼭 챙기며, 충분한 수면을 취하고, 적극적으로 운동을 하는 것입니다.
},
{
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 반지름이 주어진 원의 넓이를 구합니다.
### Input: 반지름=4
### Response: 입력된 반지름 값(4)을 이용하여 원의 넓이는 16pi입니다.
}
- 이 과정이 끝나고나면 특정 donwstream task에 특화된 SFT 모델이 생성됩니다.
2. RM(Reward Modelling) phase
- 그 다음 단계는 Reward Modelling(이하 RM) 단계입니다.
- 이 단계는 프롬프트에 대해서 샘플링한 답변들에 대한 labeler의 평가데이터(comparison data라고 함)를 RM에게 학습시키는 단계입니다.
- 구체적으로 살펴보면
- 1에서 만든 SFT 모델에 대해서 동일한 프롬프트에 대해서 4~9개의 다른 결과를 추출합니다.
- 예를들면 아래와 같은 프롬프트에 대해서
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 건강을 유지하기 위한 세 가지 팁을 알려주세요.
- 다음과 같은 4가지 서로 다른 답변을 추출(샘플링)하는 것입니다.
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 건강을 유지하기 위한 세 가지 팁을 알려주세요.
### Response: 1. 규칙적인 운동은 신체 건강을 증진시키고 스트레스를 완화합니다.
2. 균형 잡힌 식단으로 영양소를 다양하게 섭취하세요.
3. 충분한 수면을 취하여 신체와 정신적인 회복을 도모하세요.
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 건강을 유지하기 위한 세 가지 팁을 알려주세요.
### Response: 1.규칙적인 운동으로 신체 활동 증가, 2. 균형 잡힌 식단으로 영양 공급, 3. 충분한 수면으로 신체 회복과 스트레스 관리.
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 건강을 유지하기 위한 세 가지 팁을 알려주세요.
### Response: 1. 규칙적인 운동은 필수, 2. 균형 잡힌 식단 유지, 3. 충분한 수면으로 신체 회복.
아래는 task를 설명하는 지침입니다.
요청을 적절하게 완료하는 응답을 작성합니다.
### Instruction: 건강을 유지하기 위한 세 가지 팁을 알려주세요.
### Response: 1. 생선과 견과류 섭취로 심혈관 건강 증진, 2. 미네랄 풍부한 녹황색 채소 섭취, 3. 정기적인 명상으로 스트레스 관리.
- 이렇게 뽑아낸 4개의 답변에 대해서 2개씩 쌍을 만들어서 human labeler가 선호의 정도를 평가합니다.
- 2개씩 쌍을 만들어서 평가하는 이유는 이 선호도를 모델링하기 위해서 Bradley-Terry(이하 BT) 모델을 사용했기 때문입니다.
- 여러 순위가 매겨진 답변으로 모델링하는 경우 Plackett-Luce 모델을 사용할 수도 있습니다.
- BT 모델은 다양한 선택 중에서 어떤 대상이 다른 대상에 비해 상대적으로 선호되는지를 모델링하는데 사용되는 통계적 모델 중 하나입니다.
- 기본수식은 다음과 같습니다. $i$와 $j$는 비교하는 두 대사을 나타내며 $P(i > j)$는 대상 $i$가 대상 $j$보다 선호될 확률을 나타냅니다.
- 이 모델을 우리의 task에 적용하면 텍스트 프롬프트 $x$가 주어졌을 때 생성된 문장 $y_1, y_2$에 대한 선호도를 비교하는 것이므로 다음과 같이 BT-model 식을 작성할 수 있습니다.
- $p^*$의 확률로 답변들이 샘플링 되었다고 가정할 때 maximum likelihood를 통해서 RM의 매개변수를 추정할 수 있습니다. 따라서 RM $r_\phi(x, y)$은 다음과 같은 손실함수를 가지고 학습됩니다.
- 이 때 학습되는 RM은 앞의 SFT 모델과는 별도의 모델로 reward는 단일 스칼라 값이어야 하므로 최종 레이어에 선형레이어를 추가해서 앞 단계의 SFT 모델 파라미터로 초기화한 후 위의 목적함수로 학습이 진행됩니다.
3. RL Fine-Tuning phase
- 마지막 단계는 앞의 두 단계에서 학습한 SFT 모델과 RM 모델을 이용해서 강화학습 알고리즘으로 fine-tuning 하는 과정을 나타냅니다.
- 위의 그림의 아래 부분 (Policy training)에 해당하는 부분입니다.
- RM의 답변 평가 결과를 바탕으로 Policy(여기서는 SFT 모델의 파라미터가 될 것)가 업데이트 되는 것입니다.
- 해당 최적화 과정은 다음과 같은 식으로 나타낼 수 있습니다.
- 식의 노란색 부분은 RM을 통해 받는 Reward를 나타내고
- 파란색 부분은 KL-penalty term 으로 Policy의 급격한 변화를 방지하기 위해서 파라미터의 변동폭을 제한해주는 penalty term 입니다.
- 이 말의 의미는 초기 policy 즉, 우리가 모델을 SFT의 가중치 값으로 초기화했기 때문에 SFT Policy에서 RM의 결과에 따라서 해당 policy가 업데이트가 될텐데 업데이트가 될 때 업데이트 되는 policy가 기존의 SFT policy로부터 너무 멀어지지 않도록 하는 것입니다.
- 이렇게 함으로써 단순히 reward를 높이기 위해 잘못된 방향으로 policy가 업데이트 되는 것을 방지하게 됩니다.
PPO (Proximal Policy Optimization)
- 흔히들 많이 알고계시는 방법이 InstructGPT에서 사용한 RLHF 방법론일텐데요.
- 일반적으로 RLHF approach에서는 학습된 RM의 value function으로 하여금 PPO 알고리즘을 통해 Policy 최적화를 진행하게 됩니다.
- PPO 알고리즘에 대해서 간단하게만 짚고 넘어가자면 PPO 알고리즘은 강화학습의 알고리즘 중 하나입니다.
- 강화학습(Reinforcement Learning)은 agent가 environment(환경)과 상호작용하면서 주어진 state(상태)$s_t$에서 누적 보상 합(return)을 최대로 하는 action(행동)$a_t$를 선택하는 최적의 policy(정책)을 스스로 찾아가는 학습방법입니다.
- 그리고 그러한 알고리즘이 여러가지가 있지만 그 중 PPO가 하나인 것입니다.
- InstructGPT를 위해서 만들어진 알고리즘은 아니고 원래 있던 알고리즘인데 InstructGPT에서 해당 PPO 알고리즘을 LLM에 결합시켰습니다.
- PPO의 특징이라고 한다면 Policy Gradient Method라는 건데요.
- Q-learning 계열의 value-based와는 대조적으로 가치함수를 사용하지 않고 policy를 직접 업데이트 하는 방식의 알고리즘이라는 것입니다.
- 그리고 이러한 Policy Gradient Method의 알고리즘 중 하나인 TRPO(Trust Region Policy Optimization)에서 더 심플한 버전으로 발전된 알고리즘이 PPO입니다.
- PPO는 agent(LLM)가 policy(text 시퀀스 생성)을 RM(human preference에 기반한 reward 체계)을 기반으로 업데이트 하는 방법중 하나라는 정도로만 알고가도 괜찮을 거 같습니다.
- PPO에 대한 자세한 설명은 추후에 게시글로 업로드 하도록 하겠습니다.
- 지금까지 RLHF를 이용해서 LM을 fine-tuning 하는 과정을 살펴보았습니다.
- 처음에 설명했듯이 사람들이 선호하는 답변을 생성하기 위해 이러한 과정을 도입했는데요.
- 이러한 방식을 적용했더니 적용하지 않은 GPT-3 답변보다 적용한 InstructGPT의 답변이 인간의 선호도에 더 걸맞는 것을 위의 그림에서 확인할 수 있습니다.
Direct Preference Optimization (DPO)
Key Idea
- RLHF를 이용해서 LM을 fine-tuning 하는 과정은 사람들이 선호하는 답변을 생성할 수 있게 되었지만 efficient하지는 않습니다.
- 왜냐하면 SFT 모델도 만들고, RM 모델도 만들고, 이를 바탕으로 PPO 알고리즘을 적용시켜서 또 다시 학습시켜야 하기 때문입니다.
- 이러한 복잡한 과정을 줄인 것이 DPO 알고리즘이라고 할 수 있습니다.
- Reward를 학습하고 이를 바탕으로 RL을 통해서 모델을 최적화하는 기존의 RLHF의 방식과 다르게
- DPO는 Reward 모델링 단계를 우회해서 모델을 직접적으로 최적화합니다.
- 아래에서 자세하게 식과 함께 설명하겠지만
- 저자들은 Reward func.과 Optimal Policy 사이의 관계를 분석을 통해 찾아내서 Reward func.의 loss를 Policy의 loss func.에 통합시켰습니다.
- 마치 비유를 하자면
- $ z = 2x + 3y $를 계산하기 위해서 $ 2x $와 $ 3y $를 각각 계산(1)하고 두 값을 더해야(2)했다면
- $ y = x + 2 $라는 관계를 찾아서 단지 $ z = 2x + 3(x+2) = 5x + 6 $ 만 계산해서 우리가 원하는 값을 찾을 수 있게 된 것이라고 생각할 수 있을 거 같습니다.
- 이렇게 change-of-variables 전략을 통해서 Reward를 모델링하는 explicit한 과정을 생략하고 human preference 데이터로 directly LM을 최적화할 수 있게 되었습니다.
- 결국 policy network 자체가 Reward Modeling 단계를 내포하고 있다는 것으로 이해할 수 있습니다. (implicit)
Deriving the DPO objective
- 본격적으로 DPO의 objective를 수식과 함께 유도해보겠습니다.
- DPO의 수식을 유도하기 위해서 위에서 언급했던 RL fine-tuning objective func.에서 시작합니다.
- 위의 길고 험난한(? 과정을 통해서 최종적으로 DPO의 objective를 구했습니다.
- 이렇게 하므로써 explicit한 reward modeling 단계를 skip하고 LM을 direct optimization할 수 있는 objective를 얻은 것입니다.
- 식 유도 과정을 살펴보면서 key idea 라고 생각하는 부분을 파란색 박스로 표시하였습니다.
- optimal solution에서 $r(x,y)$로 reparameterization 한 부분과 BT(Bradley-Terry)를 적용한 부분이 제가 생각했을 때 DPO 식을 유도하는 Main Point 라고 생각합니다.
What does the DPO update do?
- 그렇다면 우리가 유도한 objective가 과연 우리의 의도에 맞게 동작하는지 살펴보도록 하겠습니다.
- DPO 동작 메커니즘을 이해하기 위해서 DPO objective의 $\theta$에 대한gradient 값을 계산해봅니다.
- Likelihood term, Implicit Reward term은 지칭상의 편의를 위해서 임의로 설정한 이름입니다.
- $\hat{r_{\theta}}(x, y)$ 부분을 보았을 때 reward가 policy $\pi_{\theta}$와 $\pi_{ref}$에 의해 implicit하게 포함된 것을 확인할 수 있으며
- $L_DPO$의 기울기는 선호하는 문장($y_w$)이 생성될 가능성은 높이고, 선호하지 않는 문장($y_l$)이 생성될 가능성은 낮추는 방향으로 업데이트가 진행된다는 것도 확인할 수 있습니다.
- 또한 앞쪽의 Implicit Reward term 부분은 선호하지 않는 문장을 얼마나 더 높게 평가하는지에 따라서 가중치를 매기는 것으로 해석할 수 있습니다.
- 저자들은 Implicit Reward term이 중요하다고 주장하고 있는데요. 이 부분(Implicit Reward term)이 없을 경우, 선호하는 문장에 대한 확률만 극대화되어서 LM의 성능이 degenerate 된다고 합니다.
- 아래의 그림과 같이 한 단어 또는 문구만 반복하는 것을 확인할 수 있습니다.
- 이 부분이 원래 RLHF의 목적함수에서 KL penalty 역할을 한다고 생각할 수도 있을 거 같습니다.
DPO outline
- 지금까지의 DPO의 파이프라인을 정리해보자면
- 모든 프롬프트 $x$에 대해서 답변 $y_1, y_2$를 샘플링해서 선호도를 평가해서 선호도 데이터셋 $\mathcal{D}= \{x^{(i)}, y_w^{(i)}, y_l^{(i)}\}_{i=1}^{N}$을 준비하고
- DPO의 loss를 최소화하는 방식으로 LM을 optimize 하면됩니다.
- DPO loss의 Pytorch Code는 다음과 같습니다.
Theoretical Analysis of DPO
- 추후 필요시 추가하도록 하겠습니다.
Experiments
- 추후 필요시 추가하도록 하겠습니다.
🤗 Review
- DPO를 소개받았을 때, 간단(? 한 수식으로 2 step을 1 step으로 압축(? 한 부분이 매우 인상깊었습니다.
- 논문을 읽어보면서 추가로 수식을 더 분석하고 다른 알고리즘을 적용한다면 DPO를 더 발전시킬 수 있을 거 같다는 느낌을 받았습니다.
- 그래서 .. 저에게 DPO 알고리즘의 개선점을 도출하고 알고리즘을 수정하라는 새로운 과제가 주어졌습니다.
- 과제를 수행하는 과정도 추후에 업로드 하도록 하겠습니다!
긴 글 읽어주셔서 감사합니다.
언제든 잘못된 부분에 대한 피드백은 환영입니다 :)