Skip to content

Conversation

@ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Feb 27, 2025

CUTLASS implementation of fp8 MoE kernel.

Tested with

llm = LLM(model="nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8",
          trust_remote_code=True,
          tensor_parallel_size=2,
 )

Benchmark (Deepseek V2 Lite, total time of 25 runs)

[--------------------------------------------------------------------------------------------------------- Quant Matmul ---------------------------------------------------------------------------------------------------------]
                                                                                                                                    |  triton_moe  |  triton_moe_cuda_graphs  |  grouped_gemm_moe  |  grouped_gemm_moe_cuda_graphs
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((1, 2048, 1408))               |      3.6     |            2.6           |         3.6        |               3.3            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((4, 2048, 1408))               |      6.8     |            6.7           |         4.7        |               4.3            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((8, 2048, 1408))               |     10.1     |           10.0           |         5.6        |               5.1            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((16, 2048, 1408))              |     15.0     |           14.9           |         6.8        |               6.3            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((32, 2048, 1408))              |     16.9     |           16.8           |         7.3        |               6.9            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((64, 2048, 1408))              |     17.0     |           16.9           |         7.6        |               7.1            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((128, 2048, 1408))             |      8.5     |            8.4           |         8.1        |               7.6            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((256, 2048, 1408))             |      9.1     |            9.0           |         9.0        |               8.5            
      nm-testing/deepseekv2-lite, num_experts=64, topk=6, per_act_token=False per_out_ch=False, MKN=((512, 2048, 1408))             |     10.9     |           10.8           |        10.6        |              10.1          
(times are in ms)  

ElizaWszola and others added 30 commits December 6, 2024 14:36
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Comment on lines +111 to +117
template <typename Descriptor, typename T>
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
using Arguments = typename Descriptor::Arguments;
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
return Arguments{data_ptr, do_broadcast};
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider revisiting this interface in a follow up?

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted some issues, mainly around the CUDA version and compute capability checks

torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be ENABLE_CUTLASS_MOE_SM90 now

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 26, 2025
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) March 27, 2025 00:51
@robertgshaw2-redhat robertgshaw2-redhat merged commit 9239bf7 into vllm-project:main Mar 27, 2025
67 checks passed
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry had some pending review comments I forgot to submit, submitting now for posterity. Most was for future PRs anyways

def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]:
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does a type union containing Any do anything more than Any?


if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for future PR, we should abstract this more like https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/kernels/scaled_mm to make it easier to adopt as a backend for non-compressed-tensor checkpoints

class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcastArray {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future pr: we should see if we should can use the now upstream IsArrayOfPointers support in Sm90RowBroadcast to eliminate this file

@li2haipeng
Copy link
Contributor

Thanks for the PR! Seems now the PR only supports quant_method=compressed_tensor. I'm wondering if you have tested it on DeepSeek V3 or other quant_method=fp8 models? Do we have plans to support them?

@tlrmchlsmth
Copy link
Member

Thanks for the PR! Seems now the PR only supports quant_method=compressed_tensor. I'm wondering if you have tested it on DeepSeek V3 or other quant_method=fp8 models? Do we have plans to support them?

DeepSeekV3's blocked per-token quantization isn't supported by these kernels yet and requires additional work, but this is on our roadmap

Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: xinyuxiao <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
@shixianc
Copy link
Contributor

shixianc commented Jun 7, 2025

DeepSeekV3's blocked per-token quantization isn't supported by these kernels yet and requires additional work, but this is on our roadmap

What about per-tensor quantization? We have a fp8 model quant_method=fp8 quantized thru AutoFP8. Could you give some suggestion? I'd like to contribute but want to hear your suggestion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants