This software project accompanies the research paper, DiffuCoder: Understanding and Improving Masked Diffusion Models for Code Generation.
MLX support on Apple Silicon is in progress. We will make necessary updates to the repository once it is available.
- June 4, 2025. MLX support in progress. To preview or contribute, please check out this PR started by @Goekdeniz-Guelmez: this PR
- June 4, 2025. Update inference usage/examples/demo.
- June 2, 2025. Models are available on Huggingface.
- June 1, 2025. Code is available.
Scaling upon Masked Denoising Models (MDMs), diffusion LLMs (dLLMs) such as LLaDA and Dream have achieved performance on par with similarly sized autoregressive (AR) LLMs across many benchmarks. Recent commercial-scale dLLMs like Mercury and Gemini further demonstrate that diffusion-based code generators can rival top AR code models on programming tasks while offering faster text generation.
However, the generation pattern and post-training strategies of dLLMs remain under-explored. In this work, we investigate the following questions:
- How does the generation pattern of dLLMs differ from AR models?
- What is the difference in modeling different data modalities, such as code vs. math?
- How diverse can dLLMs be, and how should post-training be designed?
We train DiffuCoder using the adaptation approach in DiffuLLaMA and introduce a new metric — autoregressiveness score — to quantify the causal pattern during dLLM generation. The key findings are listed below.
-
dLLMs still exhibit a left-to-right bias due to the nature of text, but they can also break this strict order in AR models.
-
After pre-training, we show that code tasks induce less global AR-ness than math.
-
In dLLMs, changing the sampling temperature not only affects sampled tokens (as in AR models), but also alters the generation order itself.
For more interesting findings, please refer to our original paper!
We propose Coupled-GRPO, a post-training method to improve DiffuCoder's performance.
In diffusion LLMs, the per-timestep loss
- For each training example, we select
$\lambda$ pairs of timesteps$(t, \hat{t})$ such that$t + \hat{t} = T$ . - We apply two complementary token masks — each mask hides part of the tokens, and together they cover the entire set of target tokens.
- As a result, every token is unmasked in exactly one of the two forward passes.
This ensures that:
- Every token's log-probability is computed at least once, providing a non-zero learning signal for all tokens.
- The probability estimates are more accurate, since each token is evaluated in a realistic partially-masked context (rather than always being fully masked).
- The scheme effectively uses
$2\lambda$ times more sampling passes than the baseline (we choose$\lambda=1$ ), improving estimation with modest computational overhead.
In this repository, we release our implementation of Coupled-GRPO, built upon open-r1.
├── run.sh # start training
├── setup.py # modified open-r1/setup.py
├── src/open_r1/ # our code based on open-r1
│ ├── configs.py # with diffusion related params
│ ├── coupled_grpo.py # inherits trl GRPOTrainer
│ ├── grpo.py # main training script
│ ├── rewards.py # rewrite code reward and code_format reward
│ ├── utils/code_providers.py # rewrite pass rate extraction for E2B
├── recipes/process_data.py # prepare grpo training data
├── recipes/config_coupled_code.yaml # training config
├── tests/test_code_reward.py # test sandbox execution for code
Clone the source code of Open-R1 from git clone https://github.com/huggingface/open-r1
. Merge and replace files between ours and Open-R1's (including setup.py
).
Set up the environment and dependencies following Open-R1:
env=openr1
conda create -n $env python=3.11 -y -c anaconda
conda activate $env
pip install vllm==0.8.4
pip install setuptools
pip install flash-attn==2.8.0.post1 --no-build-isolation
pip install -e ".[code]"
Prepare a code sandbox at E2B. Export your E2B token to E2B_API_KEY
environment variable. Log in to wandb and export your WANDB_ENTITY
.
We prepare a hard split of GRPO training data based on AceCode-89k.
cd recipes
python process_data.py --dataset_path "TIGER-Lab/AceCode-89K" --output_path "./acecode_hard.jsonl" --difficulty "hard"
cd ..
bash run.sh
# in `run.sh`, we start e2b server locally, but you can also run it on CPU clusters.
The DiffuCoder models (Base, Instruct, and cpGRPO) are now available on HuggingFace.
Change TOKEN_PER_STEP
to trade off between performance and speed.
Usage for Base model (click to expand)
import torch
from transformers import AutoModel, AutoTokenizer
model_path = "apple/DiffuCoder-7B-Base"
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to("cuda").eval()
prompt = """
from typing import List
def has_close_elements(numbers: List[float], threshold: float) -> bool:
\"\"\"
Check if in given list of numbers, are any two numbers closer to each other than given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
\"\"\"
"""
TOKEN_PER_STEP = 1 # diffusion timesteps * TOKEN_PER_STEP = total new tokens
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(device="cuda")
attention_mask = inputs.attention_mask.to(device="cuda")
output = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=256,
output_history=True,
return_dict_in_generate=True,
steps=256//TOKEN_PER_STEP,
temperature=0.2,
top_p=0.95,
alg="entropy",
alg_temp=0.,
)
generations = [
tokenizer.decode(g[len(p) :].tolist())
for p, g in zip(input_ids, output.sequences)
]
print(generations[0].split(tokenizer.eos_token)[0])
Output (click to expand)
# Sort the list to make it easier to find close elements
numbers.sort()
# Iterate through the list, checking each adjacent pair
for i in range(len(numbers) - 1):
# If the difference between the current and next element is less than the threshold, return True
if numbers[i + 1] - numbers[i] < threshold:
return True
# If no such pair is found, return False
return False
Given an example input from the HumanEval test, the output of DiffuCoder-Base is a direct completion of the code snippet.
Usage for Instruct model (click to expand)
import torch
from transformers import AutoModel, AutoTokenizer
model_path = "apple/DiffuCoder-7B-cpGRPO"
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to("cuda").eval()
query = "Write a function to find the shared elements from the given two lists."
prompt = f"""<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
{query.strip()}
<|im_end|>
<|im_start|>assistant
""" ## following the template of qwen; you can also use apply_chat_template function
TOKEN_PER_STEP = 1 # diffusion timesteps * TOKEN_PER_STEP = total new tokens
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(device="cuda")
attention_mask = inputs.attention_mask.to(device="cuda")
output = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=256,
output_history=True,
return_dict_in_generate=True,
steps=256//TOKEN_PER_STEP,
temperature=0.4,
top_p=0.95,
alg="entropy",
alg_temp=0.,
)
generations = [
tokenizer.decode(g[len(p) :].tolist())
for p, g in zip(input_ids, output.sequences)
]
print(generations[0].split('<|dlm_pad|>')[0])
Output (click to expand)
Here is the code to solve this problem:
```python
def shared_elements(list1, list2):
return [value for value in list1 if value in list2]
```<|im_end|>
Given an example input from the MBPP test, the output of DiffuCoder-cpGRPO is a chat-based response.
🚀 Start the demo and enter any prompt you want!
python inference_demo.py
The diffusion inference algorithm is based on Dream-7B. The code evaluation is based on Qwen2.5-Coder.
We sincerely appreciate the following works for DiffuCoder:
- Our data used in pre-training/mid-training/instruction tuning are from OpenCoder.
- Our instruction tuning code is based on LLaMA-Factory.
- Our coupled-GRPO is built upon Open-R1 and d1.
- Our evaluation is built upon Dream and Qwen2.5-Coder.
@article{gong2025diffucoder,
title={DiffuCoder: Understanding and Improving Masked Diffusion Models for Code Generation},
author={Shansan Gong, Ruixiang Zhang, Huangjie Zheng, Jiatao Gu, Navdeep Jaitly, Lingpeng Kong, Yizhe Zhang},
year={2025},
eprint={2506.20639},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2506.20639},
}