Skip to content

RLHFlow/GVM

Repository files navigation

DIBS - Dynamic Inference Budget Scheduling

Paper Github

Table of Contents

Introduction

Chain-of-thought (CoT) reasoning in large language models (LLMs) can be formalized as a latent variable problem, where the model needs to generate intermediate reasoning steps. While prior approaches such as iterative reward-ranked fine-tuning (RAFT) have relied on such formulations, they typically apply uniform inference budgets across prompts, which fails to account for variability in difficulty and convergence behavior. This work identifies the main bottleneck in CoT training as inefficient stochastic gradient estimation due to static sampling strategies. We propose GVM-RAFT, a prompt-specific Dynamic Sample Allocation Strategy designed to minimize stochastic gradient variance under a computational budget constraint. The method dynamically allocates computational resources by monitoring prompt acceptance rates and stochastic gradient norms, ensuring that the resulting gradient variance is minimized. Our theoretical analysis shows that the proposed dynamic sampling strategy leads to accelerated convergence guarantees under suitable conditions. Experiments on mathematical reasoning show that GVM-RAFT achieves a 2-4 $\times$ speedup and considerable accuracy improvements over vanilla RAFT. The proposed dynamic sampling strategy is general and can be incorporated into other reinforcement learning algorithms, such as GRPO, leading to similar improvements in convergence and test accuracy.

Main Takeaways

  1. We revisit the EM framework and RAFT in the context of CoT reasoning, and identify that a major limitation of current approaches lies in inefficient stochastic gradient estimation caused by uniform and static sampling strategies (i.e., best-of-n sampling), which fail to account for prompt-specific difficulty and convergence behavior.
  2. Motivated by the goal of minimizing the variance of stochastic gradient, we propose a dynamic sampling strategy that adaptively allocates computational resources based on prompt hardness and gradient norms. Our approach provides both intuitive theoretical insight and rigorous convergence guarantees, establishing a principled framework for efficient on-policy sampling under computational budget constraints.
  3. We apply our method to both RAFT++ and GRPO algorithms with real-world experiments on mathematical reasoning tasks. Our results demonstrate that the proposed approach achieves 2-4 $\times$ speedup in convergence rate and also considerably improve the final test accuracy.

Environment Setup

  1. Create a new environment.
    python -m venv ~/.python/gvm
    source ~/.python/gvm/bin/activate
    # You can also use conda 
    #conda create -n gvm python==3.10
    #conda activate gvm
  2. Install dependencies
    pip install pip --upgrade
    pip install uv
    python -m uv pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124
    python -m uv pip install flash-attn --no-build-isolation
    git clone https://github.com/RLHFlow/GVM.git
    cd GVM/
    python -m uv pip install -e .
    python -m uv pip install vllm==0.6.3

Experiments Running

  1. Prepare the training and test datasets.
    python runs/data_preprocess/math_dataset.py
    python runs/data_preprocess/numina_process.py
  2. Start the training loop.
    bash runs/scripts/run_em.sh
    bash runs/scripts/run_raft.sh
    bash runs/scripts/run_grpo.sh

Acknowledgement

We greatly thanks verl for providing the awesome codebase!

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published