-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Closed
Description
Feature request
Prior art
This work borrows the core idea from PR
#3555 (vectorized RLOO)
Motivation
- Add AdvantageEstimator.GRPO_VECTORIZED
- Add groupwise vector helpers
- Add CI test
Your contribution
Group-Wise helpers
To support GRPO (and future per-group estimators) cleanly and efficiently:
• segment_sum(values, group_idx, G): computes per-group sums using torch.bincount(..., minlength=G) (CPU/GPU friendly, no external deps).
• segment_count(group_idx, G): per-group counts via bincount.
• gather_by_group(stat, group_idx): maps a per-group tensor back to per-sample via stat.index_select(0, group_idx).
This achieves:
• No Python loops over groups or samples.
• O(B) time, minimal kernel launches (just a few reductions + gathers).
• Drop-in for other estimators in future PRs.
Metadata
Metadata
Assignees
Labels
No labels