-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[algo, perf] feat: Vectorize RLOO Advantage Estimator - 20x Speedup #3555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request vectorizes the RLOO advantage estimator, leading to a significant performance improvement. The previous implementation using Python loops and dictionaries has been replaced with an efficient, vectorized approach using torch.bincount. This not only boosts performance but also enhances correctness by correctly handling groups with a single response, for which the advantage is now properly set to zero. The new code is more concise and idiomatic for tensor computations. Overall, this is an excellent and well-executed optimization.
|
Could you help add a CI to ensure that they are consistent? |
|
We may want to introduce a new adv estimator called rloo_vector instead of directly over-write the original one |
59e6c0f to
e73112b
Compare
|
@vermouth1992 I don't know your guys' standard workflow YAML but there is a PyTest if you want to add to ci |
…olcengine#3555) Vectorize RLOO advantage estimator 130ms -> 6ms Similar method can be done for other advantage estimators, I just don't have time Implements $$r_i - \frac{\sum_{j\ne i} r_j}{G-1} = \frac{(G-1)r_i - \sum_{j\ne i} r_j}{G-1} = \frac{G r_i - \sum_{j\in g} r_j}{G-1}$$ <img width="2199" height="628" alt="image" src="https://github.com/user-attachments/assets/339e5bd2-6949-4460-a297-34268ffc1764" />
…olcengine#3555) Vectorize RLOO advantage estimator 130ms -> 6ms Similar method can be done for other advantage estimators, I just don't have time Implements $$r_i - \frac{\sum_{j\ne i} r_j}{G-1} = \frac{(G-1)r_i - \sum_{j\ne i} r_j}{G-1} = \frac{G r_i - \sum_{j\in g} r_j}{G-1}$$ <img width="2199" height="628" alt="image" src="https://github.com/user-attachments/assets/339e5bd2-6949-4460-a297-34268ffc1764" />
…olcengine#3555) Vectorize RLOO advantage estimator 130ms -> 6ms Similar method can be done for other advantage estimators, I just don't have time Implements $$r_i - \frac{\sum_{j\ne i} r_j}{G-1} = \frac{(G-1)r_i - \sum_{j\ne i} r_j}{G-1} = \frac{G r_i - \sum_{j\in g} r_j}{G-1}$$ <img width="2199" height="628" alt="image" src="https://github.com/user-attachments/assets/339e5bd2-6949-4460-a297-34268ffc1764" />
…olcengine#3555) Vectorize RLOO advantage estimator 130ms -> 6ms Similar method can be done for other advantage estimators, I just don't have time Implements $$r_i - \frac{\sum_{j\ne i} r_j}{G-1} = \frac{(G-1)r_i - \sum_{j\ne i} r_j}{G-1} = \frac{G r_i - \sum_{j\in g} r_j}{G-1}$$ <img width="2199" height="628" alt="image" src="https://github.com/user-attachments/assets/339e5bd2-6949-4460-a297-34268ffc1764" />
Vectorize RLOO advantage estimator
130ms -> 6ms
Similar method can be done for other advantage estimators, I just don't have time
Implements