Chain-of-thought (CoT) reasoning in large language models (LLMs) can be formalized as a latent variable problem, where the model needs to generate intermediate reasoning steps. While prior approaches such as iterative reward-ranked fine-tuning (RAFT) have relied on such formulations, they typically apply uniform inference budgets across prompts, which fails to account for variability in difficulty and convergence behavior. This work identifies the main bottleneck in CoT training as inefficient stochastic gradient estimation due to static sampling strategies. We propose GVM-RAFT, a prompt-specific Dynamic Sample Allocation Strategy designed to minimize stochastic gradient variance under a computational budget constraint. The method dynamically allocates computational resources by monitoring prompt acceptance rates and stochastic gradient norms, ensuring that the resulting gradient variance is minimized. Our theoretical analysis shows that the proposed dynamic sampling strategy leads to accelerated convergence guarantees under suitable conditions. Experiments on mathematical reasoning show that GVM-RAFT achieves a 2-4
Main Takeaways
- We revisit the EM framework and RAFT in the context of CoT reasoning, and identify that a major limitation of current approaches lies in inefficient stochastic gradient estimation caused by uniform and static sampling strategies (i.e., best-of-n sampling), which fail to account for prompt-specific difficulty and convergence behavior.
- Motivated by the goal of minimizing the variance of stochastic gradient, we propose a dynamic sampling strategy that adaptively allocates computational resources based on prompt hardness and gradient norms. Our approach provides both intuitive theoretical insight and rigorous convergence guarantees, establishing a principled framework for efficient on-policy sampling under computational budget constraints.
- We apply our method to both RAFT++ and GRPO algorithms with real-world experiments on mathematical reasoning tasks. Our results demonstrate that the proposed approach achieves 2-4
$\times$ speedup in convergence rate and also considerably improve the final test accuracy.
- Create a new environment.
python -m venv ~/.python/gvm source ~/.python/gvm/bin/activate # You can also use conda #conda create -n gvm python==3.10 #conda activate gvm
- Install dependencies
pip install pip --upgrade pip install uv python -m uv pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124 python -m uv pip install flash-attn --no-build-isolation git clone https://github.com/RLHFlow/GVM.git cd GVM/ python -m uv pip install -e . python -m uv pip install vllm==0.6.3
- Prepare the training and test datasets.
python runs/data_preprocess/math_dataset.py python runs/data_preprocess/numina_process.py
- Start the training loop.
bash runs/scripts/run_em.sh bash runs/scripts/run_raft.sh bash runs/scripts/run_grpo.sh
We greatly thanks verl for providing the awesome codebase!