Skip to content

Commit 4ae5729

Browse files
frradoseyosey
authored andcommitted
[CI] feat: add mypy to pre-commit (volcengine#2614)
1 parent 6c64894 commit 4ae5729

File tree

3 files changed

+74
-29
lines changed

3 files changed

+74
-29
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ repos:
77
exclude: ^.*\.(ipynb)$
88
- id: ruff-format
99

10+
- repo: https://github.com/pre-commit/mirrors-mypy
11+
rev: 'v1.17.0'
12+
hooks:
13+
- id: mypy
14+
1015
- repo: local
1116
hooks:
1217
- id: autogen-trainer-cfg
@@ -29,4 +34,4 @@ repos:
2934
name: Check license
3035
entry: python3 tests/special_sanity/check_license.py --directory .
3136
language: python
32-
pass_filenames: false
37+
pass_filenames: false

pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ ignore = [
6565
"UP035",
6666
]
6767

68+
# -------------------------------
69+
# tool.mypy - typechecking config
70+
# -------------------------------
71+
[tool.mypy]
72+
pretty = true
73+
ignore_missing_imports = true
74+
explicit_package_bases = true
75+
follow_imports = "skip"
76+
77+
# Blanket silence
78+
ignore_errors = true
79+
80+
[[tool.mypy.overrides]]
81+
module = [
82+
"verl.trainer.config.algorithm",
83+
"verl.trainer.ppo.core_algos",
84+
]
85+
ignore_errors = false
86+
6887
# -------------------------------
6988
# tool.setuptools - Additional config
7089
# -------------------------------

verl/trainer/ppo/core_algos.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,31 @@
2222

2323
from collections import defaultdict
2424
from enum import Enum
25-
from typing import Optional
25+
from typing import Any, Callable, Optional
2626

2727
import numpy as np
2828
import torch
29+
from omegaconf import DictConfig
2930

3031
import verl.utils.torch_functional as verl_F
3132
from verl.trainer.config import AlgoConfig
3233

33-
POLICY_LOSS_REGISTRY = {}
34+
PolicyLossFn = Callable[
35+
[
36+
torch.Tensor, # old_log_prob
37+
torch.Tensor, # log_prob
38+
torch.Tensor, # advantages
39+
torch.Tensor, # response_mask
40+
str, # loss_agg_mode
41+
Optional[DictConfig | AlgoConfig], # config
42+
],
43+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
44+
]
3445

46+
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
3547

36-
def register_policy_loss(name):
48+
49+
def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:
3750
"""Register a policy loss function with the given name.
3851
3952
Args:
@@ -43,7 +56,7 @@ def register_policy_loss(name):
4356
function: Decorator function that registers the policy loss function.
4457
"""
4558

46-
def decorator(func):
59+
def decorator(func: PolicyLossFn) -> PolicyLossFn:
4760
POLICY_LOSS_REGISTRY[name] = func
4861
return func
4962

@@ -68,10 +81,30 @@ def get_policy_loss_fn(name):
6881
return POLICY_LOSS_REGISTRY[loss_name]
6982

7083

71-
ADV_ESTIMATOR_REGISTRY = {}
84+
class AdvantageEstimator(str, Enum):
85+
"""Using an enumeration class to avoid spelling errors in adv_estimator.
86+
87+
Note(haibin.lin): this enum class is immutable after creation. Extending this
88+
enum for new estimators may not be necessary since users can always just call
89+
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
90+
estimator instead.
91+
"""
92+
93+
GAE = "gae"
94+
GRPO = "grpo"
95+
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
96+
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
97+
REMAX = "remax"
98+
RLOO = "rloo"
99+
OPO = "opo"
100+
GRPO_PASSK = "grpo_passk"
101+
GPG = "gpg"
102+
103+
104+
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
72105

73106

74-
def register_adv_est(name_or_enum):
107+
def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
75108
"""Decorator to register a advantage estimator function with a given name.
76109
77110
Args:
@@ -108,26 +141,6 @@ def get_adv_estimator_fn(name_or_enum):
108141
return ADV_ESTIMATOR_REGISTRY[name]
109142

110143

111-
class AdvantageEstimator(str, Enum):
112-
"""Using an enumeration class to avoid spelling errors in adv_estimator.
113-
114-
Note(haibin.lin): this enum class is immutable after creation. Extending this
115-
enum for new estimators may not be necessary since users can always just call
116-
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
117-
estimator instead.
118-
"""
119-
120-
GAE = "gae"
121-
GRPO = "grpo"
122-
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
123-
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
124-
REMAX = "remax"
125-
RLOO = "rloo"
126-
OPO = "opo"
127-
GRPO_PASSK = "grpo_passk"
128-
GPG = "gpg"
129-
130-
131144
class AdaptiveKLController:
132145
"""
133146
Adaptive KL controller described in the paper:
@@ -822,7 +835,7 @@ def compute_policy_loss_clip_cov(
822835
advantages: torch.Tensor,
823836
response_mask: torch.Tensor,
824837
loss_agg_mode: str = "token-mean",
825-
config: Optional[AlgoConfig] = None,
838+
config: Optional[DictConfig | AlgoConfig] = None,
826839
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
827840
"""
828841
Compute the clipped policy objective and related metrics for Clip-Cov.
@@ -855,6 +868,10 @@ def compute_policy_loss_clip_cov(
855868
clip_cov_ub (float, optional):
856869
Upper bound for clipping covariance. Defaults to 5.0.
857870
"""
871+
assert config is not None
872+
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
873+
assert config.policy_loss is not None
874+
858875
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
859876
cliprange = config.clip_ratio
860877
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
@@ -912,7 +929,7 @@ def compute_policy_loss_kl_cov(
912929
advantages: torch.Tensor,
913930
response_mask: torch.Tensor,
914931
loss_agg_mode: str = "token-mean",
915-
config: Optional[AlgoConfig] = None,
932+
config: Optional[DictConfig | AlgoConfig] = None,
916933
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
917934
"""
918935
Compute the clipped policy objective and related metrics for Clip-Cov.
@@ -936,6 +953,10 @@ def compute_policy_loss_kl_cov(
936953
ppo_kl_coef (float, optional):
937954
Coefficient for the KL penalty term in the loss. Defaults to 1.
938955
"""
956+
assert config is not None
957+
assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet"
958+
assert config.policy_loss is not None
959+
939960
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
940961
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
941962

0 commit comments

Comments
 (0)