Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ repos:
exclude: ^.*\.(ipynb)$
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.17.0'
hooks:
- id: mypy

- repo: local
hooks:
- id: autogen-trainer-cfg
Expand All @@ -29,4 +34,4 @@ repos:
name: Check license
entry: python3 tests/special_sanity/check_license.py --directory .
language: python
pass_filenames: false
pass_filenames: false
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ ignore = [
"UP035",
]

# -------------------------------
# tool.mypy - typechecking config
# -------------------------------
[tool.mypy]
pretty = true
ignore_missing_imports = true
explicit_package_bases = true
follow_imports = "skip"

# Blanket silence
ignore_errors = true

[[tool.mypy.overrides]]
module = ["verl.trainer.ppo.core_algos"]
ignore_errors = false

# -------------------------------
# tool.setuptools - Additional config
# -------------------------------
Expand Down
50 changes: 25 additions & 25 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from collections import defaultdict
from enum import Enum
from typing import Optional
from typing import Any, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -68,10 +68,30 @@ def get_policy_loss_fn(name):
return POLICY_LOSS_REGISTRY[loss_name]


ADV_ESTIMATOR_REGISTRY = {}
class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.

Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""

def register_adv_est(name_or_enum):
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"


ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}


def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:
"""Decorator to register a advantage estimator function with a given name.

Args:
Expand Down Expand Up @@ -108,26 +128,6 @@ def get_adv_estimator_fn(name_or_enum):
return ADV_ESTIMATOR_REGISTRY[name]


class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.

Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""

GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"


class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
Expand Down Expand Up @@ -822,7 +822,7 @@ def compute_policy_loss_clip_cov(
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[AlgoConfig] = None,
config: Any = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Expand Down Expand Up @@ -912,7 +912,7 @@ def compute_policy_loss_kl_cov(
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[AlgoConfig] = None,
config: Any = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Expand Down