Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ actor_rollout_ref:
use_dynamic_bsz: False
use_torch_compile: True # False to disable torch compile
clip_ratio: 0.2
clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ actor_rollout_ref:
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
use_torch_compile: True # False to disable torch compile
Expand Down
22 changes: 18 additions & 4 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
return token_level_scores - kl * kl_ratio


def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange, clip_ratio_c=3.0):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122

Args:
Expand All @@ -283,24 +283,38 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange)
shape: (bs, response_length)
cliprange: (float)
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
clip_ratio_c: (float)
THe lower bound of the ratio for dual-clip PPO, defalut 3. See https://arxiv.org/pdf/1912.09729

Returns:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via PPO
pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped

ppo_kl: (float)
the estimated KL divergence between the latest updating policy and the old sampling policy
"""
assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."

negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)

pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)

pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
clip_pg_losses1 = torch.max(pg_losses, pg_losses2)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl

pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), eos_mask)
# We only apply the dual-clip when the advantage is negative.
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

pg_loss = verl_F.masked_mean(pg_losses, eos_mask)

return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


def compute_entropy_loss(logits, eos_mask):
Expand Down
14 changes: 9 additions & 5 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,18 @@ def update_policy(self, data: DataProto):

clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff
clip_ratio_c = self.config.get('clip_ratio_c', 3.0)

# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)

pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio,
clip_ratio_c=clip_ratio_c)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)

Expand Down Expand Up @@ -326,6 +329,7 @@ def update_policy(self, data: DataProto):
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item(),
'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item(),
}
append_to_dict(metrics, data)

Expand Down
22 changes: 15 additions & 7 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,18 +277,20 @@ def loss_func(output, data, meta_info):

clip_ratio = meta_info['clip_ratio']
entropy_coeff = meta_info['entropy_coeff']
clip_ratio_c = meta_info['clip_ratio_c']

# compute policy loss
logits = output.logits
logits = logits[:, -response_length - 1:-1].contiguous()
logits_back = logits.clone()
log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
logits = logits_back
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio,
clip_ratio_c=clip_ratio_c)
entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask)
policy_loss = pg_loss - entropy_loss * entropy_coeff

Expand All @@ -310,7 +312,8 @@ def loss_func(output, data, meta_info):
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item()
'actor/ppo_kl': ppo_kl.detach().item(),
'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item()
}
append_to_dict(stats, metrics)
return policy_loss, stats
Expand All @@ -324,7 +327,12 @@ def forward_step(batch_iter, model):
if forward_only:
meta_info = None
else:
meta_info = {'clip_ratio': self.config.clip_ratio, 'entropy_coeff': self.config.entropy_coeff}
clip_ratio_c = self.config.get('clip_ratio_c', 3.0)
meta_info = {
'clip_ratio': self.config.clip_ratio,
'entropy_coeff': self.config.entropy_coeff,
'clip_ratio_c': clip_ratio_c
}
return output, partial(loss_func, data=batch, meta_info=meta_info)

# batch should be a list of batches inside micro-batches
Expand Down