-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[algo] refactor: don't special-case compute_policy_loss
#2701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
8048f36
6a76795
bbfe831
6025dd3
441e152
dccd1e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -719,17 +719,15 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str | |||||||||||||||||||||||||||||||||||
| return loss | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def compute_policy_loss( | ||||||||||||||||||||||||||||||||||||
| old_log_prob, | ||||||||||||||||||||||||||||||||||||
| log_prob, | ||||||||||||||||||||||||||||||||||||
| advantages, | ||||||||||||||||||||||||||||||||||||
| response_mask, | ||||||||||||||||||||||||||||||||||||
| cliprange=None, | ||||||||||||||||||||||||||||||||||||
| cliprange_low=None, | ||||||||||||||||||||||||||||||||||||
| cliprange_high=None, | ||||||||||||||||||||||||||||||||||||
| clip_ratio_c=3.0, | ||||||||||||||||||||||||||||||||||||
| @register_policy_loss("vanilla") | ||||||||||||||||||||||||||||||||||||
| def compute_policy_loss_vanilla( | ||||||||||||||||||||||||||||||||||||
| old_log_prob: torch.Tensor, | ||||||||||||||||||||||||||||||||||||
| log_prob: torch.Tensor, | ||||||||||||||||||||||||||||||||||||
| advantages: torch.Tensor, | ||||||||||||||||||||||||||||||||||||
| response_mask: torch.Tensor, | ||||||||||||||||||||||||||||||||||||
| loss_agg_mode: str = "token-mean", | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| config: Optional[AlgoConfig] = None, | ||||||||||||||||||||||||||||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for the Additionally, the parameter is marked as
Suggested change
|
||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| Compute the clipped policy objective and related metrics for PPO. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -745,19 +743,22 @@ def compute_policy_loss( | |||||||||||||||||||||||||||||||||||
| Advantage estimates for each action, shape (batch_size, response_length). | ||||||||||||||||||||||||||||||||||||
| response_mask (torch.Tensor): | ||||||||||||||||||||||||||||||||||||
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). | ||||||||||||||||||||||||||||||||||||
| cliprange (float, optional): | ||||||||||||||||||||||||||||||||||||
| Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. | ||||||||||||||||||||||||||||||||||||
| Defaults to None (must be provided). | ||||||||||||||||||||||||||||||||||||
| cliprange_low (float, optional): | ||||||||||||||||||||||||||||||||||||
| Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. | ||||||||||||||||||||||||||||||||||||
| cliprange_high (float, optional): | ||||||||||||||||||||||||||||||||||||
| Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. | ||||||||||||||||||||||||||||||||||||
| clip_ratio_c (float, optional): | ||||||||||||||||||||||||||||||||||||
| Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. | ||||||||||||||||||||||||||||||||||||
| Defaults to 3.0. | ||||||||||||||||||||||||||||||||||||
| loss_agg_mode (str, optional): | ||||||||||||||||||||||||||||||||||||
| Aggregation mode for `agg_loss`. Defaults to "token-mean". | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| assert config is not None | ||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||
| clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. | ||||||||||||||||||||||||||||||||||||
| clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio | ||||||||||||||||||||||||||||||||||||
| clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio | ||||||||||||||||||||||||||||||||||||
| clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. | ||||||||||||||||||||||||||||||||||||
| "clip_ratio_c", 3.0 | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| cliprange = clip_ratio | ||||||||||||||||||||||||||||||||||||
| cliprange_low = clip_ratio_low | ||||||||||||||||||||||||||||||||||||
| cliprange_high = clip_ratio_high | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| assert clip_ratio_c > 1.0, ( | ||||||||||||||||||||||||||||||||||||
| "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," | ||||||||||||||||||||||||||||||||||||
| + f" but get the value: {clip_ratio_c}." | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
despite that verl does not list
compute_policy_lossas an official API, we do notice that many verl extensions import this function in their recipe. I'd recommend mark it withverl.utils.import_utils.deprecated, add a new function compute_policy_loss_vanilla, note it in #2744. We then remove compute_policy_loss in the next release.