Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ loss_agg_mode: token-mean
# Entropy regularization coefficient in PPO loss
entropy_coeff: 0

# whether to apply the truncated Importance Sampling (-1 for no importance sampling)
imp_ratio_cap: 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed :)


# Whether to use KL loss instead of KL reward penalty. True for GRPO
use_kl_loss: false

Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ multi_turn:
format: hermes

# support logging rollout prob for debugging purpose
calculate_log_probs: False
# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling
calculate_log_probs: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you turn the default value to False and add a running script for TIS?

Copy link
Contributor Author

@yaof20 yaof20 Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it now and added one running scripts :)


# [Experimental] agent loop based rollout configs
agent:
Expand Down
31 changes: 31 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ def compute_policy_loss(
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode: str = "token-mean",
rollout_log_probs=None,
imp_ratio_cap=-1,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

icic, I have included change in change in verl.trainer.ppo.core_algos.compute_policy_loss_vanilla as well.

Let me delete the change in the compute_policy_loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed now :)

"""
Compute the clipped policy objective and related metrics for PPO.
Expand Down Expand Up @@ -807,6 +809,13 @@ def compute_policy_loss(
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

if imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling
imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
imp_ratio = torch.clamp(imp_ratio, max=imp_ratio_cap)
pg_losses = pg_losses * imp_ratio

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Expand All @@ -820,6 +829,8 @@ def compute_policy_loss_vanilla(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs=None,
imp_ratio_cap=-1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
Expand Down Expand Up @@ -884,6 +895,13 @@ def compute_policy_loss_vanilla(
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

if imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling
imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
imp_ratio = torch.clamp(imp_ratio, max=imp_ratio_cap)
pg_losses = pg_losses * imp_ratio

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Expand Down Expand Up @@ -1270,6 +1288,19 @@ def compute_value_loss(


def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""
The expectation of k1 and k3 estimator is the expectaed value of KL,
but the expected gradient of k1 and k3 estimator is not the expectaed gradient of KL!
On the other hand k2 estimator gives right gradient estimator,
so we use a straight through trick here
"""
forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
backward_score = 0.5 * (logprob - ref_logprob).square()

return backward_score - backward_score.detach() + forward_score.detach()


def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Expand Down
5 changes: 5 additions & 0 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def update_policy(self, data: DataProto):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.imp_ratio_cap > 0:
select_keys.append("rollout_log_probs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server mode(agent loop) hasn't return rollout_log_probs for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have add a check here before adding rollout_log_probs.


has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
Expand Down Expand Up @@ -405,6 +407,7 @@ def update_policy(self, data: DataProto):
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.imp_ratio_cap > 0 else None
advantages = model_inputs["advantages"]

entropy_coeff = self.config.entropy_coeff
Expand Down Expand Up @@ -435,6 +438,8 @@ def update_policy(self, data: DataProto):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
imp_ratio_cap=self.config.imp_ratio_cap,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to pass it as it's included in the config already

)

if entropy_coeff != 0:
Expand Down