From f371901d01096de536574a6335642dfef63e30da Mon Sep 17 00:00:00 2001 From: none0663 Date: Thu, 27 Mar 2025 11:46:07 +0800 Subject: [PATCH 1/5] dual_clip_ppo for lower bound clip --- verl/trainer/config/ppo_megatron_trainer.yaml | 2 ++ verl/trainer/config/ppo_trainer.yaml | 2 ++ verl/trainer/ppo/core_algos.py | 17 +++++++++++++---- verl/workers/actor/dp_actor.py | 6 +++++- verl/workers/actor/megatron_actor.py | 15 +++++++++++++-- 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 370816eecb2..f28e481d1a5 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -28,6 +28,8 @@ actor_rollout_ref: use_dynamic_bsz: False use_torch_compile: True # False to disable torch compile clip_ratio: 0.2 + use_dual_clip: False # add Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + clip_ratio_c: 3 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 entropy_coeff: 0.001 use_kl_loss: False # True for GRPO kl_loss_coef: 0.001 # for grpo diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 2a9f0d333b6..122c54ec40a 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -31,6 +31,8 @@ actor_rollout_ref: ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 + use_dual_clip: False # add Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + clip_ratio_c: 3i # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 entropy_coeff: 0.001 use_kl_loss: False # True for GRPO use_torch_compile: True # False to disable torch compile diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 4d0b8edfff5..02885b90663 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -269,7 +269,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): return token_level_scores - kl * kl_ratio -def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange): +def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, use_dual_clip=False, clip_ratio_c=3): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: @@ -283,13 +283,18 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange) shape: (bs, response_length) cliprange: (float) The clip range used in PPO. See https://arxiv.org/abs/1707.06347 + use_dual_clip: (float) + The use_dual_clip ppo. See https://arxiv.org/pdf/1912.09729 + clip_ratio_c: (float) + THe lower bound of the ratio, defalut 3. See https://arxiv.org/pdf/1912.09729 Returns: pg_loss: `a scalar torch.Tensor` policy gradient loss computed via PPO pg_clipfrac: (float) a float number indicating the fraction of policy gradient loss being clipped - + ppo_kl: (float) + the estimated KL divergence between the latest updating policy and the old sampling policy """ negative_approx_kl = log_prob - old_log_prob ratio = torch.exp(negative_approx_kl) @@ -297,9 +302,13 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange) pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - - pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) + if not use_dual_clip: + pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) + else: + pg_losses3 = -advantages * clip_ratio_c + pg_loss = verl_F.masked_mean(torch.min(pg_losses3, torch.max(pg_losses, pg_losses2)), eos_mask) pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) + return pg_loss, pg_clipfrac, ppo_kl diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index adb6b3927c2..58ec087cc59 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -287,6 +287,8 @@ def update_policy(self, data: DataProto): clip_ratio = self.config.clip_ratio entropy_coeff = self.config.entropy_coeff + use_dual_clip = self.config.get('use_dual_clip', False) + clip_ratio_c = self.config.get('clip_ratio_c', 3) # all return: (bsz, response_length) entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) @@ -295,7 +297,9 @@ def update_policy(self, data: DataProto): log_prob=log_prob, advantages=advantages, eos_mask=response_mask, - cliprange=clip_ratio) + cliprange=clip_ratio, + use_dual_clip=use_dual_clip, + clip_ratio_c=clip_ratio_c) # compute entropy loss from entropy entropy_loss = verl_F.masked_mean(entropy, response_mask) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 2c52c3e5971..56e7bea2b1d 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -277,6 +277,8 @@ def loss_func(output, data, meta_info): clip_ratio = meta_info['clip_ratio'] entropy_coeff = meta_info['entropy_coeff'] + use_dual_clip = meta_info['use_dual_clip'] + clip_ratio_c = meta_info['clip_ratio_c'] # compute policy loss logits = output.logits @@ -288,7 +290,9 @@ def loss_func(output, data, meta_info): log_prob=log_prob, advantages=advantages, eos_mask=response_mask, - cliprange=clip_ratio) + cliprange=clip_ratio, + use_dual_clip=use_dual_clip, + clip_ratio_c=clip_ratio_c) entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask) policy_loss = pg_loss - entropy_loss * entropy_coeff @@ -324,7 +328,14 @@ def forward_step(batch_iter, model): if forward_only: meta_info = None else: - meta_info = {'clip_ratio': self.config.clip_ratio, 'entropy_coeff': self.config.entropy_coeff} + use_dual_clip = self.config.get('use_dual_clip', False) + clip_ratio_c = self.config.get('clip_ratio_c', 3) + meta_info = { + 'clip_ratio': self.config.clip_ratio, + 'entropy_coeff': self.config.entropy_coeff, + 'use_dual_clip': use_dual_clip, + 'clip_ratio_c': clip_ratio_c + } return output, partial(loss_func, data=batch, meta_info=meta_info) # batch should be a list of batches inside micro-batches From 18fc31c3a67d63b990e909af16e56b0ffbc686a4 Mon Sep 17 00:00:00 2001 From: none0663 Date: Thu, 27 Mar 2025 14:11:46 +0800 Subject: [PATCH 2/5] fix yaml clip_ratio_c type error --- verl/trainer/config/ppo_megatron_trainer.yaml | 2 +- verl/trainer/config/ppo_trainer.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index f28e481d1a5..0b82d37fbdd 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -29,7 +29,7 @@ actor_rollout_ref: use_torch_compile: True # False to disable torch compile clip_ratio: 0.2 use_dual_clip: False # add Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - clip_ratio_c: 3 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 entropy_coeff: 0.001 use_kl_loss: False # True for GRPO kl_loss_coef: 0.001 # for grpo diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 122c54ec40a..8cbfe4b591a 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -32,7 +32,7 @@ actor_rollout_ref: grad_clip: 1.0 clip_ratio: 0.2 use_dual_clip: False # add Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - clip_ratio_c: 3i # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 entropy_coeff: 0.001 use_kl_loss: False # True for GRPO use_torch_compile: True # False to disable torch compile From c5bc81540b393d2cce45390b2d8747c734f648cd Mon Sep 17 00:00:00 2001 From: none0663 Date: Fri, 28 Mar 2025 15:03:39 +0800 Subject: [PATCH 3/5] remove the use option and add pg_clipfrac_lower metric --- verl/trainer/config/ppo_megatron_trainer.yaml | 1 - verl/trainer/config/ppo_trainer.yaml | 1 - verl/trainer/ppo/core_algos.py | 18 ++++++++-------- verl/workers/actor/dp_actor.py | 18 ++++++++-------- verl/workers/actor/megatron_actor.py | 21 ++++++++----------- 5 files changed, 27 insertions(+), 32 deletions(-) diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 0b82d37fbdd..87ff11a3718 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -28,7 +28,6 @@ actor_rollout_ref: use_dynamic_bsz: False use_torch_compile: True # False to disable torch compile clip_ratio: 0.2 - use_dual_clip: False # add Dual-clip PPO from https://arxiv.org/pdf/1912.09729 clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 entropy_coeff: 0.001 use_kl_loss: False # True for GRPO diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 8cbfe4b591a..f2ee2f902f6 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -31,7 +31,6 @@ actor_rollout_ref: ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 - use_dual_clip: False # add Dual-clip PPO from https://arxiv.org/pdf/1912.09729 clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 entropy_coeff: 0.001 use_kl_loss: False # True for GRPO diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 02885b90663..47fa97bcf50 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -269,7 +269,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): return token_level_scores - kl * kl_ratio -def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, use_dual_clip=False, clip_ratio_c=3): +def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, clip_ratio_c=3.0): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: @@ -283,8 +283,6 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, shape: (bs, response_length) cliprange: (float) The clip range used in PPO. See https://arxiv.org/abs/1707.06347 - use_dual_clip: (float) - The use_dual_clip ppo. See https://arxiv.org/pdf/1912.09729 clip_ratio_c: (float) THe lower bound of the ratio, defalut 3. See https://arxiv.org/pdf/1912.09729 @@ -302,14 +300,16 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - if not use_dual_clip: - pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) - else: - pg_losses3 = -advantages * clip_ratio_c - pg_loss = verl_F.masked_mean(torch.min(pg_losses3, torch.max(pg_losses, pg_losses2)), eos_mask) + + pg_losses3 = -advantages * clip_ratio_c + max_pg_losses = torch.max(pg_losses, pg_losses2) + + pg_loss = verl_F.masked_mean(torch.min(pg_losses3, max_pg_losses), eos_mask) pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) - return pg_loss, pg_clipfrac, ppo_kl + pg_clipfrac_lower = verl_F.masked_mean(torch.gt(max_pg_losses, pg_losses3).float(), eos_mask) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower def compute_entropy_loss(logits, eos_mask): diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 58ec087cc59..3037e9a4c14 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -287,19 +287,18 @@ def update_policy(self, data: DataProto): clip_ratio = self.config.clip_ratio entropy_coeff = self.config.entropy_coeff - use_dual_clip = self.config.get('use_dual_clip', False) - clip_ratio_c = self.config.get('clip_ratio_c', 3) + clip_ratio_c = self.config.get('clip_ratio_c', 3.0) # all return: (bsz, response_length) entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) - pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=response_mask, - cliprange=clip_ratio, - use_dual_clip=use_dual_clip, - clip_ratio_c=clip_ratio_c) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + eos_mask=response_mask, + cliprange=clip_ratio, + clip_ratio_c=clip_ratio_c) # compute entropy loss from entropy entropy_loss = verl_F.masked_mean(entropy, response_mask) @@ -330,6 +329,7 @@ def update_policy(self, data: DataProto): 'actor/pg_loss': pg_loss.detach().item(), 'actor/pg_clipfrac': pg_clipfrac.detach().item(), 'actor/ppo_kl': ppo_kl.detach().item(), + 'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item(), } append_to_dict(metrics, data) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 56e7bea2b1d..7c9f3701ad9 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -277,7 +277,6 @@ def loss_func(output, data, meta_info): clip_ratio = meta_info['clip_ratio'] entropy_coeff = meta_info['entropy_coeff'] - use_dual_clip = meta_info['use_dual_clip'] clip_ratio_c = meta_info['clip_ratio_c'] # compute policy loss @@ -286,13 +285,12 @@ def loss_func(output, data, meta_info): logits_back = logits.clone() log_prob = vocab_parallel_log_probs_from_logits(logits, responses) logits = logits_back - pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=response_mask, - cliprange=clip_ratio, - use_dual_clip=use_dual_clip, - clip_ratio_c=clip_ratio_c) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + eos_mask=response_mask, + cliprange=clip_ratio, + clip_ratio_c=clip_ratio_c) entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask) policy_loss = pg_loss - entropy_loss * entropy_coeff @@ -314,7 +312,8 @@ def loss_func(output, data, meta_info): 'actor/entropy_loss': entropy_loss.detach().item(), 'actor/pg_loss': pg_loss.detach().item(), 'actor/pg_clipfrac': pg_clipfrac.detach().item(), - 'actor/ppo_kl': ppo_kl.detach().item() + 'actor/ppo_kl': ppo_kl.detach().item(), + 'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item() } append_to_dict(stats, metrics) return policy_loss, stats @@ -328,12 +327,10 @@ def forward_step(batch_iter, model): if forward_only: meta_info = None else: - use_dual_clip = self.config.get('use_dual_clip', False) - clip_ratio_c = self.config.get('clip_ratio_c', 3) + clip_ratio_c = self.config.get('clip_ratio_c', 3.0) meta_info = { 'clip_ratio': self.config.clip_ratio, 'entropy_coeff': self.config.entropy_coeff, - 'use_dual_clip': use_dual_clip, 'clip_ratio_c': clip_ratio_c } return output, partial(loss_func, data=batch, meta_info=meta_info) From 0d918c80ec6e734b90d31d2b778f0c6168a0b54e Mon Sep 17 00:00:00 2001 From: none0663 Date: Fri, 28 Mar 2025 22:44:24 +0800 Subject: [PATCH 4/5] fix dual bug, only apply when advantages < 0 --- verl/trainer/ppo/core_algos.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 47fa97bcf50..da3b436ae10 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -284,7 +284,7 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, cliprange: (float) The clip range used in PPO. See https://arxiv.org/abs/1707.06347 clip_ratio_c: (float) - THe lower bound of the ratio, defalut 3. See https://arxiv.org/pdf/1912.09729 + THe lower bound of the ratio for dual-clip PPO, defalut 3. See https://arxiv.org/pdf/1912.09729 Returns: pg_loss: `a scalar torch.Tensor` @@ -294,6 +294,8 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, ppo_kl: (float) the estimated KL divergence between the latest updating policy and the old sampling policy """ + assert clip_ratio_c > 1.0 , f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}." + negative_approx_kl = log_prob - old_log_prob ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask) @@ -301,13 +303,16 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - pg_losses3 = -advantages * clip_ratio_c - max_pg_losses = torch.max(pg_losses, pg_losses2) - - pg_loss = verl_F.masked_mean(torch.min(pg_losses3, max_pg_losses), eos_mask) + clip_pg_losses1 = torch.max(pg_losses, pg_losses2) pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) - pg_clipfrac_lower = verl_F.masked_mean(torch.gt(max_pg_losses, pg_losses3).float(), eos_mask) + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), eos_mask) + # We only apply the dual-clip when the advantage is negative. + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + pg_loss = verl_F.masked_mean(pg_losses, eos_mask) return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower From 7fa176d5c6ce29fb2161967c9cbe786c6140ee1d Mon Sep 17 00:00:00 2001 From: none0663 Date: Sat, 29 Mar 2025 13:25:43 +0800 Subject: [PATCH 5/5] fix yaml format --- verl/trainer/ppo/core_algos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index da3b436ae10..08411b0e19a 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -294,8 +294,8 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, ppo_kl: (float) the estimated KL divergence between the latest updating policy and the old sampling policy """ - assert clip_ratio_c > 1.0 , f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}." - + assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}." + negative_approx_kl = log_prob - old_log_prob ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)