Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/algo/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Instead of adding KL penalty in the reward, GRPO regularizes by directly adding

- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.

- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

## Advanced Extensions

Expand Down
2 changes: 1 addition & 1 deletion docs/algo/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Options to use KL loss for KL divergence control:

- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.

- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

Options to use KL penalty in the reward:

Expand Down
4 changes: 3 additions & 1 deletion docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Actor/Rollout/Reference Policy
clip_ratio: 0.2
entropy_coeff: 0.0
use_kl_loss: False # True for GRPO
tis_imp_ratio_cap: -1 # set to positive values for Truncated Importance Sampling (requires setting `rollout.calculate_log_probs` as True)
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
Expand Down Expand Up @@ -185,6 +186,7 @@ Actor/Rollout/Reference Policy
sglang: {}

n: 1 # for each prompt, sample n responses (i.e. num sample times). set it to values > 1 for grpo, rloo
calculate_log_probs: False # set to True for computing log probs via rollouts
val_kwargs:
# sampling parameters for validation
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
Expand Down Expand Up @@ -286,7 +288,7 @@ Actor/Rollout/Reference Policy

- ``actor_rollout_ref.actor.kl_loss_coef``: The coefficient of kl loss. Default is 0.001.

- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. Appending ``+`` in the end (e.g., ``k1+`` and ``k3+``) would use straight-through to employ ``k2`` for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py>`_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor

Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Instead of adding KL penalty in the reward, GRPO regularizes by directly adding

- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.

- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

## Advanced Extensions

Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_trainer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Options to use KL loss for KL divergence control:

- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001.

- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html
- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

Options to use KL penalty in the reward:

Expand Down
144 changes: 144 additions & 0 deletions recipe/dapo/run_dapo_qwen2.5_32b_tis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='DAPO'
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
tis_imp_ratio_cap=2.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"

enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=512
gen_prompt_bsz=$((train_prompt_bsz * 3))
n_resp_per_prompt=16
train_prompt_mini_bsz=32

# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4


# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl

# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
# so currently, server mode is not supported for TIS.

# To turn on TIS, you need to set the following parameters. Note 2.0 is a hyper-parameter and can be tuned.
# actor_rollout_ref.actor.tis_imp_ratio_cap=2.0
# actor_rollout_ref.rollout.calculate_log_probs=True

ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
algorithm.filter_groups.metric=${filter_groups_metric} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.tis_imp_ratio_cap=${tis_imp_ratio_cap} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger='["console","wandb"]' \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=5 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ actor_rollout_ref:
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
tis_imp_ratio_cap: -1
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ actor_rollout_ref:
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
tis_imp_ratio_cap: -1
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ loss_agg_mode: token-mean
# Entropy regularization coefficient in PPO loss
entropy_coeff: 0

# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
tis_imp_ratio_cap: -1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there should be some joint checking on the configuration? Like, if the user specifies tis_imp_ratio_cap, calculate_log_probs must be True?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Please verify in the config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we have already addressed this issue inverl/workers/actor/dp_actor.py as follows:

      if self.config.tis_imp_ratio_cap > 0:
          assert "rollout_log_probs" in data.batch.keys(), (
              "Truncated Importance Sampling (TIS) requires to configure "
              "`actor_rollout_ref.rollout.calculate_log_probs=True` "
              "and is not currently supported in Server mode (agent loop)."
          )
          select_keys.append("rollout_log_probs")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! Looks good to me now.


# Whether to use KL loss instead of KL reward penalty. True for GRPO
use_kl_loss: false

Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ multi_turn:
format: hermes

# support logging rollout prob for debugging purpose
calculate_log_probs: False
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
calculate_log_probs: False

# [Experimental] agent loop based rollout configs
agent:
Expand Down
40 changes: 39 additions & 1 deletion verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def compute_policy_loss_vanilla(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
Expand All @@ -838,6 +839,10 @@ def compute_policy_loss_vanilla(
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config: `(verl.trainer.config.ActorConfig)`:
config for the actor.
rollout_log_probs: `(torch.Tensor)`:
log probabilities of actions under the rollout policy, shape (batch_size, response_length).
"""

assert config is not None
Expand Down Expand Up @@ -884,6 +889,13 @@ def compute_policy_loss_vanilla(
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
pg_losses = pg_losses * tis_imp_ratio

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Expand Down Expand Up @@ -1270,6 +1282,32 @@ def compute_value_loss(


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other
kl penalty compute method for unbiased KL gradient estimation.
See more description in http://joschu.net/blog/kl-approx.html

Args:
logprob:
ref_logprob:

Returns:
kl_estimate
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
return forward_score

"""
The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3
estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator,
so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+.
"""
backward_score = 0.5 * (logprob - ref_logprob).square()

return backward_score - backward_score.detach() + forward_score.detach()


def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Expand All @@ -1279,7 +1317,7 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
ref_logprob:

Returns:

kl_estimate
"""
if kl_penalty in ("kl", "k1"):
return logprob - ref_logprob
Expand Down
10 changes: 10 additions & 0 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ def update_policy(self, data: DataProto):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.tis_imp_ratio_cap > 0:
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
select_keys.append("rollout_log_probs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server mode(agent loop) hasn't return rollout_log_probs for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have add a check here before adding rollout_log_probs.


has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
Expand Down Expand Up @@ -408,6 +415,8 @@ def update_policy(self, data: DataProto):
micro_batch_metrics = {}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor problem: what if the user specify tis_imp_ratio_cap? but not specify calculate_log_probs = True? I would suggest directly checking whether "rollout_log_probs" is in model_inputs to avoid any risk of KeyNotFoundError (note this is already very deep in the whole verl codebase, so it could be hard for the user to debug)

advantages = model_inputs["advantages"]

entropy_coeff = self.config.entropy_coeff
Expand Down Expand Up @@ -443,6 +452,7 @@ def update_policy(self, data: DataProto):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
)

if entropy_coeff != 0:
Expand Down
1 change: 1 addition & 0 deletions verl/workers/config/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class ActorConfig(BaseConfig):
clip_ratio_c: float = 3.0
loss_agg_mode: str = "token-mean"
entropy_coeff: float = 0
tis_imp_ratio_cap: float = -1
use_kl_loss: bool = False
use_torch_compile: bool = True
kl_loss_coef: float = 0.001
Expand Down
Loading