Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions recipe/dapo/run_dapo_qwen2.5_32b_tis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='DAPO'
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"

enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=512
gen_prompt_bsz=$((train_prompt_bsz * 3))
n_resp_per_prompt=16
train_prompt_mini_bsz=32

# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4


# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl

# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
# so currently, server mode is not supported for TIS.

# To turn on TIS, you need to set the following parameters:
# 1. rollout.calculate_log_probs=True
# 2. rollout.imp_ratio_cap > 0 (the value can be tuned)

ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
algorithm.filter_groups.metric=${filter_groups_metric} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger='["console","wandb"]' \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=5 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
rollout.calculate_log_probs=True \
+rollout.imp_ratio_cap=2.0 # remember to turn on calculate_log_probs=True first, and set imp_ratio_cap > 0. The value can be tuned.
4 changes: 4 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ loss_agg_mode: token-mean
# Entropy regularization coefficient in PPO loss
entropy_coeff: 0

# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
tis_imp_ratio_cap: -1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there should be some joint checking on the configuration? Like, if the user specifies tis_imp_ratio_cap, calculate_log_probs must be True?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Please verify in the config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we have already addressed this issue inverl/workers/actor/dp_actor.py as follows:

      if self.config.tis_imp_ratio_cap > 0:
          assert "rollout_log_probs" in data.batch.keys(), (
              "Truncated Importance Sampling (TIS) requires to configure "
              "`actor_rollout_ref.rollout.calculate_log_probs=True` "
              "and is not currently supported in Server mode (agent loop)."
          )
          select_keys.append("rollout_log_probs")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Looks good to me now.


# Whether to use KL loss instead of KL reward penalty. True for GRPO
use_kl_loss: false

Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ multi_turn:
format: hermes

# support logging rollout prob for debugging purpose
calculate_log_probs: False
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
calculate_log_probs: False

# [Experimental] agent loop based rollout configs
agent:
Expand Down
25 changes: 25 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def compute_policy_loss_vanilla(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
Expand All @@ -838,6 +839,10 @@ def compute_policy_loss_vanilla(
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config: `(verl.trainer.config.ActorConfig)`:
config for the actor.
rollout_log_probs: `(torch.Tensor)`:
log probabilities of actions under the rollout policy, shape (batch_size, response_length).
"""

assert config is not None
Expand Down Expand Up @@ -884,6 +889,13 @@ def compute_policy_loss_vanilla(
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
pg_losses = pg_losses * tis_imp_ratio

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Expand Down Expand Up @@ -1270,6 +1282,19 @@ def compute_value_loss(


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""
The expectation of k1 and k3 estimator is the expectaed value of KL,
but the expected gradient of k1 and k3 estimator is not the expectaed gradient of KL!
On the other hand k2 estimator gives right gradient estimator,
so we use a straight through trick here
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
backward_score = 0.5 * (logprob - ref_logprob).square()

return backward_score - backward_score.detach() + forward_score.detach()


def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Expand Down
4 changes: 4 additions & 0 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def update_policy(self, data: DataProto):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.tis_imp_ratio_cap > 0 and "rollout_log_probs" in data.batch.keys():
select_keys.append("rollout_log_probs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server mode(agent loop) hasn't return rollout_log_probs for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have add a check here before adding rollout_log_probs.


has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
Expand Down Expand Up @@ -405,6 +407,7 @@ def update_policy(self, data: DataProto):
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor problem: what if the user specify tis_imp_ratio_cap? but not specify calculate_log_probs = True? I would suggest directly checking whether "rollout_log_probs" is in model_inputs to avoid any risk of KeyNotFoundError (note this is already very deep in the whole verl codebase, so it could be hard for the user to debug)

advantages = model_inputs["advantages"]

entropy_coeff = self.config.entropy_coeff
Expand Down Expand Up @@ -435,6 +438,7 @@ def update_policy(self, data: DataProto):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
)

if entropy_coeff != 0:
Expand Down