Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning
1. 연구 동기 (Motivation & Contribution)
•
SLM(Small Language Model)의 수학 task 성능을 높이기 위한 방법
◦
Knowledge Distillation: Advanced LM(ex: GPT)의 output을 활용. SLM이 Advanced LM의 behavior를 따라하도록 학습. 그러나 cost가 너무 크다는 단점 존재.
◦
Self-Training: LM이 스스로 생성한 데이터를 이용하여 학습.
•
논문에서는 Self-Training에 DPO를 더하여, 기존 ST보다 높은 성능을 낼 수 있는 방법을 제시.
2. 배경 소개 (Background & Related Works)
•
Math word problem solving: 라는 수학 문제에 대해, 라는 rationale(풀이)를 뱉도록 함. 응답이 옳은지 틀린지는 풀이의 최종 정답 를 통해 비교.
•
Self-Training (ST)
◦
Labeled data 로 teacher model을 학습
◦
학습시킨 teacher model로 unlabeled data 를 annotate → Pseudo-labeled data 확보
◦
로 student model을 학습 → Student model이 teacher model보다 성능이 높아지도록
•
Direct Preference Optimization (DPO)
◦
기존 RLHF: 선호 데이터로 먼저 reward model을 학습 → Reward를 최대화하는 방향으로 LM을 학습
◦
DPO: 선호 데이터로 LM을 직접 학습
▪
Input 에 대한 reference model의 응답 를 샘플링
▪
둘 중 옳은 응답을 , 틀린 응답을 로 하여 선호 데이터셋 구축
▪
Target model이 input 에 대해 를 잘 뽑도록, 은 잘 못 뽑도록 DPO 학습
3. 방법론 (Methodology)
•
DPO-ST
1.
로 supervised fine-tuning
•
수학 task에 대한 labeled dataset 로 를 sft 하여 을 얻음
2.
선호 데이터셋 구축
•
Unlabeled dataset 의 input에 대한 의 응답을 샘플링
•
해당 응답들의 최종 정답을 ground-truth 답과 비교하여 선호 데이터셋 구축 (답이 맞을 경우 , 틀릴 경우 )
•
Unlabeled dataset인데 ground-truth 답은 어떻게 아는지? → 뒤에서 설명
3.
DPO 학습
•
위에서 구축한 선호 데이터셋으로 을 학습하여 를 얻음
4.
를 pseudo-labeling하여 확보
•
의 input에 대한 의 응답을 샘플링 → 원래 question만 있고 그에 대한 rationale은 없었던 에, 샘플링한 응답을 추가하는 pseudo-labeling을 함
5.
을 augment
•
중에서, 위에서 샘플링한 응답의 최종 정답이 실제 ground-truth 답과 맞는 데이터만을 골라 에 더해줌
1.
로 supervised fine-tuning
…
위의 과정을 반복한다.
•
핵심 아이디어: 기존 ST에서는 sft를 한 모델을 사용하여 을 pseudo-labeling 하였는데, DPO-ST에서는 해당 모델을 한 번 더 DPO 학습 시킨 후에 을 pseudo-labeling 한다 → 더 나은 pseudo-label을 확보하겠다는 아이디어
4. 실험 결과 및 분석 (Experiments & Results)
•
모델
◦
Flan-T5-Base (250M) & Large (780M)
◦
Llama-1-7B & 2-7B & 3-8B
•
데이터셋
◦
Training: GSM8K로부터 과 를 모두 얻음. Unlabeled dataset이라고는 하지만, 실제로는 GSM8K에서 답 부분을 지워서 를 확보한 것. 그렇기 때문에 DPO-ST의 2단계와 4단계에서 실제 ground-truth 답과 비교할 수 있었던 것.
◦
Test: GSM8K, MultiArith, ASDiv, SVAMP
•
일반 SFT & ST보다 좋은 성능
•
사이즈가 더 큰 모델(Flan-T5-Large)에서, iteration 진행에 따른 성능 향상이 더 큼 → Large model에서 더 큰 성능 향상을 노릴 수 있다는 potential 제시
•
기존 방식들과 비교했을 때, annotation 과정에서 closed model(GPT)을 사용하지 않았다는 장점 + 성능 높음
•
Closed model에 비하면 성능은 당연히 낮음. 하지만, knowledge distillation을 사용하지 않은 open model에 비해서 성능이 훨씬 높음.
•
Knowledge distilation을 사용한 open model보다는 성능이 낮지만, DPO-ST는 cost가 훨씬 낮다는 장점 존재.
5. 결론 (Conclusion)
•
ST의 중간 과정에 DPO를 추가함으로써, 더 나은 pseudo-labeled data를 확보하는 DPO-ST
•
Knowledge distillation 사용 없이 낮은 cost로 SLM 학습
•
한계점
◦
Unlabeled data를 사용했다고는 하지만, 실제로는 답이 있는 데이터를 사용하였음 → 실제 unlabeled data는 활용하기 힘듦
◦
수학 task에 한정됨 → 최종 정답을 실제 gt 답과 비교하는 방식을 통해 선호 데이터 을 가리는데, 이는 수학 task에서만 가능한 방법
◦
Pseudo-labeled data가 부정확할 수 있음 → 최종 정답은 맞지만 풀이는 틀리는 경우는 걸러낼 수 없음
DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving
1. 연구 동기 (Motivation & Contribution)
•
Data augmentation을 통해 수학 task 성능을 향상
•
기존 방식: Rejection Tuning
◦
데이터셋에 pair가 있을 때, 에 대한 여러 개의 응답을 sampling → 해당 응답의 최종 답을 gt 답과 비교하여, 맞는 것만 남김 → 이런 식으로 answer augmentation
•
하지만 기존 방식에는 쿼리 난이도 측면에서의 bias 발생
◦
쉬운 쿼리를 100번 샘플링하면, 최종 답이 gt 답과 맞는 correct 응답이 90개 정도 나옴 → 90개의 augmented answer 확보
◦
어려운 쿼리를 100번 샘플링하면, correct 응답이 20개 정도 나옴 → 20개의 augmented answer 확보
◦
Augment 이후 쉬운 쿼리의 비율이 훨씬 증가 → Bias 발생
•
논문에서는 이를 해결하기 위한 DART (Difficulty-Aware Rejection Tuning) 제시
◦
Uniform: 쉬운 쿼리든 어려운 쿼리는 똑같이 개의 correct 응답을 확보
◦
Prop2Diff: 어려운 쿼리면 더 많은 correct 응답을 확보
•
쿼리의 난이도는, 여러 번 샘플링 했을 때의 틀린 응답의 비율로 계산 → 틀린 응답이 더 많이 샘플링 될수록 더 어려운 쿼리로 판정
•
샘플링에 open model 사용
2. 배경 소개 (Background & Related Works)
•
Rejection Sampling
◦
위에서 설명한대로, 데이터셋의 쿼리에 대해 여러 개의 응답을 샘플링 하고, 최종 답이 gt 답과 일치하는 것들만을 남김
•
Rejection Tuning
◦
Rejection Sampling으로 augment한 데이터를 사용해 모델을 fine-tuning
3. 방법론 (Methodology)
•
Uniform
◦
각 쿼리에 대해 개의 correct 응답을 수집
•
Prop2Diff
◦
어려운 쿼리일수록 더 많은(최대 개) correct 응답을 수집
•
두 방식을 수학 데이터셋 GSM8K, MATH에 적용하여 augmented dataset 확보
•
샘플링 최대 횟수에 제한을 둠 → 어려운 쿼리면 샘플링을 아주 많이 해도 개의 correct 응답을 얻지 못할 수 있음 → 너무 많이 샘플링하는 경우를 방지
•
샘플링에는 open model인 DeepSeekMath-7B 사용
4. 실험 결과 및 분석 (Experiments & Results)
•
모델
◦
Llama-3-8B & 70B
◦
Mistral-7B
◦
DeepSeekMath-7B-RL
•
데이터셋
◦
Training: GSM8K, MATH
◦
Test: GSM8K, MATH, CollegeMath, DeepMind-Mathematics, OlympiadBench-Math, TheoremQA
•
GSM8K에서는 큰 향상 X → 너무 쉬운 데이터셋이라, 난이도 bias가 크지 않음 (다 쉬운 쿼리로 구성) → DART를 통한 성능 향상 감소
•
DeepSeekMath-7B에서는 성능 향상이 적음 → 이미 수학 데이터로 많이 학습된 수학 특화 모델이라, answer augmentation으로는 큰 향상을 노리기 힘듦 → Query augmentation이 필요
•
학습 데이터 사이즈 증가에 따른 성능 향상 → Mistral과 Llama에서, 기존 방식(Vanilla Rejection Tuning)보다 꾸준히 성능 높음 → Better scaling
•
상술한대로, DeepSeekMath-7B에서는 큰 향상 X
5. 결론 (Conclusion)
•
기존 answer augmentation에서의 난이도 측면에서 발생하는 bias를 해결
•
어려운 쿼리일수록 더 많이 샘플링하여, 최종적으로 쉬운 쿼리와 어려운 쿼리의 비율을 유지 or 어려운 쿼리의 비율을 늘림
•
어려운 쿼리는 학습에 더 도움됨 → Efficient dataset 확보
•
한계점
◦
수학 task에 한정되는 method → Rejection Tuning을 할 때, 샘플링한 응답의 최종 답과 gt 답을 비교하는 방식을 통해 필터링 여부를 결정 → 다른 task에서는 활용하기 힘듦
◦
쿼리의 난이도를 측정할 때 fail rate(샘플링 시 틀린 응답의 비율)을 사용 → Sub-optimal한 방식
◦
샘플링 횟수가 증가함 → Cost 증가
참고자료
•
Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning, Wang et al., ACL 2024
•
Direct Preference Optimization: Your Language Model is Secretly a Reward Model, Rafailov et al., NeurIPS 2023
•
DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving, Tong et al., NeurIPS 2024
•
MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models, Yu et al., ICLR 2024