Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
from verl.utils.import_utils import deprecated

PolicyLossFn = Callable[
[
Expand Down Expand Up @@ -732,6 +733,7 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
return loss


@deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla")
def compute_policy_loss(
old_log_prob,
log_prob,
Expand Down Expand Up @@ -807,6 +809,83 @@ def compute_policy_loss(
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.

Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122

Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""

assert config is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion is redundant if the config parameter in the function signature is made non-optional, as suggested in a separate comment. Since callers always provide the configuration, making the parameter mandatory in the signature is a cleaner approach and removes the need for this runtime check.

assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)

cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high

assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)

negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)

pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


@register_policy_loss("gpg")
def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None):
"""Adapted from
Expand Down
42 changes: 10 additions & 32 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
from verl.utils.profiler import GPUMemoryLogger
Expand Down Expand Up @@ -407,14 +407,6 @@ def update_policy(self, data: DataProto):
old_log_prob = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]

clip_ratio = self.config.clip_ratio
clip_ratio_low = (
self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
)
clip_ratio_high = (
self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
)
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

Expand All @@ -428,29 +420,15 @@ def update_policy(self, data: DataProto):

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

if self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)

else:
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)

if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
Expand Down
39 changes: 10 additions & 29 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torch import nn

from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
from verl.utils.device import get_device_id, get_torch_device
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
Expand Down Expand Up @@ -414,39 +414,20 @@ def loss_func(output, data, meta_info):
old_log_prob = data["old_log_probs"]
advantages = data["advantages"]

clip_ratio = self.config.clip_ratio
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio

clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

if self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)

else:
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)
policy_loss_fn = get_policy_loss_fn(loss_mode)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
)

stats.update(
{
Expand Down
Loading