diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e68644a010a..4b4c7b8435c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,11 @@ repos: exclude: ^.*\.(ipynb)$ - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.17.0' + hooks: + - id: mypy + - repo: local hooks: - id: autogen-trainer-cfg @@ -29,4 +34,4 @@ repos: name: Check license entry: python3 tests/special_sanity/check_license.py --directory . language: python - pass_filenames: false \ No newline at end of file + pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index b508516adc6..9d24ae95b35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,25 @@ ignore = [ "UP035", ] +# ------------------------------- +# tool.mypy - typechecking config +# ------------------------------- +[tool.mypy] +pretty = true +ignore_missing_imports = true +explicit_package_bases = true +follow_imports = "skip" + +# Blanket silence +ignore_errors = true + +[[tool.mypy.overrides]] +module = [ +"verl.trainer.config.algorithm", +"verl.trainer.ppo.core_algos", +] +ignore_errors = false + # ------------------------------- # tool.setuptools - Additional config # ------------------------------- diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 5f02675817b..c916b6c91b2 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -22,18 +22,31 @@ from collections import defaultdict from enum import Enum -from typing import Optional +from typing import Any, Callable, Optional import numpy as np import torch +from omegaconf import DictConfig import verl.utils.torch_functional as verl_F from verl.trainer.config import AlgoConfig -POLICY_LOSS_REGISTRY = {} +PolicyLossFn = Callable[ + [ + torch.Tensor, # old_log_prob + torch.Tensor, # log_prob + torch.Tensor, # advantages + torch.Tensor, # response_mask + str, # loss_agg_mode + Optional[DictConfig | AlgoConfig], # config + ], + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], +] +POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} -def register_policy_loss(name): + +def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]: """Register a policy loss function with the given name. Args: @@ -43,7 +56,7 @@ def register_policy_loss(name): function: Decorator function that registers the policy loss function. """ - def decorator(func): + def decorator(func: PolicyLossFn) -> PolicyLossFn: POLICY_LOSS_REGISTRY[name] = func return func @@ -68,10 +81,30 @@ def get_policy_loss_fn(name): return POLICY_LOSS_REGISTRY[loss_name] -ADV_ESTIMATOR_REGISTRY = {} +class AdvantageEstimator(str, Enum): + """Using an enumeration class to avoid spelling errors in adv_estimator. + + Note(haibin.lin): this enum class is immutable after creation. Extending this + enum for new estimators may not be necessary since users can always just call + `verl.trainer.ppo.core_algos.register` with string name for a custom advantage + estimator instead. + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + OPO = "opo" + GRPO_PASSK = "grpo_passk" + GPG = "gpg" + + +ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} -def register_adv_est(name_or_enum): +def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any: """Decorator to register a advantage estimator function with a given name. Args: @@ -108,26 +141,6 @@ def get_adv_estimator_fn(name_or_enum): return ADV_ESTIMATOR_REGISTRY[name] -class AdvantageEstimator(str, Enum): - """Using an enumeration class to avoid spelling errors in adv_estimator. - - Note(haibin.lin): this enum class is immutable after creation. Extending this - enum for new estimators may not be necessary since users can always just call - `verl.trainer.ppo.core_algos.register` with string name for a custom advantage - estimator instead. - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" - REMAX = "remax" - RLOO = "rloo" - OPO = "opo" - GRPO_PASSK = "grpo_passk" - GPG = "gpg" - - class AdaptiveKLController: """ Adaptive KL controller described in the paper: @@ -822,7 +835,7 @@ def compute_policy_loss_clip_cov( advantages: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", - config: Optional[AlgoConfig] = None, + config: Optional[DictConfig | AlgoConfig] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for Clip-Cov. @@ -855,6 +868,10 @@ def compute_policy_loss_clip_cov( clip_cov_ub (float, optional): Upper bound for clipping covariance. Defaults to 5.0. """ + assert config is not None + assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" + assert config.policy_loss is not None + clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 cliprange = config.clip_ratio cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange @@ -912,7 +929,7 @@ def compute_policy_loss_kl_cov( advantages: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", - config: Optional[AlgoConfig] = None, + config: Optional[DictConfig | AlgoConfig] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for Clip-Cov. @@ -936,6 +953,10 @@ def compute_policy_loss_kl_cov( ppo_kl_coef (float, optional): Coefficient for the KL penalty term in the loss. Defaults to 1. """ + assert config is not None + assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" + assert config.policy_loss is not None + kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0