Skip to content

Commit f149c46

Browse files
yaof20zdhNarsilgemini-code-assist[bot]LiyuanLucasLiu
authored andcommitted
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling (volcengine#2953)
Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` > Demonstrate the high-level design if this PR is complex, and list the specific changes. > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
1 parent 3b76b7c commit f149c46

File tree

13 files changed

+208
-7
lines changed

13 files changed

+208
-7
lines changed

docs/algo/grpo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Instead of adding KL penalty in the reward, GRPO regularizes by directly adding
4646

4747
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.
4848

49-
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
49+
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
5050

5151
## Advanced Extensions
5252

docs/algo/ppo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Options to use KL loss for KL divergence control:
5959

6060
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.
6161

62-
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
62+
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
6363

6464
Options to use KL penalty in the reward:
6565

docs/examples/config.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ Actor/Rollout/Reference Policy
118118
clip_ratio: 0.2
119119
entropy_coeff: 0.0
120120
use_kl_loss: False # True for GRPO
121+
tis_imp_ratio_cap: -1 # set to positive values for Truncated Importance Sampling (requires setting `rollout.calculate_log_probs` as True)
121122
use_torch_compile: True # False to disable torch compile
122123
kl_loss_coef: 0.001 # for grpo
123124
kl_loss_type: low_var_kl # for grpo
@@ -188,6 +189,7 @@ Actor/Rollout/Reference Policy
188189
attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla
189190
190191
n: 1 # for each prompt, sample n responses (i.e. num sample times). set it to values > 1 for grpo, rloo
192+
calculate_log_probs: False # set to True for computing log probs via rollouts
191193
val_kwargs:
192194
# sampling parameters for validation
193195
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
@@ -289,7 +291,7 @@ Actor/Rollout/Reference Policy
289291

290292
- ``actor_rollout_ref.actor.kl_loss_coef``: The coefficient of kl loss. Default is 0.001.
291293

292-
- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
294+
- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. Appending ``+`` in the end (e.g., ``k1+`` and ``k3+``) would use straight-through to employ ``k2`` for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
293295

294296
- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor
295297

examples/grpo_trainer/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Instead of adding KL penalty in the reward, GRPO regularizes by directly adding
4444

4545
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.
4646

47-
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
47+
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
4848

4949
## Advanced Extensions
5050

examples/ppo_trainer/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Options to use KL loss for KL divergence control:
5757

5858
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.
5959

60-
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
60+
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
6161

6262
Options to use KL penalty in the reward:
6363

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
project_name='DAPO'
5+
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
6+
7+
adv_estimator=grpo
8+
9+
use_kl_in_reward=False
10+
kl_coef=0.0
11+
use_kl_loss=False
12+
kl_loss_coef=0.0
13+
tis_imp_ratio_cap=2.0
14+
15+
clip_ratio_low=0.2
16+
clip_ratio_high=0.28
17+
18+
max_prompt_length=$((1024 * 2))
19+
max_response_length=$((1024 * 20))
20+
enable_overlong_buffer=True
21+
overlong_buffer_len=$((1024 * 4))
22+
overlong_penalty_factor=1.0
23+
24+
loss_agg_mode="token-mean"
25+
26+
enable_filter_groups=True
27+
filter_groups_metric=acc
28+
max_num_gen_batches=10
29+
train_prompt_bsz=512
30+
gen_prompt_bsz=$((train_prompt_bsz * 3))
31+
n_resp_per_prompt=16
32+
train_prompt_mini_bsz=32
33+
34+
# Ray
35+
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
36+
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
37+
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
38+
NNODES=${NNODES:-16}
39+
# Paths
40+
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
41+
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
42+
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
43+
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
44+
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
45+
46+
# Algorithm
47+
temperature=1.0
48+
top_p=1.0
49+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
50+
val_top_p=0.7
51+
52+
# Performance Related Parameter
53+
sp_size=8
54+
use_dynamic_bsz=True
55+
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
56+
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
57+
offload=True
58+
gen_tp=4
59+
60+
61+
# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
62+
63+
# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
64+
# so currently, server mode is not supported for TIS.
65+
66+
# To turn on TIS, you need to set the following parameters. Note 2.0 is a hyper-parameter and can be tuned.
67+
# actor_rollout_ref.actor.tis_imp_ratio_cap=2.0
68+
# actor_rollout_ref.rollout.calculate_log_probs=True
69+
70+
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
71+
--working-dir "${WORKING_DIR}" \
72+
-- python3 -m recipe.dapo.main_dapo \
73+
data.train_files="${TRAIN_FILE}" \
74+
data.val_files="${TEST_FILE}" \
75+
data.prompt_key=prompt \
76+
data.truncation='left' \
77+
data.max_prompt_length=${max_prompt_length} \
78+
data.max_response_length=${max_response_length} \
79+
data.gen_batch_size=${gen_prompt_bsz} \
80+
data.train_batch_size=${train_prompt_bsz} \
81+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
82+
algorithm.adv_estimator=${adv_estimator} \
83+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
84+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
85+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
86+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
87+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
88+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
89+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
90+
algorithm.filter_groups.enable=${enable_filter_groups} \
91+
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
92+
algorithm.filter_groups.metric=${filter_groups_metric} \
93+
actor_rollout_ref.model.use_remove_padding=True \
94+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
95+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
96+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
97+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
98+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
99+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
100+
actor_rollout_ref.model.path="${MODEL_PATH}" \
101+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
102+
actor_rollout_ref.actor.optim.lr=1e-6 \
103+
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
104+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
105+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
106+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
107+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
108+
actor_rollout_ref.actor.entropy_coeff=0 \
109+
actor_rollout_ref.actor.grad_clip=1.0 \
110+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
111+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
112+
actor_rollout_ref.actor.tis_imp_ratio_cap=${tis_imp_ratio_cap} \
113+
actor_rollout_ref.rollout.calculate_log_probs=True \
114+
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
115+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
116+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
117+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
118+
actor_rollout_ref.rollout.temperature=${temperature} \
119+
actor_rollout_ref.rollout.top_p=${top_p} \
120+
actor_rollout_ref.rollout.top_k="${top_k}" \
121+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
122+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
123+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
124+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
125+
actor_rollout_ref.rollout.val_kwargs.n=1 \
126+
actor_rollout_ref.rollout.name=vllm \
127+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
128+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
129+
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
130+
reward_model.reward_manager=dapo \
131+
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
132+
reward_model.overlong_buffer.len=${overlong_buffer_len} \
133+
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
134+
trainer.logger='["console","wandb"]' \
135+
trainer.project_name="${project_name}" \
136+
trainer.experiment_name="${exp_name}" \
137+
trainer.n_gpus_per_node=8 \
138+
trainer.nnodes="${NNODES}" \
139+
trainer.val_before_train=True \
140+
trainer.test_freq=5 \
141+
trainer.save_freq=5 \
142+
trainer.total_epochs=1 \
143+
trainer.default_local_dir="${CKPTS_DIR}" \
144+
trainer.resume_mode=auto

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ actor_rollout_ref:
2626
clip_ratio_c: 3.0
2727
loss_agg_mode: token-mean
2828
entropy_coeff: 0
29+
tis_imp_ratio_cap: -1
2930
use_kl_loss: false
3031
use_torch_compile: true
3132
kl_loss_coef: 0.001

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ actor_rollout_ref:
2626
clip_ratio_c: 3.0
2727
loss_agg_mode: token-mean
2828
entropy_coeff: 0
29+
tis_imp_ratio_cap: -1
2930
use_kl_loss: false
3031
use_torch_compile: true
3132
kl_loss_coef: 0.001

verl/trainer/config/actor/actor.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ loss_agg_mode: token-mean
7171
# Entropy regularization coefficient in PPO loss
7272
entropy_coeff: 0
7373

74+
# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
75+
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
76+
tis_imp_ratio_cap: -1
77+
7478
# Whether to use KL loss instead of KL reward penalty. True for GRPO
7579
use_kl_loss: false
7680

verl/trainer/config/rollout/rollout.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ multi_turn:
174174
format: hermes
175175

176176
# support logging rollout prob for debugging purpose
177-
calculate_log_probs: False
177+
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
178+
calculate_log_probs: False
178179

179180
# [Experimental] agent loop based rollout configs
180181
agent:

0 commit comments

Comments
 (0)