-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Kernel] CUTLASS grouped gemm fp8 MoE kernel #13972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] CUTLASS grouped gemm fp8 MoE kernel #13972
Conversation
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]> Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
…of tensors Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
| template <typename Descriptor, typename T> | ||
| static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { | ||
| using Arguments = typename Descriptor::Arguments; | ||
| static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> || | ||
| std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>); | ||
| return Arguments{data_ptr, do_broadcast}; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider revisiting this interface in a follow up?
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spotted some issues, mainly around the CUDA version and compute capability checks
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
Show resolved
Hide resolved
| torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, | ||
| torch::Tensor const& b_strides, torch::Tensor const& c_strides) { | ||
| int32_t version_num = get_sm_version_num(); | ||
| #if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be ENABLE_CUTLASS_MOE_SM90 now
Signed-off-by: ElizaWszola <[email protected]>
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now!
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry had some pending review comments I forgot to submit, submitting now for posterity. Most was for future PRs anyways
| def weak_ref_tensors( | ||
| tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] | ||
| ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]: | ||
| ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Does a type union containing Any do anything more than Any?
|
|
||
| if quant_config._is_wNa16_group_channel(weight_quant, input_quant): | ||
| return CompressedTensorsWNA16MoEMethod(quant_config) | ||
| elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for future PR, we should abstract this more like https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/kernels/scaled_mm to make it easier to adopt as a backend for non-compressed-tensor checkpoints
| class StrideMNL = Stride<_0,_1,_0>, | ||
| int Alignment = 128 / sizeof_bits_v<Element> | ||
| > | ||
| struct Sm90RowOrScalarBroadcastArray { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
future pr: we should see if we should can use the now upstream IsArrayOfPointers support in Sm90RowBroadcast to eliminate this file
|
Thanks for the PR! Seems now the PR only supports |
DeepSeekV3's blocked per-token quantization isn't supported by these kernels yet and requires additional work, but this is on our roadmap |
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Signed-off-by: xinyuxiao <[email protected]>
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Signed-off-by: Mu Huai <[email protected]>
What about per-tensor quantization? We have a fp8 model quant_method=fp8 quantized thru AutoFP8. Could you give some suggestion? I'd like to contribute but want to hear your suggestion |
CUTLASS implementation of fp8 MoE kernel.
Tested with
Benchmark (Deepseek V2 Lite, total time of 25 runs)