-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
Changes from 8 commits
5e2181b
3d98e74
cd03fd6
6d8a9e1
5b49a5b
4b5d04f
cca83bb
ec987b3
259edc8
aaf4511
abea330
3fa967e
cb0686e
114145d
3ceb77c
3a55325
38d2391
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| #!/usr/bin/env bash | ||
| set -xeuo pipefail | ||
|
|
||
| project_name='DAPO' | ||
| exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl | ||
|
|
||
| adv_estimator=grpo | ||
|
|
||
| use_kl_in_reward=False | ||
| kl_coef=0.0 | ||
| use_kl_loss=False | ||
| kl_loss_coef=0.0 | ||
|
|
||
| clip_ratio_low=0.2 | ||
| clip_ratio_high=0.28 | ||
|
|
||
| max_prompt_length=$((1024 * 2)) | ||
| max_response_length=$((1024 * 20)) | ||
| enable_overlong_buffer=True | ||
| overlong_buffer_len=$((1024 * 4)) | ||
| overlong_penalty_factor=1.0 | ||
|
|
||
| loss_agg_mode="token-mean" | ||
|
|
||
| enable_filter_groups=True | ||
| filter_groups_metric=acc | ||
| max_num_gen_batches=10 | ||
| train_prompt_bsz=512 | ||
| gen_prompt_bsz=$((train_prompt_bsz * 3)) | ||
| n_resp_per_prompt=16 | ||
| train_prompt_mini_bsz=32 | ||
|
|
||
| # Ray | ||
| RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} | ||
| WORKING_DIR=${WORKING_DIR:-"${PWD}"} | ||
| RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} | ||
| NNODES=${NNODES:-16} | ||
| # Paths | ||
| RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} | ||
| MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} | ||
| CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} | ||
| TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} | ||
| TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} | ||
|
|
||
| # Algorithm | ||
| temperature=1.0 | ||
| top_p=1.0 | ||
| top_k=-1 # 0 for HF rollout, -1 for vLLM rollout | ||
| val_top_p=0.7 | ||
|
|
||
| # Performance Related Parameter | ||
| sp_size=8 | ||
| use_dynamic_bsz=True | ||
| actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) | ||
| infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) | ||
| offload=True | ||
| gen_tp=4 | ||
|
|
||
|
|
||
| # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl | ||
|
|
||
| # Please note that server mode(agent loop) hasn't return rollout_log_probs for now. | ||
| # so currently, server mode is not supported for TIS. | ||
|
|
||
| # To turn on TIS, you need to set the following parameters: | ||
| # 1. rollout.calculate_log_probs=True | ||
| # 2. rollout.imp_ratio_cap > 0 (the value can be tuned) | ||
|
|
||
| ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ | ||
| --working-dir "${WORKING_DIR}" \ | ||
| -- python3 -m recipe.dapo.main_dapo \ | ||
| data.train_files="${TRAIN_FILE}" \ | ||
| data.val_files="${TEST_FILE}" \ | ||
| data.prompt_key=prompt \ | ||
| data.truncation='left' \ | ||
| data.max_prompt_length=${max_prompt_length} \ | ||
| data.max_response_length=${max_response_length} \ | ||
| data.gen_batch_size=${gen_prompt_bsz} \ | ||
| data.train_batch_size=${train_prompt_bsz} \ | ||
| actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ | ||
| algorithm.adv_estimator=${adv_estimator} \ | ||
| algorithm.use_kl_in_reward=${use_kl_in_reward} \ | ||
| algorithm.kl_ctrl.kl_coef=${kl_coef} \ | ||
| actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ | ||
| actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ | ||
| actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ | ||
| actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ | ||
| actor_rollout_ref.actor.clip_ratio_c=10.0 \ | ||
| algorithm.filter_groups.enable=${enable_filter_groups} \ | ||
| algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ | ||
| algorithm.filter_groups.metric=${filter_groups_metric} \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ | ||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.model.path="${MODEL_PATH}" \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ | ||
| actor_rollout_ref.actor.optim.weight_decay=0.1 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.actor.grad_clip=1.0 \ | ||
| actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ | ||
| actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ | ||
| actor_rollout_ref.rollout.enable_chunked_prefill=True \ | ||
| actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ | ||
| actor_rollout_ref.rollout.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.top_p=${top_p} \ | ||
| actor_rollout_ref.rollout.top_k="${top_k}" \ | ||
| actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ | ||
| actor_rollout_ref.rollout.val_kwargs.do_sample=True \ | ||
| actor_rollout_ref.rollout.val_kwargs.n=1 \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ | ||
| reward_model.reward_manager=dapo \ | ||
| reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ | ||
| reward_model.overlong_buffer.len=${overlong_buffer_len} \ | ||
| reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ | ||
| trainer.logger='["console","wandb"]' \ | ||
| trainer.project_name="${project_name}" \ | ||
| trainer.experiment_name="${exp_name}" \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes="${NNODES}" \ | ||
| trainer.val_before_train=True \ | ||
| trainer.test_freq=5 \ | ||
| trainer.save_freq=5 \ | ||
| trainer.total_epochs=1 \ | ||
| trainer.default_local_dir="${CKPTS_DIR}" \ | ||
| trainer.resume_mode=auto \ | ||
| rollout.calculate_log_probs=True \ | ||
| +rollout.imp_ratio_cap=2.0 # remember to turn on calculate_log_probs=True first, and set imp_ratio_cap > 0. The value can be tuned. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -376,6 +376,8 @@ def update_policy(self, data: DataProto): | |
| ] | ||
| if self.config.use_kl_loss: | ||
| select_keys.append("ref_log_prob") | ||
| if self.config.imp_ratio_cap > 0 and "rollout_log_probs" in data.batch.keys(): | ||
| select_keys.append("rollout_log_probs") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Server mode(agent loop) hasn't return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I have add a check here before adding |
||
|
|
||
| has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() | ||
| non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] | ||
|
|
@@ -405,6 +407,7 @@ def update_policy(self, data: DataProto): | |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} | ||
| response_mask = model_inputs["response_mask"] | ||
| old_log_prob = model_inputs["old_log_probs"] | ||
| rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.imp_ratio_cap > 0 else None | ||
| advantages = model_inputs["advantages"] | ||
|
|
||
| entropy_coeff = self.config.entropy_coeff | ||
|
|
@@ -435,6 +438,8 @@ def update_policy(self, data: DataProto): | |
| response_mask=response_mask, | ||
| loss_agg_mode=loss_agg_mode, | ||
| config=self.config, | ||
| rollout_log_probs=rollout_log_probs, | ||
| imp_ratio_cap=self.config.imp_ratio_cap, | ||
|
||
| ) | ||
|
|
||
| if entropy_coeff != 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
icic, I have included change in change in
verl.trainer.ppo.core_algos.compute_policy_loss_vanillaas well.Let me delete the change in the
compute_policy_loss.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed now :)