1. 연구 동기 (Motivation & Contribution)
•
Data augmentation을 통해 수학 task 성능을 향상
•
기존 방식: Rejection Tuning
◦
데이터셋에 pair가 있을 때, 에 대한 여러 개의 응답을 sampling → 해당 응답의 최종 답을 gt 답과 비교하여, 맞는 것만 남김 → 이런 식으로 answer augmentation
•
하지만 기존 방식에는 쿼리 난이도 측면에서의 bias 발생
◦
쉬운 쿼리를 100번 샘플링하면, 최종 답이 gt 답과 맞는 correct 응답이 90개 정도 나옴 → 90개의 augmented answer 확보
◦
어려운 쿼리를 100번 샘플링하면, correct 응답이 20개 정도 나옴 → 20개의 augmented answer 확보
◦
Augment 이후 쉬운 쿼리의 비율이 훨씬 증가 → Bias 발생
•
논문에서는 이를 해결하기 위한 DART (Difficulty-Aware Rejection Tuning) 제시
◦
Uniform: 쉬운 쿼리든 어려운 쿼리는 똑같이 개의 correct 응답을 확보
◦
Prop2Diff: 어려운 쿼리면 더 많은 correct 응답을 확보
•
쿼리의 난이도는, 여러 번 샘플링 했을 때의 틀린 응답의 비율로 계산 → 틀린 응답이 더 많이 샘플링 될수록 더 어려운 쿼리로 판정
•
샘플링에 open model 사용
2. 배경 소개 (Background & Related Works)
•
Rejection Sampling
◦
위에서 설명한대로, 데이터셋의 쿼리에 대해 여러 개의 응답을 샘플링 하고, 최종 답이 gt 답과 일치하는 것들만을 남김
•
Rejection Tuning
◦
Rejection Sampling으로 augment한 데이터를 사용해 모델을 fine-tuning
3. 방법론 (Methodology)
•
Uniform
◦
각 쿼리에 대해 개의 correct 응답을 수집
•
Prop2Diff
◦
어려운 쿼리일수록 더 많은(최대 개) correct 응답을 수집
•
두 방식을 수학 데이터셋 GSM8K, MATH에 적용하여 augmented dataset 확보
•
샘플링 최대 횟수에 제한을 둠 → 어려운 쿼리면 샘플링을 아주 많이 해도 개의 correct 응답을 얻지 못할 수 있음 → 너무 많이 샘플링하는 경우를 방지
•
샘플링에는 open model인 DeepSeekMath-7B 사용
4. 실험 결과 및 분석 (Experiments & Results)
•
모델
◦
Llama-3-8B & 70B
◦
Mistral-7B
◦
DeepSeekMath-7B-RL
•
데이터셋
◦
Training: GSM8K, MATH
◦
Test: GSM8K, MATH, CollegeMath, DeepMind-Mathematics, OlympiadBench-Math, TheoremQA
•
GSM8K에서는 큰 향상 X → 너무 쉬운 데이터셋이라, 난이도 bias가 크지 않음 (다 쉬운 쿼리로 구성) → DART를 통한 성능 향상 감소
•
DeepSeekMath-7B에서는 성능 향상이 적음 → 이미 수학 데이터로 많이 학습된 수학 특화 모델이라, answer augmentation으로는 큰 향상을 노리기 힘듦 → Query augmentation이 필요
•
학습 데이터 사이즈 증가에 따른 성능 향상 → Mistral과 Llama에서, 기존 방식(Vanilla Rejection Tuning)보다 꾸준히 성능 높음 → Better scaling
•
상술한대로, DeepSeekMath-7B에서는 큰 향상 X
5. 결론 (Conclusion)
•
기존 answer augmentation에서의 난이도 측면에서 발생하는 bias를 해결
•
어려운 쿼리일수록 더 많이 샘플링하여, 최종적으로 쉬운 쿼리와 어려운 쿼리의 비율을 유지 or 어려운 쿼리의 비율을 늘림
•
어려운 쿼리는 학습에 더 도움됨 → Efficient dataset 확보
•
한계점
◦
수학 task에 한정되는 method → Rejection Tuning을 할 때, 샘플링한 응답의 최종 답과 gt 답을 비교하는 방식을 통해 필터링 여부를 결정 → 다른 task에서는 활용하기 힘듦
◦
쿼리의 난이도를 측정할 때 fail rate(샘플링 시 틀린 응답의 비율)을 사용 → Sub-optimal한 방식
◦
샘플링 횟수가 증가함 → Cost 증가
참고자료
•
Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning, Wang et al., ACL 2024
•
Direct Preference Optimization: Your Language Model is Secretly a Reward Model, Rafailov et al., NeurIPS 2023
•
DART-Math: Difficulty-Aware Rejection Tuning for Mathematical Problem-Solving, Tong et al., NeurIPS 2024
•
MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models, Yu et al., ICLR 2024