-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
Changes from 10 commits
5e2181b
3d98e74
cd03fd6
6d8a9e1
5b49a5b
4b5d04f
cca83bb
ec987b3
259edc8
aaf4511
abea330
3fa967e
cb0686e
114145d
3ceb77c
3a55325
38d2391
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| #!/usr/bin/env bash | ||
| set -xeuo pipefail | ||
|
|
||
| project_name='DAPO' | ||
| exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl | ||
|
|
||
| adv_estimator=grpo | ||
|
|
||
| use_kl_in_reward=False | ||
| kl_coef=0.0 | ||
| use_kl_loss=False | ||
| kl_loss_coef=0.0 | ||
|
|
||
| clip_ratio_low=0.2 | ||
| clip_ratio_high=0.28 | ||
|
|
||
| max_prompt_length=$((1024 * 2)) | ||
| max_response_length=$((1024 * 20)) | ||
| enable_overlong_buffer=True | ||
| overlong_buffer_len=$((1024 * 4)) | ||
| overlong_penalty_factor=1.0 | ||
|
|
||
| loss_agg_mode="token-mean" | ||
|
|
||
| enable_filter_groups=True | ||
| filter_groups_metric=acc | ||
| max_num_gen_batches=10 | ||
| train_prompt_bsz=512 | ||
| gen_prompt_bsz=$((train_prompt_bsz * 3)) | ||
| n_resp_per_prompt=16 | ||
| train_prompt_mini_bsz=32 | ||
|
|
||
| # Ray | ||
| RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} | ||
| WORKING_DIR=${WORKING_DIR:-"${PWD}"} | ||
| RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} | ||
| NNODES=${NNODES:-16} | ||
| # Paths | ||
| RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} | ||
| MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} | ||
| CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} | ||
| TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} | ||
| TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} | ||
|
|
||
| # Algorithm | ||
| temperature=1.0 | ||
| top_p=1.0 | ||
| top_k=-1 # 0 for HF rollout, -1 for vLLM rollout | ||
| val_top_p=0.7 | ||
|
|
||
| # Performance Related Parameter | ||
| sp_size=8 | ||
| use_dynamic_bsz=True | ||
| actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) | ||
| infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) | ||
| offload=True | ||
| gen_tp=4 | ||
|
|
||
|
|
||
| # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl | ||
|
|
||
| # Please note that server mode(agent loop) hasn't return rollout_log_probs for now. | ||
| # so currently, server mode is not supported for TIS. | ||
|
|
||
| # To turn on TIS, you need to set the following parameters: | ||
| # 1. rollout.calculate_log_probs=True | ||
| # 2. rollout.imp_ratio_cap > 0 (the value can be tuned) | ||
|
|
||
| ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ | ||
| --working-dir "${WORKING_DIR}" \ | ||
| -- python3 -m recipe.dapo.main_dapo \ | ||
| data.train_files="${TRAIN_FILE}" \ | ||
| data.val_files="${TEST_FILE}" \ | ||
| data.prompt_key=prompt \ | ||
| data.truncation='left' \ | ||
| data.max_prompt_length=${max_prompt_length} \ | ||
| data.max_response_length=${max_response_length} \ | ||
| data.gen_batch_size=${gen_prompt_bsz} \ | ||
| data.train_batch_size=${train_prompt_bsz} \ | ||
| actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ | ||
| algorithm.adv_estimator=${adv_estimator} \ | ||
| algorithm.use_kl_in_reward=${use_kl_in_reward} \ | ||
| algorithm.kl_ctrl.kl_coef=${kl_coef} \ | ||
| actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ | ||
| actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ | ||
| actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ | ||
| actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ | ||
| actor_rollout_ref.actor.clip_ratio_c=10.0 \ | ||
| algorithm.filter_groups.enable=${enable_filter_groups} \ | ||
| algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ | ||
| algorithm.filter_groups.metric=${filter_groups_metric} \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ | ||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.model.path="${MODEL_PATH}" \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ | ||
| actor_rollout_ref.actor.optim.weight_decay=0.1 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.actor.grad_clip=1.0 \ | ||
| actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ | ||
| actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ | ||
| actor_rollout_ref.rollout.enable_chunked_prefill=True \ | ||
| actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ | ||
| actor_rollout_ref.rollout.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.top_p=${top_p} \ | ||
| actor_rollout_ref.rollout.top_k="${top_k}" \ | ||
| actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ | ||
| actor_rollout_ref.rollout.val_kwargs.do_sample=True \ | ||
| actor_rollout_ref.rollout.val_kwargs.n=1 \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ | ||
| reward_model.reward_manager=dapo \ | ||
| reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ | ||
| reward_model.overlong_buffer.len=${overlong_buffer_len} \ | ||
| reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ | ||
| trainer.logger='["console","wandb"]' \ | ||
| trainer.project_name="${project_name}" \ | ||
| trainer.experiment_name="${exp_name}" \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes="${NNODES}" \ | ||
| trainer.val_before_train=True \ | ||
| trainer.test_freq=5 \ | ||
| trainer.save_freq=5 \ | ||
| trainer.total_epochs=1 \ | ||
| trainer.default_local_dir="${CKPTS_DIR}" \ | ||
| trainer.resume_mode=auto \ | ||
| rollout.calculate_log_probs=True \ | ||
| +rollout.imp_ratio_cap=2.0 # remember to turn on calculate_log_probs=True first, and set imp_ratio_cap > 0. The value can be tuned. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -376,6 +376,8 @@ def update_policy(self, data: DataProto): | |
| ] | ||
| if self.config.use_kl_loss: | ||
| select_keys.append("ref_log_prob") | ||
| if self.config.tis_imp_ratio_cap > 0 and "rollout_log_probs" in data.batch.keys(): | ||
| select_keys.append("rollout_log_probs") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Server mode(agent loop) hasn't return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I have add a check here before adding |
||
|
|
||
| has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() | ||
| non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] | ||
|
|
@@ -405,6 +407,7 @@ def update_policy(self, data: DataProto): | |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} | ||
| response_mask = model_inputs["response_mask"] | ||
| old_log_prob = model_inputs["old_log_probs"] | ||
| rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A very minor problem: what if the user specify |
||
| advantages = model_inputs["advantages"] | ||
|
|
||
| entropy_coeff = self.config.entropy_coeff | ||
|
|
@@ -435,6 +438,7 @@ def update_policy(self, data: DataProto): | |
| response_mask=response_mask, | ||
| loss_agg_mode=loss_agg_mode, | ||
| config=self.config, | ||
| rollout_log_probs=rollout_log_probs, | ||
| ) | ||
|
|
||
| if entropy_coeff != 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think there should be some joint checking on the configuration? Like, if the user specifies
tis_imp_ratio_cap,calculate_log_probsmust be True?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Please verify in the config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, we have already addressed this issue in
verl/workers/actor/dp_actor.pyas follows:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks! Looks good to me now.