Search

Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning

분류
Paper Review
세부 분류
Momentum
게시일
2025/02/13 02:01
발표일
2025/01/23
작성자
작성 완료
작성 완료

Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning

1. 연구 동기 (Motivation & Contribution)

SLM(Small Language Model)의 수학 task 성능을 높이기 위한 방법
Knowledge Distillation: Advanced LM(ex: GPT)의 output을 활용. SLM이 Advanced LM의 behavior를 따라하도록 학습. 그러나 cost가 너무 크다는 단점 존재.
Self-Training: LM이 스스로 생성한 데이터를 이용하여 학습.
논문에서는 Self-Training에 DPO를 더하여, 기존 ST보다 높은 성능을 낼 수 있는 방법을 제시.

2. 배경 소개 (Background & Related Works)

Math word problem solving: xx라는 수학 문제에 대해, yy라는 rationale(풀이)를 뱉도록 함. 응답이 옳은지 틀린지는 풀이의 최종 정답 aa를 통해 비교.
Self-Training (ST)
Labeled data LL로 teacher model을 학습
학습시킨 teacher model로 unlabeled data UU를 annotate → Pseudo-labeled data SS 확보
LSL \cup S로 student model을 학습 → Student model이 teacher model보다 성능이 높아지도록
Direct Preference Optimization (DPO)
기존 RLHF: 선호 데이터로 먼저 reward model을 학습 → Reward를 최대화하는 방향으로 LM을 학습
DPO: 선호 데이터로 LM을 직접 학습
Input xx에 대한 reference model의 응답 y1,y2y_1, y_2를 샘플링
둘 중 옳은 응답을 ywy_w, 틀린 응답을 yly_l로 하여 선호 데이터셋 구축
Target model이 input xx에 대해 ywy_w를 잘 뽑도록, yly_l은 잘 못 뽑도록 DPO 학습

3. 방법론 (Methodology)

DPO-ST
1.
LL로 supervised fine-tuning
수학 task에 대한 labeled dataset LLfθf_\theta를 sft 하여 fθf_{\theta^\prime}을 얻음
2.
선호 데이터셋 구축
Unlabeled dataset UU의 input에 대한 fθf_{\theta^\prime}의 응답을 샘플링
해당 응답들의 최종 정답을 ground-truth 답과 비교하여 선호 데이터셋 구축 (답이 맞을 경우 ywy_w, 틀릴 경우 yly_l)
Unlabeled dataset인데 ground-truth 답은 어떻게 아는지? → 뒤에서 설명
3.
DPO 학습
위에서 구축한 선호 데이터셋으로 fθf_{\theta^\prime}을 학습하여 fθdf_{\theta^d}를 얻음
4.
UU를 pseudo-labeling하여 SS 확보
UU의 input에 대한 fθdf_{\theta^d}의 응답을 샘플링 → 원래 question만 있고 그에 대한 rationale은 없었던 UU에, 샘플링한 응답을 추가하는 pseudo-labeling을 함
5.
LL을 augment
SS 중에서, 위에서 샘플링한 응답의 최종 정답이 실제 ground-truth 답과 맞는 데이터만을 골라 LL에 더해줌
1.
LL로 supervised fine-tuning
위의 과정을 반복한다.
핵심 아이디어: 기존 ST에서는 sft를 한 모델을 사용하여 LL을 pseudo-labeling 하였는데, DPO-ST에서는 해당 모델을 한 번 더 DPO 학습 시킨 후에 LL을 pseudo-labeling 한다 → 더 나은 pseudo-label을 확보하겠다는 아이디어

4. 실험 결과 및 분석 (Experiments & Results)

모델
Flan-T5-Base (250M) & Large (780M)
Llama-1-7B & 2-7B & 3-8B
데이터셋
Training: GSM8K로부터 LLUU를 모두 얻음. Unlabeled dataset이라고는 하지만, 실제로는 GSM8K에서 답 부분을 지워서 UU를 확보한 것. 그렇기 때문에 DPO-ST의 2단계와 4단계에서 실제 ground-truth 답과 비교할 수 있었던 것.
Test: GSM8K, MultiArith, ASDiv, SVAMP
일반 SFT & ST보다 좋은 성능
사이즈가 더 큰 모델(Flan-T5-Large)에서, iteration 진행에 따른 성능 향상이 더 큼 → Large model에서 더 큰 성능 향상을 노릴 수 있다는 potential 제시
기존 방식들과 비교했을 때, annotation 과정에서 closed model(GPT)을 사용하지 않았다는 장점 + 성능 높음
Closed model에 비하면 성능은 당연히 낮음. 하지만, knowledge distillation을 사용하지 않은 open model에 비해서 성능이 훨씬 높음.
Knowledge distilation을 사용한 open model보다는 성능이 낮지만, DPO-ST는 cost가 훨씬 낮다는 장점 존재.

5. 결론 (Conclusion)

ST의 중간 과정에 DPO를 추가함으로써, 더 나은 pseudo-labeled data를 확보하는 DPO-ST
Knowledge distillation 사용 없이 낮은 cost로 SLM 학습
한계점
Unlabeled data를 사용했다고는 하지만, 실제로는 답이 있는 데이터를 사용하였음 → 실제 unlabeled data는 활용하기 힘듦
수학 task에 한정됨 → 최종 정답을 실제 gt 답과 비교하는 방식을 통해 선호 데이터 yw,yly_w, y_l을 가리는데, 이는 수학 task에서만 가능한 방법
Pseudo-labeled data가 부정확할 수 있음 → 최종 정답은 맞지만 풀이는 틀리는 경우는 걸러낼 수 없음

DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving

1. 연구 동기 (Motivation & Contribution)

Data augmentation을 통해 수학 task 성능을 향상
기존 방식: Rejection Tuning
데이터셋에 (x,y)(x, y) pair가 있을 때, xx에 대한 여러 개의 응답을 sampling → 해당 응답의 최종 답을 gt 답과 비교하여, 맞는 것만 남김 → (x,y),(x,y1),(x,y2),...(x, y), (x, y_1), (x, y_2), ... 이런 식으로 answer augmentation
하지만 기존 방식에는 쿼리 난이도 측면에서의 bias 발생
쉬운 쿼리를 100번 샘플링하면, 최종 답이 gt 답과 맞는 correct 응답이 90개 정도 나옴 → 90개의 augmented answer 확보
어려운 쿼리를 100번 샘플링하면, correct 응답이 20개 정도 나옴 → 20개의 augmented answer 확보
Augment 이후 쉬운 쿼리의 비율이 훨씬 증가 → Bias 발생
논문에서는 이를 해결하기 위한 DART (Difficulty-Aware Rejection Tuning) 제시
Uniform: 쉬운 쿼리든 어려운 쿼리는 똑같이 kk개의 correct 응답을 확보
Prop2Diff: 어려운 쿼리면 더 많은 correct 응답을 확보
쿼리의 난이도는, 여러 번 샘플링 했을 때의 틀린 응답의 비율로 계산 → 틀린 응답이 더 많이 샘플링 될수록 더 어려운 쿼리로 판정
샘플링에 open model 사용

2. 배경 소개 (Background & Related Works)

Rejection Sampling
위에서 설명한대로, 데이터셋의 쿼리에 대해 여러 개의 응답을 샘플링 하고, 최종 답이 gt 답과 일치하는 것들만을 남김
Rejection Tuning
Rejection Sampling으로 augment한 데이터를 사용해 모델을 fine-tuning

3. 방법론 (Methodology)

Uniform
각 쿼리에 대해 kuk_u개의 correct 응답을 수집
Prop2Diff
어려운 쿼리일수록 더 많은(최대 kpk_p개) correct 응답을 수집
두 방식을 수학 데이터셋 GSM8K, MATH에 적용하여 augmented dataset 확보
샘플링 최대 횟수에 제한을 둠 → 어려운 쿼리면 샘플링을 아주 많이 해도 kk개의 correct 응답을 얻지 못할 수 있음 → 너무 많이 샘플링하는 경우를 방지
샘플링에는 open model인 DeepSeekMath-7B 사용

4. 실험 결과 및 분석 (Experiments & Results)

모델
Llama-3-8B & 70B
Mistral-7B
DeepSeekMath-7B-RL
데이터셋
Training: GSM8K, MATH
Test: GSM8K, MATH, CollegeMath, DeepMind-Mathematics, OlympiadBench-Math, TheoremQA
GSM8K에서는 큰 향상 X → 너무 쉬운 데이터셋이라, 난이도 bias가 크지 않음 (다 쉬운 쿼리로 구성) → DART를 통한 성능 향상 감소
DeepSeekMath-7B에서는 성능 향상이 적음 → 이미 수학 데이터로 많이 학습된 수학 특화 모델이라, answer augmentation으로는 큰 향상을 노리기 힘듦 → Query augmentation이 필요
학습 데이터 사이즈 증가에 따른 성능 향상 → Mistral과 Llama에서, 기존 방식(Vanilla Rejection Tuning)보다 꾸준히 성능 높음 → Better scaling
상술한대로, DeepSeekMath-7B에서는 큰 향상 X

5. 결론 (Conclusion)

기존 answer augmentation에서의 난이도 측면에서 발생하는 bias를 해결
어려운 쿼리일수록 더 많이 샘플링하여, 최종적으로 쉬운 쿼리와 어려운 쿼리의 비율을 유지 or 어려운 쿼리의 비율을 늘림
어려운 쿼리는 학습에 더 도움됨 → Efficient dataset 확보
한계점
수학 task에 한정되는 method → Rejection Tuning을 할 때, 샘플링한 응답의 최종 답과 gt 답을 비교하는 방식을 통해 필터링 여부를 결정 → 다른 task에서는 활용하기 힘듦
쿼리의 난이도를 측정할 때 fail rate(샘플링 시 틀린 응답의 비율)을 사용 → Sub-optimal한 방식
샘플링 횟수가 증가함 → Cost 증가

참고자료

Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning, Wang et al., ACL 2024
Direct Preference Optimization: Your Language Model is Secretly a Reward Model, Rafailov et al., NeurIPS 2023
DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving, Tong et al., NeurIPS 2024
MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models, Yu et al., ICLR 2024