-
Notifications
You must be signed in to change notification settings - Fork 157
Support Gumbel-Max Sampling Kernel #574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Please consider the latest version flashinfer-ai/flashinfer#2119 which significant improves performance. |
|
|
||
| template <uint32_t BLOCK_THREADS, | ||
| uint32_t VEC_SIZE, | ||
| int BATCH_SIZE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this template arg if we are passing in batch size as a function arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
| constexpr int CONSUMER_NUM_THREADS = 128; // Grace Hopper setting | ||
| #else | ||
| // Default settings for other architectures | ||
| constexpr int WORKER_NUM_THREADS = 128; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also have WORKER_NUM_THREADS defined in the persistent_kernel.cuh file, not sure if it's a duplicated definition here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
JackFram
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sampling task implementation in this PR is already in a pretty good shape. I left some comments, and feel free to address them. Waiting for other reviewers input
| @@ -0,0 +1,228 @@ | |||
| /* Copyright 2023-2024 CMU | |||
| * | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per Apache 2.0 you have to keep the original copyright as well.
e.g.
Copyright 2023-2025 FlashInfer contributors
in addition to your own copyright.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I forgot to mention this. @STWMichae, please make sure to add the copyright before merging as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Description of changes:
This PR adds support for Gumbel-Max Sampling Kernel in all architectures. The original kernel code is from flashinfer.
In order to support the sampling in MPK, we:
BLOCK_THREADSfrom 1024 to 128/256.batch_sizeand iterates through all batches sequentially.This PR also adds unit tests to verify the correctness of the kernel. It also provides a demo_sampling.py that shows the sampling effect during inference.
Related Issues:
#519
Linked Issues:
Issues closed by this PR: