Skip to content

Commit ea44424

Browse files
frradfrederrx
andauthored
[algo] refactor: don't special-case compute_policy_loss (#2701)
### What does this PR do? currently the vanilla policy loss mode is special cased. this moves vanilla onto the shared interface and stops speical-casing it. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Fred <[email protected]>
1 parent 0f5ab5c commit ea44424

File tree

3 files changed

+99
-61
lines changed

3 files changed

+99
-61
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import verl.utils.torch_functional as verl_F
3232
from verl.trainer.config import AlgoConfig
33+
from verl.utils.import_utils import deprecated
3334

3435
PolicyLossFn = Callable[
3536
[
@@ -732,6 +733,7 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
732733
return loss
733734

734735

736+
@deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla")
735737
def compute_policy_loss(
736738
old_log_prob,
737739
log_prob,
@@ -807,6 +809,83 @@ def compute_policy_loss(
807809
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
808810

809811

812+
@register_policy_loss("vanilla")
813+
def compute_policy_loss_vanilla(
814+
old_log_prob: torch.Tensor,
815+
log_prob: torch.Tensor,
816+
advantages: torch.Tensor,
817+
response_mask: torch.Tensor,
818+
loss_agg_mode: str = "token-mean",
819+
config: Optional[DictConfig | AlgoConfig] = None,
820+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
821+
"""
822+
Compute the clipped policy objective and related metrics for PPO.
823+
824+
Adapted from
825+
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
826+
827+
Args:
828+
old_log_prob (torch.Tensor):
829+
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
830+
log_prob (torch.Tensor):
831+
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
832+
advantages (torch.Tensor):
833+
Advantage estimates for each action, shape (batch_size, response_length).
834+
response_mask (torch.Tensor):
835+
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
836+
loss_agg_mode (str, optional):
837+
Aggregation mode for `agg_loss`. Defaults to "token-mean".
838+
"""
839+
840+
assert config is not None
841+
assert not isinstance(config, AlgoConfig)
842+
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
843+
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
844+
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
845+
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
846+
"clip_ratio_c", 3.0
847+
)
848+
849+
cliprange = clip_ratio
850+
cliprange_low = clip_ratio_low
851+
cliprange_high = clip_ratio_high
852+
853+
assert clip_ratio_c > 1.0, (
854+
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
855+
+ f" but get the value: {clip_ratio_c}."
856+
)
857+
858+
negative_approx_kl = log_prob - old_log_prob
859+
# Clamp negative_approx_kl for stability
860+
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
861+
ratio = torch.exp(negative_approx_kl)
862+
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
863+
864+
pg_losses1 = -advantages * ratio
865+
if cliprange_low is None:
866+
cliprange_low = cliprange
867+
if cliprange_high is None:
868+
cliprange_high = cliprange
869+
pg_losses2 = -advantages * torch.clamp(
870+
ratio, 1 - cliprange_low, 1 + cliprange_high
871+
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
872+
clip_pg_losses1 = torch.maximum(
873+
pg_losses1, pg_losses2
874+
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
875+
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
876+
877+
pg_losses3 = -advantages * clip_ratio_c
878+
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
879+
pg_clipfrac_lower = verl_F.masked_mean(
880+
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
881+
)
882+
883+
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
884+
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
885+
886+
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
887+
888+
810889
@register_policy_loss("gpg")
811890
def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None):
812891
"""Adapted from

verl/workers/actor/dp_actor.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import verl.utils.torch_functional as verl_F
2828
from verl import DataProto
29-
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
29+
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
3030
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
3131
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
3232
from verl.utils.profiler import GPUMemoryLogger
@@ -407,14 +407,6 @@ def update_policy(self, data: DataProto):
407407
old_log_prob = model_inputs["old_log_probs"]
408408
advantages = model_inputs["advantages"]
409409

410-
clip_ratio = self.config.clip_ratio
411-
clip_ratio_low = (
412-
self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
413-
)
414-
clip_ratio_high = (
415-
self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
416-
)
417-
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
418410
entropy_coeff = self.config.entropy_coeff
419411
loss_agg_mode = self.config.loss_agg_mode
420412

@@ -428,29 +420,15 @@ def update_policy(self, data: DataProto):
428420

429421
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
430422

431-
if self.config.policy_loss.loss_mode == "vanilla":
432-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
433-
old_log_prob=old_log_prob,
434-
log_prob=log_prob,
435-
advantages=advantages,
436-
response_mask=response_mask,
437-
cliprange=clip_ratio,
438-
cliprange_low=clip_ratio_low,
439-
cliprange_high=clip_ratio_high,
440-
clip_ratio_c=clip_ratio_c,
441-
loss_agg_mode=loss_agg_mode,
442-
)
443-
444-
else:
445-
policy_loss_fn = get_policy_loss_fn(loss_mode)
446-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
447-
old_log_prob=old_log_prob,
448-
log_prob=log_prob,
449-
advantages=advantages,
450-
response_mask=response_mask,
451-
loss_agg_mode=loss_agg_mode,
452-
config=self.config,
453-
)
423+
policy_loss_fn = get_policy_loss_fn(loss_mode)
424+
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
425+
old_log_prob=old_log_prob,
426+
log_prob=log_prob,
427+
advantages=advantages,
428+
response_mask=response_mask,
429+
loss_agg_mode=loss_agg_mode,
430+
config=self.config,
431+
)
454432

455433
if entropy_coeff != 0:
456434
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

verl/workers/actor/megatron_actor.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from torch import nn
3838

3939
from verl import DataProto
40-
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty
40+
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
4141
from verl.utils.device import get_device_id, get_torch_device
4242
from verl.utils.megatron.pipeline_parallel import make_batch_generator
4343
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
@@ -414,39 +414,20 @@ def loss_func(output, data, meta_info):
414414
old_log_prob = data["old_log_probs"]
415415
advantages = data["advantages"]
416416

417-
clip_ratio = self.config.clip_ratio
418-
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
419-
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
420-
421-
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
422417
entropy_coeff = self.config.entropy_coeff
423418
loss_agg_mode = self.config.loss_agg_mode
424419

425420
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
426421

427-
if self.config.policy_loss.loss_mode == "vanilla":
428-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
429-
old_log_prob=old_log_prob,
430-
log_prob=log_prob,
431-
advantages=advantages,
432-
response_mask=response_mask,
433-
cliprange=clip_ratio,
434-
cliprange_low=clip_ratio_low,
435-
cliprange_high=clip_ratio_high,
436-
clip_ratio_c=clip_ratio_c,
437-
loss_agg_mode=loss_agg_mode,
438-
)
439-
440-
else:
441-
policy_loss_fn = get_policy_loss_fn(loss_mode)
442-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
443-
old_log_prob=old_log_prob,
444-
log_prob=log_prob,
445-
advantages=advantages,
446-
response_mask=response_mask,
447-
loss_agg_mode=loss_agg_mode,
448-
config=self.config,
449-
)
422+
policy_loss_fn = get_policy_loss_fn(loss_mode)
423+
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
424+
old_log_prob=old_log_prob,
425+
log_prob=log_prob,
426+
advantages=advantages,
427+
response_mask=response_mask,
428+
loss_agg_mode=loss_agg_mode,
429+
config=self.config,
430+
)
450431

451432
stats.update(
452433
{

0 commit comments

Comments
 (0)