Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,17 +719,15 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
return loss


def compute_policy_loss(
old_log_prob,
log_prob,
advantages,
response_mask,
cliprange=None,
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

despite that verl does not list compute_policy_loss as an official API, we do notice that many verl extensions import this function in their recipe. I'd recommend mark it with verl.utils.import_utils.deprecated, add a new function compute_policy_loss_vanilla, note it in #2744. We then remove compute_policy_loss in the next release.

@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
):
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint for the config parameter is Optional[AlgoConfig], but this appears to be incorrect. The function accesses attributes like config.clip_ratio, which are defined in the AlgoConfig class. The config object passed from the callers in dp_actor.py and megatron_actor.py is the actor's configuration, which contains the algorithm's configuration.

Additionally, the parameter is marked as Optional with a default of None, but it's treated as mandatory within the function (see the assertion on line 750). Since callers always provide this configuration, it's better to make it a required, non-optional parameter.

Suggested change
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
):
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: AlgoConfig,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

"""
Compute the clipped policy objective and related metrics for PPO.

Expand All @@ -745,19 +743,22 @@ def compute_policy_loss(
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
clip_ratio_c (float, optional):
Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
Defaults to 3.0.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""

assert config is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion is redundant if the config parameter in the function signature is made non-optional, as suggested in a separate comment. Since callers always provide the configuration, making the parameter mandatory in the signature is a cleaner approach and removes the need for this runtime check.

clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)

cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high

assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
Expand Down
42 changes: 10 additions & 32 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
from verl.utils.profiler import GPUMemoryLogger
Expand Down Expand Up @@ -397,14 +397,6 @@ def update_policy(self, data: DataProto):
old_log_prob = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]

clip_ratio = self.config.clip_ratio
clip_ratio_low = (
self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
)
clip_ratio_high = (
self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
)
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

Expand All @@ -418,29 +410,15 @@ def update_policy(self, data: DataProto):

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

if self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)

else:
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)

if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
Expand Down
39 changes: 10 additions & 29 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torch import nn

from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
from verl.utils.device import get_device_id, get_torch_device
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
Expand Down Expand Up @@ -414,39 +414,20 @@ def loss_func(output, data, meta_info):
old_log_prob = data["old_log_probs"]
advantages = data["advantages"]

clip_ratio = self.config.clip_ratio
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio

clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

if self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)

else:
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)

stats.update(
{
Expand Down
Loading