|
41 | 41 | torch.Tensor, # response_mask |
42 | 42 | str, # loss_agg_mode |
43 | 43 | Optional[DictConfig | AlgoConfig], # config |
| 44 | + torch.Tensor | None, # rollout_log_probs |
44 | 45 | ], |
45 | 46 | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], |
46 | 47 | ] |
@@ -820,7 +821,7 @@ def compute_policy_loss_vanilla( |
820 | 821 | response_mask: torch.Tensor, |
821 | 822 | loss_agg_mode: str = "token-mean", |
822 | 823 | config: Optional[DictConfig | AlgoConfig] = None, |
823 | | - rollout_log_probs=None, |
| 824 | + rollout_log_probs: torch.Tensor | None = None, |
824 | 825 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
825 | 826 | """ |
826 | 827 | Compute the clipped policy objective and related metrics for PPO. |
@@ -909,6 +910,7 @@ def compute_policy_loss_gspo( |
909 | 910 | response_mask: torch.Tensor, |
910 | 911 | loss_agg_mode: str = "seq-mean-token-mean", |
911 | 912 | config: Optional[DictConfig | ActorConfig] = None, |
| 913 | + rollout_log_probs: torch.Tensor | None = None, |
912 | 914 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
913 | 915 | """ |
914 | 916 | Compute the clipped policy objective and related metrics for GSPO. |
@@ -967,7 +969,15 @@ def compute_policy_loss_gspo( |
967 | 969 |
|
968 | 970 |
|
969 | 971 | @register_policy_loss("gpg") |
970 | | -def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): |
| 972 | +def compute_policy_loss_gpg( |
| 973 | + old_log_prob: torch.Tensor, |
| 974 | + log_prob: torch.Tensor, |
| 975 | + advantages: torch.Tensor, |
| 976 | + response_mask: torch.Tensor, |
| 977 | + loss_agg_mode: str = "token-mean", |
| 978 | + config: Optional[DictConfig | AlgoConfig] = None, |
| 979 | + rollout_log_probs: torch.Tensor | None = None, |
| 980 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
971 | 981 | """Adapted from |
972 | 982 | https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 |
973 | 983 | Args: |
@@ -995,6 +1005,7 @@ def compute_policy_loss_clip_cov( |
995 | 1005 | response_mask: torch.Tensor, |
996 | 1006 | loss_agg_mode: str = "token-mean", |
997 | 1007 | config: Optional[DictConfig | AlgoConfig] = None, |
| 1008 | + rollout_log_probs: torch.Tensor | None = None, |
998 | 1009 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
999 | 1010 | """ |
1000 | 1011 | Compute the clipped policy objective and related metrics for Clip-Cov. |
@@ -1089,6 +1100,7 @@ def compute_policy_loss_kl_cov( |
1089 | 1100 | response_mask: torch.Tensor, |
1090 | 1101 | loss_agg_mode: str = "token-mean", |
1091 | 1102 | config: Optional[DictConfig | AlgoConfig] = None, |
| 1103 | + rollout_log_probs: torch.Tensor | None = None, |
1092 | 1104 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
1093 | 1105 | """ |
1094 | 1106 | Compute the clipped policy objective and related metrics for Clip-Cov. |
@@ -1160,6 +1172,7 @@ def compute_policy_loss_geo_mean( |
1160 | 1172 | response_mask: torch.Tensor, |
1161 | 1173 | loss_agg_mode: str = "token-mean", |
1162 | 1174 | config: Optional[DictConfig | AlgoConfig] = None, |
| 1175 | + rollout_log_probs: torch.Tensor | None = None, |
1163 | 1176 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
1164 | 1177 | """ |
1165 | 1178 | Compute the clipped policy objective and related metrics for GMPO. |
|
0 commit comments