diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 700e2ebcae7..d1713d1425f 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -501,7 +501,7 @@ Algorithm - ``gamma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator -- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo`` +- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized`` - ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False. - ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to calculate the kl divergence between actor and reference policy. For diff --git a/examples/ppo_trainer/README.md b/examples/ppo_trainer/README.md index cde0c9be51f..7b7261c5dd1 100644 --- a/examples/ppo_trainer/README.md +++ b/examples/ppo_trainer/README.md @@ -43,7 +43,7 @@ Most critic configs are similar to those of actors. Note that the critic model i - `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator -- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo +- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo, rloo_vectorized ## Advanced Extensions diff --git a/tests/trainer/ppo/test_core_algos_on_cpu.py b/tests/trainer/ppo/test_core_algos_on_cpu.py index 087a0d2f129..4fbd118ec5b 100644 --- a/tests/trainer/ppo/test_core_algos_on_cpu.py +++ b/tests/trainer/ppo/test_core_algos_on_cpu.py @@ -15,11 +15,18 @@ import random import unittest +import numpy as np import pytest import torch import verl.trainer.ppo.core_algos -from verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est +from verl.trainer.ppo.core_algos import ( + compute_gae_advantage_return, + compute_rloo_outcome_advantage, + compute_rloo_vectorized_outcome_advantage, + get_adv_estimator_fn, + register_adv_est, +) def mock_test_fn(): @@ -188,5 +195,67 @@ def test_multi_turn_compute_gae_advantage_return(): print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}") +def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray: + """Create a numpy index array ensuring each group has at least 2 samples.""" + assert num_groups * 2 <= batch_size, "batch_size must allow >=2 samples per group" + counts: list[int] = [2] * num_groups + remaining = batch_size - 2 * num_groups + for _ in range(remaining): + counts[random.randrange(num_groups)] += 1 + index = [] + for gid, c in enumerate(counts): + index.extend([gid] * c) + random.shuffle(index) + return np.asarray(index, dtype=np.int64) + + +def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor: + mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float() + rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0] + if len(rows_without_one) > 0: + mask[rows_without_one, -1] = 1.0 + return mask + + +@pytest.mark.parametrize( + "batch_size,seq_len,num_groups,seed", + [ + (64, 128, 5, 0), + (128, 256, 8, 1), + (512, 512, 10, 2), + ], +) +def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + index = _make_group_index(batch_size, num_groups) + response_mask = _rand_mask(batch_size, seq_len) + base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32) + token_level_rewards = base_rewards * response_mask + adv1, ret1 = compute_rloo_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + ) + adv2, ret2 = compute_rloo_vectorized_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + ) + # Print concise diagnostics for visibility during test runs + adv_max_diff = (adv1 - adv2).abs().max().item() + ret_max_diff = (ret1 - ret2).abs().max().item() + total_mask_tokens = int(response_mask.sum().item()) + print( + f"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} " + f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}" + ) + assert adv1.shape == adv2.shape == (batch_size, seq_len) + assert ret1.shape == ret2.shape == (batch_size, seq_len) + assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6) + assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6) + + if __name__ == "__main__": unittest.main() diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7a9103c4dfb..dee03c6f78f 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -102,6 +102,7 @@ class AdvantageEstimator(str, Enum): OPO = "opo" GRPO_PASSK = "grpo_passk" GPG = "gpg" + RLOO_VECTORIZED = "rloo_vectorized" ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} @@ -685,6 +686,44 @@ def compute_gpg_outcome_advantage( return scores, scores +@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized") +def compute_rloo_vectorized_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + with torch.no_grad(): + inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device) + + c = torch.bincount(inv)[inv].to(scores.dtype) + adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1) + + adv = adv.unsqueeze(-1) * response_mask + + return adv, adv + + def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): """Compute token-level rewards with KL penalty.