Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions recipe/dapo/run_dapo_qwen2.5_32b_tis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='DAPO'
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"

enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=512
gen_prompt_bsz=$((train_prompt_bsz * 3))
n_resp_per_prompt=16
train_prompt_mini_bsz=32

# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4


# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl

# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
# so currently, server mode is not supported for TIS.

# To turn on TIS, you need to set the following parameters:
# 1. rollout.calculate_log_probs=True
# 2. rollout.imp_ratio_cap > 0 (the value can be tuned)

ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
algorithm.filter_groups.metric=${filter_groups_metric} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger='["console","wandb"]' \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=5 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
rollout.calculate_log_probs=True \
+rollout.imp_ratio_cap=2.0 # remember to turn on calculate_log_probs=True first, and set imp_ratio_cap > 0. The value can be tuned.
3 changes: 3 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ loss_agg_mode: token-mean
# Entropy regularization coefficient in PPO loss
entropy_coeff: 0

# whether to apply the truncated Importance Sampling (-1 for no importance sampling)
imp_ratio_cap: -1

# Whether to use KL loss instead of KL reward penalty. True for GRPO
use_kl_loss: false

Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ multi_turn:
format: hermes

# support logging rollout prob for debugging purpose
calculate_log_probs: False
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
calculate_log_probs: False

# [Experimental] agent loop based rollout configs
agent:
Expand Down
31 changes: 31 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ def compute_policy_loss(
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode: str = "token-mean",
rollout_log_probs=None,
imp_ratio_cap=-1,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

icic, I have included change in change in verl.trainer.ppo.core_algos.compute_policy_loss_vanilla as well.

Let me delete the change in the compute_policy_loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed now :)

"""
Compute the clipped policy objective and related metrics for PPO.
Expand Down Expand Up @@ -807,6 +809,13 @@ def compute_policy_loss(
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

if imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
imp_ratio = torch.clamp(imp_ratio, max=imp_ratio_cap)
pg_losses = pg_losses * imp_ratio

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Expand All @@ -820,6 +829,8 @@ def compute_policy_loss_vanilla(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs=None,
imp_ratio_cap=-1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
Expand Down Expand Up @@ -884,6 +895,13 @@ def compute_policy_loss_vanilla(
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

if imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
imp_ratio = torch.clamp(imp_ratio, max=imp_ratio_cap)
pg_losses = pg_losses * imp_ratio

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Expand Down Expand Up @@ -1270,6 +1288,19 @@ def compute_value_loss(


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""
The expectation of k1 and k3 estimator is the expectaed value of KL,
but the expected gradient of k1 and k3 estimator is not the expectaed gradient of KL!
On the other hand k2 estimator gives right gradient estimator,
so we use a straight through trick here
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
backward_score = 0.5 * (logprob - ref_logprob).square()

return backward_score - backward_score.detach() + forward_score.detach()


def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Expand Down
5 changes: 5 additions & 0 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def update_policy(self, data: DataProto):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.imp_ratio_cap > 0 and "rollout_log_probs" in data.batch.keys():
select_keys.append("rollout_log_probs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server mode(agent loop) hasn't return rollout_log_probs for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have add a check here before adding rollout_log_probs.


has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
Expand Down Expand Up @@ -405,6 +407,7 @@ def update_policy(self, data: DataProto):
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.imp_ratio_cap > 0 else None
advantages = model_inputs["advantages"]

entropy_coeff = self.config.entropy_coeff
Expand Down Expand Up @@ -435,6 +438,8 @@ def update_policy(self, data: DataProto):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
imp_ratio_cap=self.config.imp_ratio_cap,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to pass it as it's included in the config already

)

if entropy_coeff != 0:
Expand Down