Skip to content

Commit 26a734e

Browse files
authored
[algo, perf] feat: Vectorize RLOO Advantage Estimator - 20x Speedup (#3555)
Vectorize RLOO advantage estimator 130ms -> 6ms Similar method can be done for other advantage estimators, I just don't have time Implements $$r_i - \frac{\sum_{j\ne i} r_j}{G-1} = \frac{(G-1)r_i - \sum_{j\ne i} r_j}{G-1} = \frac{G r_i - \sum_{j\in g} r_j}{G-1}$$ <img width="2199" height="628" alt="image" src="https://github.com/user-attachments/assets/339e5bd2-6949-4460-a297-34268ffc1764" />
1 parent 69b0127 commit 26a734e

File tree

4 files changed

+111
-3
lines changed

4 files changed

+111
-3
lines changed

docs/examples/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ Algorithm
501501
502502
- ``gamma``: discount factor
503503
- ``lam``: Trade-off between bias and variance in the GAE estimator
504-
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``
504+
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``
505505
- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False.
506506
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to
507507
calculate the kl divergence between actor and reference policy. For

examples/ppo_trainer/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Most critic configs are similar to those of actors. Note that the critic model i
4343

4444
- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator
4545

46-
- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo
46+
- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo, rloo_vectorized
4747

4848
## Advanced Extensions
4949

tests/trainer/ppo/test_core_algos_on_cpu.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,18 @@
1515
import random
1616
import unittest
1717

18+
import numpy as np
1819
import pytest
1920
import torch
2021

2122
import verl.trainer.ppo.core_algos
22-
from verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est
23+
from verl.trainer.ppo.core_algos import (
24+
compute_gae_advantage_return,
25+
compute_rloo_outcome_advantage,
26+
compute_rloo_vectorized_outcome_advantage,
27+
get_adv_estimator_fn,
28+
register_adv_est,
29+
)
2330

2431

2532
def mock_test_fn():
@@ -188,5 +195,67 @@ def test_multi_turn_compute_gae_advantage_return():
188195
print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}")
189196

190197

198+
def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:
199+
"""Create a numpy index array ensuring each group has at least 2 samples."""
200+
assert num_groups * 2 <= batch_size, "batch_size must allow >=2 samples per group"
201+
counts: list[int] = [2] * num_groups
202+
remaining = batch_size - 2 * num_groups
203+
for _ in range(remaining):
204+
counts[random.randrange(num_groups)] += 1
205+
index = []
206+
for gid, c in enumerate(counts):
207+
index.extend([gid] * c)
208+
random.shuffle(index)
209+
return np.asarray(index, dtype=np.int64)
210+
211+
212+
def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:
213+
mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float()
214+
rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0]
215+
if len(rows_without_one) > 0:
216+
mask[rows_without_one, -1] = 1.0
217+
return mask
218+
219+
220+
@pytest.mark.parametrize(
221+
"batch_size,seq_len,num_groups,seed",
222+
[
223+
(64, 128, 5, 0),
224+
(128, 256, 8, 1),
225+
(512, 512, 10, 2),
226+
],
227+
)
228+
def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
229+
torch.manual_seed(seed)
230+
random.seed(seed)
231+
np.random.seed(seed)
232+
index = _make_group_index(batch_size, num_groups)
233+
response_mask = _rand_mask(batch_size, seq_len)
234+
base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
235+
token_level_rewards = base_rewards * response_mask
236+
adv1, ret1 = compute_rloo_outcome_advantage(
237+
token_level_rewards=token_level_rewards,
238+
response_mask=response_mask,
239+
index=index,
240+
)
241+
adv2, ret2 = compute_rloo_vectorized_outcome_advantage(
242+
token_level_rewards=token_level_rewards,
243+
response_mask=response_mask,
244+
index=index,
245+
)
246+
# Print concise diagnostics for visibility during test runs
247+
adv_max_diff = (adv1 - adv2).abs().max().item()
248+
ret_max_diff = (ret1 - ret2).abs().max().item()
249+
total_mask_tokens = int(response_mask.sum().item())
250+
print(
251+
f"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} "
252+
f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
253+
)
254+
assert adv1.shape == adv2.shape == (batch_size, seq_len)
255+
assert ret1.shape == ret2.shape == (batch_size, seq_len)
256+
assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
257+
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
258+
259+
191260
if __name__ == "__main__":
192261
unittest.main()

verl/trainer/ppo/core_algos.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class AdvantageEstimator(str, Enum):
102102
OPO = "opo"
103103
GRPO_PASSK = "grpo_passk"
104104
GPG = "gpg"
105+
RLOO_VECTORIZED = "rloo_vectorized"
105106

106107

107108
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
@@ -685,6 +686,44 @@ def compute_gpg_outcome_advantage(
685686
return scores, scores
686687

687688

689+
@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized")
690+
def compute_rloo_vectorized_outcome_advantage(
691+
token_level_rewards: torch.Tensor,
692+
response_mask: torch.Tensor,
693+
index: np.ndarray,
694+
epsilon: float = 1e-6,
695+
config: Optional[AlgoConfig] = None,
696+
**kwargs,
697+
) -> tuple[torch.Tensor, torch.Tensor]:
698+
"""
699+
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
700+
701+
Args:
702+
token_level_rewards: `(torch.Tensor)`
703+
shape: (bs, response_length)
704+
response_mask: `(torch.Tensor)`
705+
shape: (bs, response_length)
706+
config: (AlgoConfig) algorithm config
707+
708+
Returns:
709+
advantages: `(torch.Tensor)`
710+
shape: (bs, response_length)
711+
Returns: `(torch.Tensor)`
712+
shape: (bs, response_length)
713+
"""
714+
scores = token_level_rewards.sum(dim=-1)
715+
716+
with torch.no_grad():
717+
inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device)
718+
719+
c = torch.bincount(inv)[inv].to(scores.dtype)
720+
adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1)
721+
722+
adv = adv.unsqueeze(-1) * response_mask
723+
724+
return adv, adv
725+
726+
688727
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
689728
"""Compute token-level rewards with KL penalty.
690729

0 commit comments

Comments
 (0)