-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
Changes from 4 commits
5e2181b
3d98e74
cd03fd6
6d8a9e1
5b49a5b
4b5d04f
cca83bb
ec987b3
259edc8
aaf4511
abea330
3fa967e
cb0686e
114145d
3ceb77c
3a55325
38d2391
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,7 +159,8 @@ multi_turn: | |
| format: hermes | ||
|
|
||
| # support logging rollout prob for debugging purpose | ||
| calculate_log_probs: False | ||
| # "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling | ||
| calculate_log_probs: True | ||
|
||
|
|
||
| # [Experimental] agent loop based rollout configs | ||
| agent: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -747,6 +747,8 @@ def compute_policy_loss( | |
| cliprange_high=None, | ||
| clip_ratio_c=3.0, | ||
| loss_agg_mode: str = "token-mean", | ||
| rollout_log_probs=None, | ||
| imp_ratio_cap=-1, | ||
| ): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. icic, I have included change in change in Let me delete the change in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fixed now :) |
||
| """ | ||
| Compute the clipped policy objective and related metrics for PPO. | ||
|
|
@@ -807,6 +809,13 @@ def compute_policy_loss( | |
| ) | ||
|
|
||
| pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) | ||
|
|
||
| if imp_ratio_cap > 0 and rollout_log_probs is not None: | ||
| # Apply truncated importance sampling | ||
| imp_ratio = torch.exp(old_log_prob - rollout_log_probs) | ||
| imp_ratio = torch.clamp(imp_ratio, max=imp_ratio_cap) | ||
| pg_losses = pg_losses * imp_ratio | ||
|
|
||
| pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) | ||
|
|
||
| return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower | ||
|
|
@@ -820,6 +829,8 @@ def compute_policy_loss_vanilla( | |
| response_mask: torch.Tensor, | ||
| loss_agg_mode: str = "token-mean", | ||
| config: Optional[DictConfig | AlgoConfig] = None, | ||
| rollout_log_probs=None, | ||
| imp_ratio_cap=-1, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Compute the clipped policy objective and related metrics for PPO. | ||
|
|
@@ -884,6 +895,13 @@ def compute_policy_loss_vanilla( | |
| ) | ||
|
|
||
| pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) | ||
|
|
||
| if imp_ratio_cap > 0 and rollout_log_probs is not None: | ||
| # Apply truncated importance sampling | ||
| imp_ratio = torch.exp(old_log_prob - rollout_log_probs) | ||
| imp_ratio = torch.clamp(imp_ratio, max=imp_ratio_cap) | ||
| pg_losses = pg_losses * imp_ratio | ||
|
|
||
| pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) | ||
|
|
||
| return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower | ||
|
|
@@ -1270,6 +1288,19 @@ def compute_value_loss( | |
|
|
||
|
|
||
| def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: | ||
| """ | ||
| The expectation of k1 and k3 estimator is the expectaed value of KL, | ||
| but the expected gradient of k1 and k3 estimator is not the expectaed gradient of KL! | ||
| On the other hand k2 estimator gives right gradient estimator, | ||
| so we use a straight through trick here | ||
| """ | ||
| forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty) | ||
| backward_score = 0.5 * (logprob - ref_logprob).square() | ||
|
|
||
| return backward_score - backward_score.detach() + forward_score.detach() | ||
|
|
||
|
|
||
| def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: | ||
| """Compute KL divergence given logprob and ref_logprob. | ||
| Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 | ||
| See more description in http://joschu.net/blog/kl-approx.html | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -376,6 +376,8 @@ def update_policy(self, data: DataProto): | |
| ] | ||
| if self.config.use_kl_loss: | ||
| select_keys.append("ref_log_prob") | ||
| if self.config.imp_ratio_cap > 0: | ||
| select_keys.append("rollout_log_probs") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Server mode(agent loop) hasn't return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I have add a check here before adding |
||
|
|
||
| has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() | ||
| non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] | ||
|
|
@@ -405,6 +407,7 @@ def update_policy(self, data: DataProto): | |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} | ||
| response_mask = model_inputs["response_mask"] | ||
| old_log_prob = model_inputs["old_log_probs"] | ||
| rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.imp_ratio_cap > 0 else None | ||
| advantages = model_inputs["advantages"] | ||
|
|
||
| entropy_coeff = self.config.entropy_coeff | ||
|
|
@@ -435,6 +438,8 @@ def update_policy(self, data: DataProto): | |
| response_mask=response_mask, | ||
| loss_agg_mode=loss_agg_mode, | ||
| config=self.config, | ||
| rollout_log_probs=rollout_log_probs, | ||
| imp_ratio_cap=self.config.imp_ratio_cap, | ||
|
||
| ) | ||
|
|
||
| if entropy_coeff != 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed :)