2222
2323from collections import defaultdict
2424from enum import Enum
25- from typing import Optional
25+ from typing import Any , Callable , Optional
2626
2727import numpy as np
2828import torch
29+ from omegaconf import DictConfig
2930
3031import verl .utils .torch_functional as verl_F
3132from verl .trainer .config import AlgoConfig
3233
33- POLICY_LOSS_REGISTRY = {}
34+ PolicyLossFn = Callable [
35+ [
36+ torch .Tensor , # old_log_prob
37+ torch .Tensor , # log_prob
38+ torch .Tensor , # advantages
39+ torch .Tensor , # response_mask
40+ str , # loss_agg_mode
41+ Optional [DictConfig | AlgoConfig ], # config
42+ ],
43+ tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ],
44+ ]
3445
46+ POLICY_LOSS_REGISTRY : dict [str , PolicyLossFn ] = {}
3547
36- def register_policy_loss (name ):
48+
49+ def register_policy_loss (name : str ) -> Callable [[PolicyLossFn ], PolicyLossFn ]:
3750 """Register a policy loss function with the given name.
3851
3952 Args:
@@ -43,7 +56,7 @@ def register_policy_loss(name):
4356 function: Decorator function that registers the policy loss function.
4457 """
4558
46- def decorator (func ) :
59+ def decorator (func : PolicyLossFn ) -> PolicyLossFn :
4760 POLICY_LOSS_REGISTRY [name ] = func
4861 return func
4962
@@ -68,10 +81,30 @@ def get_policy_loss_fn(name):
6881 return POLICY_LOSS_REGISTRY [loss_name ]
6982
7083
71- ADV_ESTIMATOR_REGISTRY = {}
84+ class AdvantageEstimator (str , Enum ):
85+ """Using an enumeration class to avoid spelling errors in adv_estimator.
86+
87+ Note(haibin.lin): this enum class is immutable after creation. Extending this
88+ enum for new estimators may not be necessary since users can always just call
89+ `verl.trainer.ppo.core_algos.register` with string name for a custom advantage
90+ estimator instead.
91+ """
92+
93+ GAE = "gae"
94+ GRPO = "grpo"
95+ REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
96+ REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
97+ REMAX = "remax"
98+ RLOO = "rloo"
99+ OPO = "opo"
100+ GRPO_PASSK = "grpo_passk"
101+ GPG = "gpg"
102+
103+
104+ ADV_ESTIMATOR_REGISTRY : dict [str , Any ] = {}
72105
73106
74- def register_adv_est (name_or_enum ) :
107+ def register_adv_est (name_or_enum : str | AdvantageEstimator ) -> Any :
75108 """Decorator to register a advantage estimator function with a given name.
76109
77110 Args:
@@ -108,26 +141,6 @@ def get_adv_estimator_fn(name_or_enum):
108141 return ADV_ESTIMATOR_REGISTRY [name ]
109142
110143
111- class AdvantageEstimator (str , Enum ):
112- """Using an enumeration class to avoid spelling errors in adv_estimator.
113-
114- Note(haibin.lin): this enum class is immutable after creation. Extending this
115- enum for new estimators may not be necessary since users can always just call
116- `verl.trainer.ppo.core_algos.register` with string name for a custom advantage
117- estimator instead.
118- """
119-
120- GAE = "gae"
121- GRPO = "grpo"
122- REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
123- REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
124- REMAX = "remax"
125- RLOO = "rloo"
126- OPO = "opo"
127- GRPO_PASSK = "grpo_passk"
128- GPG = "gpg"
129-
130-
131144class AdaptiveKLController :
132145 """
133146 Adaptive KL controller described in the paper:
@@ -822,7 +835,7 @@ def compute_policy_loss_clip_cov(
822835 advantages : torch .Tensor ,
823836 response_mask : torch .Tensor ,
824837 loss_agg_mode : str = "token-mean" ,
825- config : Optional [AlgoConfig ] = None ,
838+ config : Optional [DictConfig | AlgoConfig ] = None ,
826839) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
827840 """
828841 Compute the clipped policy objective and related metrics for Clip-Cov.
@@ -855,6 +868,10 @@ def compute_policy_loss_clip_cov(
855868 clip_cov_ub (float, optional):
856869 Upper bound for clipping covariance. Defaults to 5.0.
857870 """
871+ assert config is not None
872+ assert not isinstance (config , AlgoConfig ), "passing AlgoConfig not supported yet"
873+ assert config .policy_loss is not None
874+
858875 clip_cov_ratio = config .policy_loss .clip_cov_ratio if config .policy_loss .clip_cov_ratio is not None else 0.0002
859876 cliprange = config .clip_ratio
860877 cliprange_low = config .clip_ratio_low if config .clip_ratio_low is not None else cliprange
@@ -912,7 +929,7 @@ def compute_policy_loss_kl_cov(
912929 advantages : torch .Tensor ,
913930 response_mask : torch .Tensor ,
914931 loss_agg_mode : str = "token-mean" ,
915- config : Optional [AlgoConfig ] = None ,
932+ config : Optional [DictConfig | AlgoConfig ] = None ,
916933) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
917934 """
918935 Compute the clipped policy objective and related metrics for Clip-Cov.
@@ -936,6 +953,10 @@ def compute_policy_loss_kl_cov(
936953 ppo_kl_coef (float, optional):
937954 Coefficient for the KL penalty term in the loss. Defaults to 1.
938955 """
956+ assert config is not None
957+ assert not isinstance (config , AlgoConfig ), "passing AlgoConfig not supported yet"
958+ assert config .policy_loss is not None
959+
939960 kl_cov_ratio = config .policy_loss .kl_cov_ratio if config .policy_loss .kl_cov_ratio is not None else 0.0002
940961 ppo_kl_coef = config .policy_loss .ppo_kl_coef if config .policy_loss .ppo_kl_coef is not None else 1.0
941962
0 commit comments