From 8048f3630ff7b5eb3daf713d5fe67d5edb9bff3b Mon Sep 17 00:00:00 2001 From: Frederick Robinson Date: Wed, 23 Jul 2025 00:10:37 -0700 Subject: [PATCH 1/3] don't special-case `compute_policy_loss` --- verl/trainer/ppo/core_algos.py | 41 ++++++++++++++------------- verl/workers/actor/dp_actor.py | 42 +++++++--------------------- verl/workers/actor/megatron_actor.py | 39 +++++++------------------- 3 files changed, 41 insertions(+), 81 deletions(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 5f02675817b..df31f947a7b 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -719,17 +719,15 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str return loss -def compute_policy_loss( - old_log_prob, - log_prob, - advantages, - response_mask, - cliprange=None, - cliprange_low=None, - cliprange_high=None, - clip_ratio_c=3.0, +@register_policy_loss("vanilla") +def compute_policy_loss_vanilla( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", -): + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for PPO. @@ -745,19 +743,22 @@ def compute_policy_loss( Advantage estimates for each action, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the loss, shape (batch_size, response_length). - cliprange (float, optional): - Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. - Defaults to None (must be provided). - cliprange_low (float, optional): - Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. - cliprange_high (float, optional): - Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. - clip_ratio_c (float, optional): - Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. - Defaults to 3.0. loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". """ + + assert config is not None + clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio + clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + "clip_ratio_c", 3.0 + ) + + cliprange = clip_ratio + cliprange_low = clip_ratio_low + cliprange_high = clip_ratio_high + assert clip_ratio_c > 1.0, ( "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}." diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index d5cea36209a..42f4c2a134d 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -26,7 +26,7 @@ import verl.utils.torch_functional as verl_F from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty +from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty from verl.utils.device import get_device_name, is_cuda_available, is_npu_available from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.profiler import GPUMemoryLogger @@ -397,14 +397,6 @@ def update_policy(self, data: DataProto): old_log_prob = model_inputs["old_log_probs"] advantages = model_inputs["advantages"] - clip_ratio = self.config.clip_ratio - clip_ratio_low = ( - self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - ) - clip_ratio_high = ( - self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - ) - clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode @@ -418,29 +410,15 @@ def update_policy(self, data: DataProto): loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") - if self.config.policy_loss.loss_mode == "vanilla": - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, - ) - - else: - policy_loss_fn = get_policy_loss_fn(loss_mode) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - loss_agg_mode=loss_agg_mode, - config=self.config, - ) + policy_loss_fn = get_policy_loss_fn(loss_mode) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + ) if entropy_coeff != 0: entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index ce52956d042..4cbccf112ae 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -37,7 +37,7 @@ from torch import nn from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty +from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits @@ -414,39 +414,20 @@ def loss_func(output, data, meta_info): old_log_prob = data["old_log_probs"] advantages = data["advantages"] - clip_ratio = self.config.clip_ratio - clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - - clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") - if self.config.policy_loss.loss_mode == "vanilla": - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, - ) - - else: - policy_loss_fn = get_policy_loss_fn(loss_mode) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - loss_agg_mode=loss_agg_mode, - config=self.config, - ) + policy_loss_fn = get_policy_loss_fn(loss_mode) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + ) stats.update( { From 6025dd338bb663adfa9139c614f0bfff61271c34 Mon Sep 17 00:00:00 2001 From: Frederick Robinson Date: Thu, 24 Jul 2025 20:46:16 -0700 Subject: [PATCH 2/3] fix types --- verl/trainer/ppo/core_algos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 986e4eebbff..7a81941c99b 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -739,7 +739,7 @@ def compute_policy_loss_vanilla( advantages: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", - config: Optional[AlgoConfig] = None, + config: Optional[DictConfig | AlgoConfig] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for PPO. @@ -761,6 +761,7 @@ def compute_policy_loss_vanilla( """ assert config is not None + assert not isinstance(config, AlgoConfig) clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio From 441e1525ed53eb58a67cade925b2d0b5d84ab0a1 Mon Sep 17 00:00:00 2001 From: Frederick Robinson Date: Thu, 24 Jul 2025 20:55:44 -0700 Subject: [PATCH 3/3] restore `compute_policy_loss` with deprecation warning --- verl/trainer/ppo/core_algos.py | 77 ++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7a81941c99b..12b3ed791d0 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -30,6 +30,7 @@ import verl.utils.torch_functional as verl_F from verl.trainer.config import AlgoConfig +from verl.utils.import_utils import deprecated PolicyLossFn = Callable[ [ @@ -732,6 +733,82 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str return loss +@deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla") +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped policy objective and related metrics for PPO. + + Adapted from + https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + clip_ratio_c (float, optional): + Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + Defaults to 3.0. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + """ + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + @register_policy_loss("vanilla") def compute_policy_loss_vanilla( old_log_prob: torch.Tensor,