Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ Algorithm

- ``gamma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``
- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False.
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to
calculate the kl divergence between actor and reference policy. For
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_trainer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Most critic configs are similar to those of actors. Note that the critic model i

- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator

- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo
- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo, rloo_vectorized

## Advanced Extensions

Expand Down
71 changes: 70 additions & 1 deletion tests/trainer/ppo/test_core_algos_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
import random
import unittest

import numpy as np
import pytest
import torch

import verl.trainer.ppo.core_algos
from verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est
from verl.trainer.ppo.core_algos import (
compute_gae_advantage_return,
compute_rloo_outcome_advantage,
compute_rloo_vectorized_outcome_advantage,
get_adv_estimator_fn,
register_adv_est,
)


def mock_test_fn():
Expand Down Expand Up @@ -188,5 +195,67 @@ def test_multi_turn_compute_gae_advantage_return():
print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}")


def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:
"""Create a numpy index array ensuring each group has at least 2 samples."""
assert num_groups * 2 <= batch_size, "batch_size must allow >=2 samples per group"
counts: list[int] = [2] * num_groups
remaining = batch_size - 2 * num_groups
for _ in range(remaining):
counts[random.randrange(num_groups)] += 1
index = []
for gid, c in enumerate(counts):
index.extend([gid] * c)
random.shuffle(index)
return np.asarray(index, dtype=np.int64)


def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:
mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float()
rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0]
if len(rows_without_one) > 0:
mask[rows_without_one, -1] = 1.0
return mask


@pytest.mark.parametrize(
"batch_size,seq_len,num_groups,seed",
[
(64, 128, 5, 0),
(128, 256, 8, 1),
(512, 512, 10, 2),
],
)
def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
index = _make_group_index(batch_size, num_groups)
response_mask = _rand_mask(batch_size, seq_len)
base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
token_level_rewards = base_rewards * response_mask
adv1, ret1 = compute_rloo_outcome_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)
adv2, ret2 = compute_rloo_vectorized_outcome_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)
# Print concise diagnostics for visibility during test runs
adv_max_diff = (adv1 - adv2).abs().max().item()
ret_max_diff = (ret1 - ret2).abs().max().item()
total_mask_tokens = int(response_mask.sum().item())
print(
f"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} "
f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
)
assert adv1.shape == adv2.shape == (batch_size, seq_len)
assert ret1.shape == ret2.shape == (batch_size, seq_len)
assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)


if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class AdvantageEstimator(str, Enum):
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
RLOO_VECTORIZED = "rloo_vectorized"


ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
Expand Down Expand Up @@ -685,6 +686,44 @@ def compute_gpg_outcome_advantage(
return scores, scores


@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized")
def compute_rloo_vectorized_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740

Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)

with torch.no_grad():
inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device)

c = torch.bincount(inv)[inv].to(scores.dtype)
adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1)

adv = adv.unsqueeze(-1) * response_mask

return adv, adv


def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
"""Compute token-level rewards with KL penalty.

Expand Down
Loading