diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index 9021bd0847..ca1bb31acd 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -240,13 +240,14 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS } } +template __global__ void block_scale_interleave_kernel(int numBatches, int numRows, int numRowsPadded, - int numCols, int numColsPadded, uint8_t const* SFIn, - uint8_t* SFOutput) { + int numCols, int numColsPadded, T const* SFIn, + T* SFOutput) { for (int rowIdx = blockIdx.x; rowIdx < numRowsPadded; rowIdx += gridDim.x) { for (int batchIdx = 0; batchIdx < numBatches; batchIdx++) { for (int colIdx = threadIdx.x; colIdx < numColsPadded; colIdx += blockDim.x) { - uint8_t sf = 0; + T sf = 0; if (rowIdx < numRows && colIdx < numCols) { int64_t inOffset = batchIdx * numRows * numCols + rowIdx * numCols + colIdx; sf = SFIn[inOffset]; @@ -287,19 +288,29 @@ __global__ void block_scale_interleave_reverse_kernel(int numBatches, int numRow } // This is intended for weight loading, so m and n are large, b <= 256 -void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, - uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, - cudaStream_t stream) { +template +void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn, + T* SFOutput, int multiProcessorCount, cudaStream_t stream) { // Each thread reads 1 int8 value dim3 block(std::min(n_padded, 1024)); // Get number of blocks per SM (assume we can fully utilize the SM). int const numBlocksPerSM = std::max(1u, 4096u / block.x); dim3 grid(std::min(m_padded, multiProcessorCount * numBlocksPerSM)); - block_scale_interleave_kernel<<>>(b, m, m_padded, n, n_padded, SFIn, - SFOutput); + block_scale_interleave_kernel + <<>>(b, m, m_padded, n, n_padded, SFIn, SFOutput); } +// Explicit template instantiations for the types used by other compilation units +template void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, + uint8_t const* SFIn, uint8_t* SFOutput, + int multiProcessorCount, cudaStream_t stream); +template void invokeBlockScaleInterleave<__nv_bfloat16>(int b, int m, int m_padded, int n, + int n_padded, __nv_bfloat16 const* SFIn, + __nv_bfloat16* SFOutput, + int multiProcessorCount, + cudaStream_t stream); + // This is intended for weight loading, so m and n are large, b <= 256 void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream) { diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.h b/csrc/nv_internal/tensorrt_llm/kernels/quantization.h index 8a90578cb4..9809d20a5b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.h @@ -67,9 +67,9 @@ void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* i void* input_global_scale, void* mask, bool use_silu_and_mul, int m_topk, int k, int n_experts, cudaStream_t stream); -void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, - uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, - cudaStream_t stream = 0); +template +void invokeBlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, T const* SFIn, + T* SFOutput, int multiProcessorCount, cudaStream_t stream = 0); void invokeBlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* SFIn, uint8_t* SFOutput, int multiProcessorCount, cudaStream_t stream = 0); diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp index 673bc27edd..a43ee01a5a 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp @@ -137,6 +137,41 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, } } +template +void blockScaleInterleaveHost(TensorView blockScale, TensorView interleavedBlockScale) { + auto blockScaleShape = blockScale.sizes(); + auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1; + auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0]; + auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1]; + + auto expert_out_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols); + auto rows_padded = PadUpFn(rows, 128); + auto cols_padded = PadUpFn(cols, 4); + + for (int eIdx = 0; eIdx < static_cast(num_experts); eIdx++) { + T* interleavedBlockScalePtr = + static_cast(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; + for (int rIdx = 0; rIdx < static_cast(rows_padded); ++rIdx) { + auto globalRowIdx = eIdx * rows + rIdx; + T* blockScalePtr = static_cast(blockScale.data_ptr()) + globalRowIdx * cols; + for (int cIdx = 0; cIdx < static_cast(cols_padded); ++cIdx) { + T sf_ori = 0; + if (rIdx < static_cast(rows) && cIdx < static_cast(cols)) { + sf_ori = blockScalePtr[cIdx]; + } + int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, + tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); + interleavedBlockScalePtr[sf_index] = sf_ori; + } + } + } +} + +template void blockScaleInterleaveHost(TensorView blockScale, + TensorView interleavedBlockScale); +template void blockScaleInterleaveHost<__nv_bfloat16>(TensorView blockScale, + TensorView interleavedBlockScale); + // Interleave (and possibly pad) the weights block scaling factor. // blockScale: [num_experts, rows, cols] or [rows, cols] // Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4) @@ -148,7 +183,8 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal CHECK_CPU(blockScale); } CHECK_CONTIGUOUS(blockScale); - CHECK_INPUT_TYPE(blockScale, dl_uint8); + TVM_FFI_ICHECK(blockScale.dtype() == dl_uint8 || blockScale.dtype() == dl_bfloat16) + << "Block Scale must be uint8 or bfloat16."; auto blockScaleShape = blockScale.sizes(); TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3) << "Block Scale should be 2D or 3D tensor."; @@ -166,27 +202,28 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount(); const cudaStream_t stream = get_stream(blockScale.device()); - tensorrt_llm::kernels::invokeBlockScaleInterleave( - num_experts, rows, rows_padded, cols, cols_padded, - static_cast(blockScale.data_ptr()), - static_cast(interleavedBlockScale.data_ptr()), smCount, stream); + if (blockScale.dtype() == dl_uint8) { + tensorrt_llm::kernels::invokeBlockScaleInterleave( + num_experts, rows, rows_padded, cols, cols_padded, + static_cast(blockScale.data_ptr()), + static_cast(interleavedBlockScale.data_ptr()), smCount, stream); + } else if (blockScale.dtype() == dl_bfloat16) { + tensorrt_llm::kernels::invokeBlockScaleInterleave( + num_experts, rows, rows_padded, cols, cols_padded, + static_cast<__nv_bfloat16*>(blockScale.data_ptr()), + static_cast<__nv_bfloat16*>(interleavedBlockScale.data_ptr()), smCount, stream); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "block_scale_interleave only supports uint8 and bfloat16."; + } } else { - for (int eIdx = 0; eIdx < static_cast(num_experts); eIdx++) { - uint8_t* interleavedBlockScalePtr = - static_cast(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; - for (int rIdx = 0; rIdx < static_cast(rows_padded); ++rIdx) { - auto globalRowIdx = eIdx * rows + rIdx; - uint8_t* blockScalePtr = static_cast(blockScale.data_ptr()) + globalRowIdx * cols; - for (int cIdx = 0; cIdx < static_cast(cols_padded); ++cIdx) { - uint8_t sf_ori = 0; - if (rIdx < static_cast(rows) && cIdx < static_cast(cols)) { - sf_ori = blockScalePtr[cIdx]; - } - int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, - tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); - interleavedBlockScalePtr[sf_index] = sf_ori; - } - } + if (blockScale.dtype() == dl_uint8) { + blockScaleInterleaveHost(blockScale, interleavedBlockScale); + } else if (blockScale.dtype() == dl_bfloat16) { + blockScaleInterleaveHost<__nv_bfloat16>(blockScale, interleavedBlockScale); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "blockScaleInterleaveHost only supports uint8 and bfloat16."; } } } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index fc6393237f..69d4026c19 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -938,6 +938,160 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { } }; +class MxInt4BlockScaleLauncher : public FusedMoeLauncher { + public: + static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; + + MxInt4BlockScaleLauncher(TensorView const& routing_logits, + Optional const& routing_bias, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& gemm1_weights_scale, + Optional const& gemm1_alpha, + Optional const& gemm1_beta, + Optional const& gemm1_clamp_limit, + TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale) + : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, + gemm1_weights, Optional(), Optional(), + gemm2_weights, Optional()), + gemm1_weights_scale(gemm1_weights_scale), + gemm2_weights_scale(gemm2_weights_scale) {} + + void init(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type) { + // currently only support mxint4 x bf16 + auto dtype = hidden_states.dtype(); + if (dtype == dl_bfloat16) { + args->mDtypeElt = btg::Dtype::Bfloat16; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + } + args->mDtypeOut = btg::Dtype::Bfloat16; + + mDtypeAct = btg::Dtype::Bfloat16; + mDtypeWeights = btg::Dtype::MxInt4; + + FusedMoeLauncher::init_common( + std::move(args), tile_tokens_dim, routing_method_type, + /*use_shuffled_weight=*/true, + static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), + static_cast(GatedActType::SwiGlu)); + } + + void check_routing() const override { FusedMoeLauncher::check_routing_common(); } + + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); + + args->mDtypeElt = mDtypeAct; + args->mUseDeepSeekFp8 = false; + // Set expert weights dtype based on routing bias + auto const routing_bias_dtype = + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; + mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + + workspace.expert_weights = expert_weights.data_ptr(); + } + + void check_moe() const override { + TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::Bfloat16) + << "Only Bfloat16 is supported by MxInt4 block scale MoE"; + + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_uint8) << "gemm1_weights must be uint8."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_bfloat16) + << "gemm1_weights_scale must be bf16."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_uint8) << "gemm2_weights must be uint8."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_bfloat16) + << "gemm2_weights_scale must be bf16."; + } + + void prepare_moe(int64_t& moe_tactic) override { + args->hidden_states = hidden_states.data_ptr(); + args->hidden_states_scale = nullptr; + args->gemm1_weights = gemm1_weights.data_ptr(); + args->gemm1_weights_scale = gemm1_weights_scale.data_ptr(); + args->gemm1_alpha = + gemm1_alpha.has_value() ? static_cast(gemm1_alpha.value().data_ptr()) : nullptr; + args->gemm1_beta = + gemm1_beta.has_value() ? static_cast(gemm1_beta.value().data_ptr()) : nullptr; + args->gemm1_clamp_limit = gemm1_clamp_limit.has_value() + ? static_cast(gemm1_clamp_limit.value().data_ptr()) + : nullptr; + args->gemm2_weights = gemm2_weights.data_ptr(); + args->gemm2_weights_scale = gemm2_weights_scale.data_ptr(); + args->output1_scales_scalar = nullptr; + args->output1_scales_gate_scalar = nullptr; + args->output2_scales_scalar = nullptr; + + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + workspace.total_max_padded_tokens, args->intermediate_size, + btg::dtypeGetNumBits(mDtypeAct)); + max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + workspace.total_max_padded_tokens, args->hidden_size, + btg::dtypeGetNumBits(btg::Dtype::Bfloat16)); // Output is always BF16 + + auto const gemm1_output_hidden = args->intermediate_size; + gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, dl_bfloat16, + hidden_states.device()); + + // Allocate gemm2_output + gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16, + hidden_states.device()); + + // Setup workspace pointers + workspace.hidden_states_scale_linear = nullptr; // MxInt4 doesn't use linear scale + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = nullptr; + // Note: activation_output and activation_output_scale are set by the base class + // prepare_moe_common() when gated activation is used + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + } + + private: + TensorView gemm1_weights_scale; + Optional gemm1_alpha; + Optional gemm1_beta; + Optional gemm1_clamp_limit; + TensorView gemm2_weights_scale; + int32_t max_num_padded_tokens_gemm1{}; + int32_t max_num_padded_tokens_gemm2{}; + + public: + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, + int64_t intermediate_size, int64_t num_local_experts, + int64_t num_tokens) { + Array> valid_configs; + + std::vector tile_sizes(mSupportedTileNums.begin(), mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(tile_sizes, num_tokens, top_k, num_local_experts); + + for (int32_t tile_N : selected_tile_nums) { + auto moe_runner = std::make_unique( + btg::Dtype::Bfloat16, btg::Dtype::MxInt4, + false, // useDeepSeekFp8 + tile_N, GatedActType::SwiGlu, + /*useShuffledMatrixA*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); + + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } + } + + return valid_configs; + } +}; + class FP4BlockScaleLauncher : public FusedMoeLauncher { public: static constexpr std::array mBaseSupportedTileNums = {8, 16, 32, 64}; @@ -1054,7 +1208,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { void check_moe() const override { TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::E2m1 || mDtypeAct == btg::Dtype::Bfloat16 || mDtypeAct == btg::Dtype::E4m3 || mDtypeAct == btg::Dtype::MxE4m3) - << "Only E2m1, Bfloat16, MxE4m3 and E4m3 are supported by block scale MoE"; + << "Only E2m1, Bfloat16, MxE4m3 and E4m3 are supported by Fp4 block scale MoE"; if (mDtypeAct == btg::Dtype::E2m1) { TVM_FFI_ICHECK(mDtypeWeights == btg::Dtype::E2m1) @@ -1628,6 +1782,89 @@ Array trtllm_fp4_block_scale_moe( return selected_launcher->run(config, enable_pdl); } +Array trtllm_mxint4_block_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView gemm1_weights, TensorView gemm1_weights_scale, Optional gemm1_alpha, + Optional gemm1_beta, Optional gemm1_clamp_limit, + TensorView gemm2_weights, TensorView gemm2_weights_scale, int64_t num_experts, int64_t top_k, + Optional n_group, Optional topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, + int64_t routing_method_type, bool enable_pdl, TensorView output, Array config_index) { + // Determine data types based on input format + int const num_tokens = hidden_states.size(0); + int hidden_size = hidden_states.size(1); + // Just some basic type validation first and leave more checks to the launcher + + int weight_scale_vec_size = + (local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel(); + + TVM_FFI_ICHECK(weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; + + TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16) + << "routing_logits must be float or bfloat16."; + TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape."; + TVM_FFI_ICHECK(!routing_bias.has_value()) << "routing_bias is not supported for MxInt4 MoE."; + + // Determine activation type + TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) + << "weights must be int4 packed in uint8."; + TVM_FFI_ICHECK(hidden_states.dtype() == dl_bfloat16) << "hidden_states must be bf16."; + + // Determine supported tile sizes + std::vector mSupportedTileN(MxInt4BlockScaleLauncher::mSupportedTileNums.begin(), + MxInt4BlockScaleLauncher::mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Create a map of launchers for each tile size + std::unordered_map> launchers_map; + + for (int32_t curr_tile_N : selected_tile_nums) { + // Create MoE arguments for this launcher + auto args = std::make_unique(); + args->num_tokens = num_tokens; + args->num_experts = num_experts; + // For E2m1, hidden_size is already multiplied by 2 above, so use it directly + args->hidden_size = hidden_size; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group.value_or(0); + args->topk_group = topk_group.value_or(0); + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); + args->do_finalize = true; + args->output = output.data_ptr(); + args->output_scale = nullptr; + + // Create and initialize launcher for this tile size + auto launcher = std::make_unique( + routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_weights_scale, + gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale); + launcher->init(std::move(args), curr_tile_N, routing_method_type); + + launchers_map[curr_tile_N] = std::move(launcher); + } + + // Extract tile_N and config from config_index + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = -1; // Let the runner choose default + } + + // Get the launcher for the selected tile_N + auto& selected_launcher = launchers_map.at(tile_N); + + // Run the launcher - it will create its own runner internally + return selected_launcher->run(config, enable_pdl); +} + Array> trtllm_get_valid_moe_configs( int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, @@ -1636,6 +1873,11 @@ Array> trtllm_get_valid_moe_configs( auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); + if (dtype_act == btg::Dtype::Bfloat16 && dtype_weights == btg::Dtype::MxInt4) { + // MxInt4 MoE + return MxInt4BlockScaleLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + } if (dtype_act == btg::Dtype::Bfloat16 && dtype_weights == btg::Dtype::Bfloat16) { // BF16 MoE return Bf16MoeLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, @@ -1680,6 +1922,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_bf16_moe, trtllm_bf16_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_per_tensor_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_moe); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mxint4_block_scale_moe, trtllm_mxint4_block_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_get_valid_moe_configs, trtllm_get_valid_moe_configs); } // namespace flashinfer diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index b520023b70..717524bc9e 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -89,7 +89,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988" + "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" ) TRTLLM_GEN_GEMM: str = ( "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" @@ -110,7 +110,7 @@ class CheckSumHash: "a5a60600a80076317703695f56bbef2f0a44075ef4e24d7b06ba67ff68bc9da2" ) TRTLLM_GEN_BMM: str = ( - "85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf" + "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 29127f06ac..0d5f6be3b8 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -264,10 +264,10 @@ def block_scale_interleave_sm100( """Swizzle block scale tensor for FP4 format. Args: - unswizzled_sf (torch.Tensor): unswizzled block scale tensor with dtype uint8. + unswizzled_sf (torch.Tensor): unswizzled block scale tensor with dtype uint8 or bfloat16. Returns: - torch.Tensor: output tensor for swizzled block scale with dtype uint8. + torch.Tensor: output tensor for swizzled block scale with dtype uint8 or bfloat16. """ num_experts = unswizzled_sf.shape[0] if unswizzled_sf.dim() == 3 else 1 expert_out_size = _compute_swizzled_layout_sf_size( @@ -275,7 +275,7 @@ def block_scale_interleave_sm100( ) out = torch.empty( (num_experts * expert_out_size,), - dtype=torch.uint8, + dtype=unswizzled_sf.dtype, device=unswizzled_sf.device, ) module.block_scale_interleave_sm100(unswizzled_sf, out) @@ -696,18 +696,18 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. Args: - unswizzled_sf (torch.Tensor): Input tensor with dtype uint8. + unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. Returns: torch.Tensor: Swizzled tensor with the same shape as input. Raises: - AssertionError: If input dtype is not uint8. + AssertionError: If input dtype is not uint8 or bfloat16. """ # TODO(shuw): check input dtype is uint8 - assert unswizzled_sf.dtype == torch.uint8, ( - f"Input dtype must be uint8, got {unswizzled_sf.dtype}" - ) + assert ( + unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 + ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" major, minor = get_compute_capability(unswizzled_sf.device) device_arch = f"{major * 10 + minor}" diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 87c207f5e0..a7d7a368db 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -31,6 +31,7 @@ trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_bf16_moe, + trtllm_mxint4_block_scale_moe, ) from .fused_routing_dsv3 import ( # noqa: F401 @@ -54,5 +55,6 @@ "trtllm_fp4_block_scale_routed_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", + "trtllm_mxint4_block_scale_moe", "NoAuxTc", ] diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7b53c3f82c..b22f4028d4 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -116,13 +116,14 @@ def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid): Int64 = (0, 1, 1, 64, 11) MxE2m1 = (1, 1, 0, 4, 12) MxE4m3 = (1, 1, 0, 8, 13) - UE8m0 = (0, 0, 0, 8, 14) - UInt8 = (0, 0, 1, 8, 15) - UInt16 = (0, 0, 1, 16, 16) - UInt32 = (0, 0, 1, 32, 17) - UInt64 = (0, 0, 1, 64, 18) - UInt128 = (0, 0, 1, 128, 19) - Void = (0, 1, 0, 0, 20) + MxInt4 = (1, 1, 1, 4, 14) + UE8m0 = (0, 0, 0, 8, 15) + UInt8 = (0, 0, 1, 8, 16) + UInt16 = (0, 0, 1, 16, 17) + UInt32 = (0, 0, 1, 32, 18) + UInt64 = (0, 0, 1, 64, 19) + UInt128 = (0, 0, 1, 128, 20) + Void = (0, 1, 0, 0, 21) def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: @@ -131,6 +132,7 @@ def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: DtypeTrtllmGen.E2m1, DtypeTrtllmGen.MxE2m1, DtypeTrtllmGen.MxE4m3, + DtypeTrtllmGen.MxInt4, ]: return True else: @@ -1153,6 +1155,34 @@ def forward( kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, ) + elif ( + self.dtype_act == DtypeTrtllmGen.Bfloat16 + and self.dtype_weights == DtypeTrtllmGen.MxInt4 + ): + moe_op.trtllm_mxint4_block_scale_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + kwargs["gemm1_weights"], + kwargs["gemm1_weights_scale"], + kwargs["gemm1_alpha"], + kwargs["gemm1_beta"], + kwargs["gemm1_clamp_limit"], + kwargs["gemm2_weights"], + kwargs["gemm2_weights_scale"], + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["routing_method_type"], + kwargs["enable_pdl"], + output, + [-1, -1] if tactic == -1 else tactic, + ) else: moe_op.trtllm_fp4_block_scale_moe( routing_logits, @@ -1851,11 +1881,167 @@ def _fake_trtllm_fp4_block_scale_moe( return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + @register_custom_op( + "flashinfer::trtllm_mxint4_block_scale_moe", + mutates_args=(""), + ) + def trtllm_mxint4_block_scale_moe_op( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + num_local_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int, + enable_pdl: Optional[bool] = None, + output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 8192, + ) -> List[torch.Tensor]: + routing_dtype = routing_logits.dtype + hidden_size = hidden_states.shape[-1] + if hidden_states.dtype == torch.uint8: + hidden_size = hidden_size * 2 + num_tokens = hidden_states.shape[0] + + # workspace buffers required by trtllm-gen + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device + ) + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) + if output is None: + output = torch.empty( + num_tokens, + hidden_size, + dtype=torch.bfloat16, + device=hidden_states.device, + ) + + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + dtype_act = DtypeTrtllmGen.Bfloat16 + dtype_weights = DtypeTrtllmGen.MxInt4 + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=num_local_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=False, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gated_act_type=GatedActType.SwiGlu, + weight_layout=WeightLayout.BlockMajorK, + use_shuffled_weight=True, + ) + tunning_config = MoERunner.tuning_config_no_hidden_states_scales + inputs = [ + output, + routing_logits, + topk_ids, + expert_weights, + hidden_states, + ] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_mxint4_block_scale_moe", + [moe_runner], + tunning_config, + inputs, + num_experts=num_experts, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + enable_pdl=enable_pdl, + ) + + # Call the C++ function for block scale MoE + moe_op.trtllm_mxint4_block_scale_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm1_weights_scale, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + num_local_experts, + routed_scaling_factor, + routing_method_type, + enable_pdl, + output, + [-1, -1] if tactic == -1 else tactic, + ) + return output + + @register_fake_op("flashinfer::trtllm_mxint4_block_scale_moe") + def _fake_trtllm_mxint4_block_scale_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int, + enable_pdl: bool, + output: Optional[torch.Tensor], + tune_max_num_tokens: int, + ): + seq_len = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + return SimpleNamespace( trtllm_bf16_moe=trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe=trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe=trtllm_fp4_block_scale_moe_op, + trtllm_mxint4_block_scale_moe=trtllm_mxint4_block_scale_moe_op, ) @@ -2352,3 +2538,95 @@ def trtllm_fp4_block_scale_routed_moe( output, tune_max_num_tokens, ) + + +@flashinfer_api +def trtllm_mxint4_block_scale_moe( + routing_logits: torch.Tensor, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + enable_pdl: Optional[bool] = None, + output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 8192, +) -> List[torch.Tensor]: + """MxInt4 block scale MoE operation. + + Args: + routing_logits (torch.Tensor): shape [seq_len, num_experts] + Input tensor of routing logits. Supports float32, bfloat16. + hidden_states (torch.Tensor): shape [seq_len, hidden_size] + Tensor of input hidden states. Supports bfloat16. + gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2] + Tensor of FC1 weights. Dtype must be uint8 (packed mxint4) + gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 32] + Scale tensor of FC1 weights. Dtype must be bfloat16. + gemm1_alpha (Optional[torch.Tensor]): shape [num_experts] + Tensor of swiglu alpha. Dtype is float32. + gemm1_beta (Optional[torch.Tensor]): shape [num_experts] + Tensor of swiglu beta. Dtype is float32. + gemm1_clamp_limit (Optional[torch.Tensor]): shape [num_experts] + Tensor of swiglu clamp limit. Dtype is float32. + gemm2_weights (torch.Tensor): shape [num_experts, hidden_size, intermediate_size] + Tensor of FC2 weights. Dtype must be uint8 (packed mxint4) + gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size, intermediate_size // 32] + Scale tensor of FC2 weights. Dtype must be bfloat16. + num_experts (int): Total number of experts + top_k (int): Number of experts to route to per token + n_group (Optional[int]): Number of expert groups (can be None for some routing methods) + topk_group (Optional[int]): Number of groups to consider for top-k routing (can be None for some routing methods) + intermediate_size (int): Size of intermediate layer + local_expert_offset (int): Offset of local experts in global expert space + local_num_experts (int): Number of experts handled by this device + routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) + routing_method_type (int): Type of routing method to use (default: 0) + - 0: Default (Softmax -> TopK) + - 1: Renormalize (TopK -> Softmax) + - 2: DeepSeekV3 (Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts) + - 3: Llama4 (Top1 -> Sigmoid) + - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) + enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + output (Optional[torch.Tensor]): shape [seq_len, hidden_size] + Optional inplace output tensor. + Returns: + torch.Tensor: returns the final MoE output. + """ + return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe( + routing_logits, + None, + hidden_states, + gemm1_weights, + gemm1_weights_scale, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + enable_pdl, + output, + tune_max_num_tokens, + ) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index e323125efa..2efcf1ff64 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -786,8 +786,8 @@ def get_shuffle_matrix_a_row_indices( def get_shuffle_matrix_sf_a_row_indices( input_tensor: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: int = 16 ) -> torch.Tensor: - assert input_tensor.dtype == torch.uint8 - assert num_elts_per_sf == 16 + assert input_tensor.dtype == torch.uint8 or input_tensor.dtype == torch.bfloat16 + assert num_elts_per_sf == 16 or num_elts_per_sf == 32 assert input_tensor.dim() == 2, ( f"input_tensor should be a 2D tensor, not {input_tensor.dim()}" diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index f93f20d28e..22d60fb4bd 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -18,6 +18,7 @@ #include #include +#include #include "BatchedGemmOptions.h" #include "KernelParams.h" @@ -122,7 +123,7 @@ struct BatchedGemmData { // Otherwise, shape is [M / 128, K / 128]. // The rightmost dimension is contiguous in memory. // - // If DeepSeek FP8 recipe is not used, but for MxFp{4,8} and NvFp4 formats: + // If DeepSeek FP8 recipe is not used, but for MxFp{4,8}, MxInt4 and NvFp4 formats: // The layout of scaling factors for A is always R128c4 // M must be a multiple of 128. // K must be a multiple of 64. @@ -132,7 +133,8 @@ struct BatchedGemmData { // Where paddedM is M if (routeAct == true && batchM), or // sum(divUpMul(M[bi], tileM) for bi in B) if batchM, // otherwise divUpMul(M, tileM) * B. - // Dtype is Dtype::Fp32 if DeepSeek FP8 recipe is used, otherwise Dtype::E4m3. + // Dtype is Dtype::Fp32 if DeepSeek FP8 recipe is used, otherwise Dtype is Dtype::E4m3 for + // NvFp4, Dtype::UE8m0 for MxFp{4,8} formats, Dtype::Bfloat16 for MxInt4. // // Otherwise should be set to nullptr. void const* mPtrSfA{nullptr}; @@ -476,9 +478,6 @@ class BatchedGemmInterface { BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /*multiProcessorCount*/, bool usePdl = true, std::optional> moduleCache = std::nullopt) { - // Might be used. - (void)usePdl; - (void)moduleCache; // Get options from config and data. auto options = getOptionsFromConfigAndData(config, batchedGemmData); @@ -579,15 +578,18 @@ class BatchedGemmInterface { static_cast(options.mClusterDimY), static_cast(options.mClusterDimZ)}; + // Whether PDL can safely be enabled + const bool pdlSafe = batchedGemmConfig.mOptions.mGridWaitForPrimaryRouting || + batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit || + batchedGemmConfig.mOptions.mGridWaitForPrimaryA || + batchedGemmConfig.mOptions.mGridWaitForPrimaryB; + // Run the kernel. - auto result = trtllm::gen::launchKernel( - (void*)&kernelParams, cudaStream, batchedGemmConfig.mSharedMemSize, cuFunction, block3, - grid3, cluster3, - usePdl && (batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit | - batchedGemmConfig.mOptions.mGridWaitForPrimaryA | - batchedGemmConfig.mOptions.mGridWaitForPrimaryB)); + auto result = trtllm::gen::launchKernel((void*)&kernelParams, cudaStream, + batchedGemmConfig.mSharedMemSize, cuFunction, block3, + grid3, cluster3, usePdl && pdlSafe); if (result != CUDA_SUCCESS) { - return -1; + return result; } // If a module cache has not been given, unload the module to avoid leaking if (!moduleCache.has_value()) { @@ -719,11 +721,8 @@ class BatchedGemmInterface { // Get options from config and data. auto options = getOptionsFromConfigAndData(config, data); - // Is Blackwell? - bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); - // Check options without modifications. - return checkAndUpdateBatchedGemmOptions(options, isBlackwell, + return checkAndUpdateBatchedGemmOptions(options, config.mSm, /* updateOptions */ false); } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index 6e53d00c17..f3e73a5aac 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -85,16 +85,17 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX, int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, - tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, - bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, - int epilogueTileN, bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, - bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, - gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, - int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, - int numEpilogueWarps, int numRegsCastAWarps, int numRegsCopySfLdsSttm, - int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, - int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + tg::Dtype dtypeMmaB, gemm::EltwiseActType eltwiseActType, bool enablesEarlyExit, + bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, + int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool fuseUtccpWithUtcmma, + bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, + gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, + bool mockAllReduce, int n, int numEpilogueWarps, int numRegsCastAWarps, + int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, + int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, @@ -114,20 +115,21 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { gemm::GemmOptions( allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, - enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, - epilogueLdtmBits, epilogueTileM, epilogueTileN, fuseUtccpWithUtcmma, - gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, - gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, - k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, - numEpilogueWarps, numRegsCastAWarps, numRegsCopySfLdsSttm, - numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, - numSlicesForSliceK, numStages, numStagesMma, numStagesMmaWithinWorkTile, - numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp, - sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, - tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, - useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap, - usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, - useTwoMmaWarps, useUnrollLoop2xForMma, validM, validN, validK, worldSize), + eltwiseActType, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, + epilogueLdtmDps, epilogueLdtmBits, epilogueTileM, epilogueTileN, + fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB, + gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, + hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, + mmaKind, mmaM, mmaN, mockAllReduce, n, numEpilogueWarps, numRegsCastAWarps, + numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, + numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma, + numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, + outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, + sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, + transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, + useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap, usePerTokenSfA, + usePerTokenSfB, useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, + useUnrollLoop2xForMma, validM, validN, validK, worldSize), actType, clampBeforeAct), mBatchedM(batchedM), mBatchedN(batchedN), @@ -182,7 +184,7 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, +inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, tg::CudaArch cudaArch, bool updateOptions = true) { bool isValid = true; if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps) { @@ -197,10 +199,9 @@ inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool i } if (options.mFusedAct) { // ensure that we check the fused options as well - isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, isBlackwell, updateOptions); + isValid = gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, cudaArch, updateOptions); } else { - isValid = - gemm::checkAndUpdateGemmOptions(options, isBlackwell, 1 /* tpGrpSize */, updateOptions); + isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch, 1 /* tpGrpSize */, updateOptions); } bool batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; @@ -346,8 +347,17 @@ inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool i // We do not handle the case where K is not a multiple of TileK. // TMA based load handles the case transparently. - if (doesRouteImplUseLdgsts(options.mRouteImpl)) { - TLLM_CHECK_ERROR(options.mK % options.mTileK == 0, "K must be a multiple of TileK"); + if (doesRouteImplUseLdgsts(options.mRouteImpl) && + doesRouteImplUseLdgPlusSts(options.mRouteSfsImpl.value())) { + TLLM_CHECK_ERROR(options.mK % options.mTileK == 0, + "K must be a multiple of TileK when using Ldg based routing"); + } + + if (options.mRouteSfsImpl.has_value() && + (doesRouteImplUseLdgsts(options.mRouteSfsImpl.value()) || + doesRouteImplUseLdgPlusSts(options.mRouteSfsImpl.value()))) { + TLLM_CHECK_ERROR(options.mK % options.mTileK == 0, + "K must be a multiple of tileK when using Ldg based SF routing"); } if (options.mClusterDimX > 1 && batchM && options.mRouteImpl != RouteImpl::NoRoute) { @@ -380,7 +390,7 @@ struct BatchedGemmConfig { int32_t mInstanceIdx{0}; BatchedGemmOptions mOptions; - gemm::SmVersion mSm{gemm::SmVersion::Sm100a}; + tg::CudaArch mSm{tg::CudaArch::Sm100a}; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h index e9d5a23a65..2a1c371ad8 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h @@ -87,6 +87,17 @@ enum class BiasType : uint32_t { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Type of the element-wise activation to apply after the Gemm +enum class EltwiseActType { + None = 0, + // Relu2 (also known as squared Relu) is defined as the following operation: + // act = relu(x0) ^ 2 + // where x0 is the output of the Gemm. + Relu2, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + enum class TileScheduler { // Static scheduler (Non-persistent). Static = 0, diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index 559118916d..9fb4a010a4 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -60,15 +60,16 @@ namespace tg = trtllm::gen; // Type of the gated activation enum class ActType { + // clang-format off // For ActType == SwiGlu, ideally we would like to have something like - // gatedAct = quantScaleC * (x0 * dequantScaleAb + beta) * ((x1 * scaleGate) * - // sigmoid(alpha * x1 * scaleGate)). + // gatedAct = quantScaleC * (x0 * dequantScaleAb + beta) * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)). // But for now, we use the simplified version - // gatedAct = scaleC * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)), + // gatedAct = scaleC * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)), // where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales, // beta' = beta / dequantScaleAb, scaleC = quantScaleC * dequantScaleAb. // // GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0. + // clang-format on SwiGlu, // For ActType == GeGlu, we use the simplified version // gatedAct = scaleC' * (x0 + beta') * ((x1 * scaleGate) * phi(alpha * x1 * scaleGate)), @@ -119,7 +120,7 @@ struct GemmGatedActOptions : public gemm::GemmOptions { // Check if the options are valid or not. inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& options, - bool isBlackwell, bool updateOptions = true) { + tg::CudaArch cudaArch, bool updateOptions = true) { // tmpOut is already transposed at this stage auto const hiddenSizeStr = options.mTransposeMmaOutput ? "M" : "N"; auto const hiddenSize = options.mTransposeMmaOutput ? options.mM : options.mN; @@ -144,7 +145,7 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); } - auto isValid = gemm::checkAndUpdateGemmOptions(options, isBlackwell, + auto isValid = gemm::checkAndUpdateGemmOptions(options, cudaArch, /* tpGrpSize */ 1, updateOptions); if (!isValid) { @@ -211,7 +212,7 @@ struct GemmGatedActConfig { int32_t mInstanceIdx{0}; GemmGatedActOptions mOptions{}; - gemm::SmVersion mSm{gemm::SmVersion::Sm100a}; + tg::CudaArch mSm{tg::CudaArch::Sm100a}; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index af6432f7a0..54daac4a8d 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -23,6 +23,7 @@ #include "Enums.h" #include "KernelParams.h" #include "KernelTraits.h" +#include "trtllm/gen/CudaArchDecl.h" #include "trtllm/gen/DtypeDecl.h" #include "trtllm/gen/MmaDecl.h" #include "trtllm/gen/SfLayoutDecl.h" @@ -106,9 +107,9 @@ struct GemmOptions { GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, - tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, - bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, - int epilogueTileM, int epilogueTileN, bool fuseUtccpWithUtcmma, + tg::Dtype dtypeMmaB, EltwiseActType eltwiseActType, bool enablesEarlyExit, + bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, + int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, @@ -139,6 +140,7 @@ struct GemmOptions { mDtypeC{dtypeC}, mDtypeMmaA{dtypeMmaA}, mDtypeMmaB{dtypeMmaB}, + mEltwiseActType{eltwiseActType}, mEnablesEarlyExit{enablesEarlyExit}, mEnablesDelayedEarlyExit{enablesDelayedEarlyExit}, mEnablesGlobalPtxKnobs{enablesGlobalPtxKnobs}, @@ -206,7 +208,6 @@ struct GemmOptions { mValidN{validN}, mValidK{validK}, mWorldSize{worldSize} {} - // The all-reduce algorithm. AllReduceAlgo mAllReduceAlgo{AllReduceAlgo::None}; // The type of bias. @@ -233,6 +234,8 @@ struct GemmOptions { tg::Dtype mDtypeMmaA{tg::Dtype::Void}; // Data type of the B matrix for the MMA, if different from the input type. tg::Dtype mDtypeMmaB{tg::Dtype::Void}; + // The type of activation. + EltwiseActType mEltwiseActType{EltwiseActType::None}; // Whether to enable early exit. bool mEnablesEarlyExit{false}; // Whether to enable delayed early exit to overlap @@ -392,14 +395,7 @@ struct GemmOptions { //////////////////////////////////////////////////////////////////////////////////////////////////// -enum class SmVersion { Sm90a, Sm100a, Sm100f, Sm103a }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline bool isSmVersionBlackwell(SmVersion smVersion) { - return smVersion == SmVersion::Sm100a || smVersion == SmVersion::Sm100f || - smVersion == SmVersion::Sm103a; -} +using SmVersion = tg::CudaArch; //////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -421,7 +417,7 @@ struct GemmConfig { int32_t mInstanceIdx{0}; GemmOptions mOptions{}; - SmVersion mSm{SmVersion::Sm100a}; + tg::CudaArch mSm{tg::CudaArch::Sm100a}; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -481,6 +477,9 @@ inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParam ss << "mDtypeMmaB=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaB) << ")" << "," << std::endl; + ss << "mEltwiseActType=" + << "gemm::EltwiseActType(" << static_cast(options.mEltwiseActType) << ")" + << "," << std::endl; ss << "mEnablesEarlyExit=" << options.mEnablesEarlyExit << "," << std::endl; ss << "mEnablesDelayedEarlyExit=" << options.mEnablesDelayedEarlyExit << "," << std::endl; ss << "mEnablesGlobalPtxKnobs=" << options.mEnablesGlobalPtxKnobs << "," << std::endl; @@ -610,10 +609,12 @@ inline int32_t getShuffleBlockSize(int epilogueTileM) { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, int tpGrpSize, +inline bool checkAndUpdateGemmOptions(GemmOptions& options, tg::CudaArch cudaArch, int tpGrpSize, bool updateOptions = true) { options.mWorldSize = tpGrpSize; + bool isBlackwell = tg::isArchBlackwell(cudaArch); + if (options.mDtypeB == tg::Dtype::Void) { if (updateOptions) { options.mDtypeB = options.mDtypeA; @@ -639,21 +640,20 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } // If validM/N/K is not specified, then assume the full range of the dimension is valid. - if (options.mValidM == -1) { - options.mValidM = options.mM; - } - if (options.mValidN == -1) { - options.mValidN = options.mN; - } - if (options.mValidK == -1) { - options.mValidK = options.mK; + if (options.mValidM < 0 || options.mValidN < 0 || options.mValidK < 0) { + if (updateOptions) { + options.mValidM = options.mValidM < 0 ? options.mM : options.mValidM; + options.mValidN = options.mValidN < 0 ? options.mN : options.mValidN; + options.mValidK = options.mValidK < 0 ? options.mK : options.mValidK; + } else { + return false; + } } // It must not exceed the padded dimensions. if (options.mValidM > options.mM || options.mValidN > options.mN || options.mValidK > options.mK) { TLLM_LOG_WARNING( - options.mValidK <= options.mK, "ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively."); if (updateOptions) { options.mValidM = std::min(options.mValidM, options.mM); @@ -684,10 +684,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in #endif // TLLM_PUBLIC_RELEASE // Check that the A cast is supported. - // Currently, we only support {MxFp4, NvFp4} -> Bf16. + // Currently, we only support {MxFp4, NvFp4, MxInt4} -> Bf16. TLLM_CHECK_ERROR( (options.mDtypeA == options.mDtypeMmaA) || - ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1) && + ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1 || + options.mDtypeA == tg::Dtype::MxInt4) && options.mDtypeMmaA == tg::Dtype::Bfloat16) || (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3), "Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", @@ -1306,6 +1307,16 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } + if (isBlackwell && !options.mUseCustomMmaSchedule && !options.mUseDeepSeekFp8 && + options.mTileScheduler == TileScheduler::Persistent) { + if (updateOptions) { + options.mUseCustomMmaSchedule = true; + } else { + TLLM_CHECK_ERROR(false, + "TileScheduler::Persistent and !UseCustomMmaSchedule is not supported."); + } + } + if (options.mEnablesDelayedEarlyExit && options.mEnablesEarlyExit) { TLLM_LOG_WARNING( "Only one of early exit and delayed early exit should be enabled. Disabling " @@ -1441,6 +1452,10 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Using more than 4 warps for epilogue does not work with sliceK"); TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, "Using more than 4 warps for epilogue does not work with mUseDeepSeekFp8"); + + auto const numEpilogueWrpGrps = options.mNumEpilogueWarps / 4; + TLLM_CHECK_ERROR(options.mTileN % (options.mEpilogueTileN * numEpilogueWrpGrps) == 0, + "TileN must be a multiple of EpilogueTileN * numEpilogueWrpGrps"); } if (updateOptions) { diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index 800c8546ef..8094f1490e 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -85,15 +85,10 @@ template static auto makeTmaShapeStrideAbc(GemmOptions const& options, int sizeM, int sizeN, int sizeK, int tileM, int tileN, int tileK, MatrixType matrixType, int validM = -1, int validN = -1, int validK = -1) { - if (validM == -1) { - validM = sizeM; - } - if (validN == -1) { - validN = sizeN; - } - if (validK == -1) { - validK = sizeK; - } + // Default to padded dimensions if not provided. + validM = validM < 0 ? sizeM : validM; + validN = validN < 0 ? sizeN : validN; + validK = validK < 0 ? sizeK : validK; // Weights matrix is A if we transpose the output of MMA (to have it M-major). // Otherwise, it is B, when the output of MMA is K-major. bool const isWeights = (matrixType == MatrixType::MatrixA && options.mTransposeMmaOutput) || @@ -412,16 +407,23 @@ static KernelParams setKernelParams( // Shape/stride for gmem tensor B. auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc( options, options.mM, useRouteAct ? options.mNumTokens : inputNumTokens, options.mK, - options.mTileM, (useRouteAct ? 1 : options.mTileN), options.mTileK, MatrixType::MatrixB); + options.mTileM, (useRouteAct ? 1 : options.mTileN), options.mTileK, MatrixType::MatrixB, + options.mValidM, useRouteAct ? options.mNumTokens : inputNumTokens, options.mValidK); // Build tma descriptor for B. params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, tileShapeB, const_cast(ptrB)); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || - options.mDtypeA == tg::Dtype::MxE2m1) { - tg::Dtype const dTypeSf = - (options.mDtypeA == tg::Dtype::E2m1) ? tg::Dtype::E4m3 : tg::Dtype::UE8m0; + options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::MxInt4) { + tg::Dtype dTypeSfA{}; + if (options.mDtypeA == tg::Dtype::E2m1) { + dTypeSfA = tg::Dtype::E4m3; + } else if (options.mDtypeA == tg::Dtype::MxInt4) { + dTypeSfA = tg::Dtype::Bfloat16; + } else { + dTypeSfA = tg::Dtype::UE8m0; + } // Build TMA descriptor for gmem A block scaling factors. auto [shapeSfA, strideSfA, tileShapesSfA] = makeTmaShapeStrideSfAb( @@ -429,7 +431,7 @@ static KernelParams setKernelParams( options.mTileM, options.mTileN, options.mTileK, tg::SfLayout::R128c4, options.mSfReshapeFactor, options.mSfBlockSizeA.value_or(tg::dtypeNumEltsPerSf(options.mDtypeA))); - params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSf, shapeSfA, strideSfA, tileShapesSfA, + params.tmaSfA[0] = gemm::buildSfTmaDescriptor(dTypeSfA, shapeSfA, strideSfA, tileShapesSfA, const_cast(dSfA)); } @@ -449,9 +451,13 @@ static KernelParams setKernelParams( auto numSfsInK = options.mK / numEltsPerSf; numSfsInK = ceilDiv(numSfsInK, 16) * 16; + auto numSfsInValidK = options.mValidK / numEltsPerSf; + numSfsInValidK = ceilDiv(numSfsInValidK, 16) * 16; + auto [shapeSfB, strideSfB, tileShapesSfB] = makeTmaShapeStrideAbc( options, options.mM, options.mNumTokens, numSfsInK, options.mTileM, 1 /* tileN */, - options.mTileK / numEltsPerSf, MatrixType::MatrixB); + options.mTileK / numEltsPerSf, MatrixType::MatrixB, options.mValidM, options.mNumTokens, + numSfsInValidK); params.tmaSfB[0] = gemm::buildNdTmaDescriptor( dTypeSf, options.mMmaKind, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB), /*doSwizzle*/ true); @@ -474,13 +480,16 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = - makeTmaShapeStrideAbc(options, options.mM, ctaOffset * options.mTileN, options.mK, - options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC, - options.mValidM, ctaOffset * options.mTileN, options.mValidK); + // NOTE: Output is *always* sanitized across the whole MNK range. This ensures maximum + // compatibility with the next BMM where unwritten part of the output could be polluted by + // NaNs. + auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( + options, options.mM, ctaOffset * options.mTileN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixC); // Build tma descriptor for C. params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, strideC, tileShapeC, ptrC); + } else { params.ptrC = ptrC; } @@ -506,9 +515,9 @@ static KernelParams setKernelParams( // The input is padded: // [act0, padding, padding, ... tileM size .., act1, padding, padding, ...] auto const inputNumTokens = ctaOffset * options.mTileM; - auto [shapeA, strideA, tileShapeA] = - makeTmaShapeStrideAbc(options, inputNumTokens, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc( + options, inputNumTokens, options.mN, options.mK, options.mTileM, options.mTileN, + options.mTileK, MatrixType::MatrixA, inputNumTokens, options.mValidN, options.mValidK); // Build tma descriptor for A. params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, tileShapeA, const_cast(ptrA)); @@ -551,10 +560,12 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = - makeTmaShapeStrideAbc(options, ctaOffset * options.mTileM, options.mN, options.mK, - options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC, - ctaOffset * options.mTileM, options.mValidN, options.mValidK); + // NOTE: Output is *always* sanitized across the whole MNK range. This ensures maximum + // compatibility with the next BMM where unwritten part of the output could be polluted by + // NaNs. + auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( + options, ctaOffset * options.mTileM, options.mN, options.mK, options.mTileM, + options.mTileN, options.mTileK, MatrixType::MatrixC); // Build tma descriptor for C. params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, strideC, tileShapeC, ptrC); diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 4ea0a91250..d7b0b6b62f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -289,6 +289,7 @@ class KernelTraits { // gmemC reuses loadAb memory for split-K in DSMEM. // Epilogue1 does not reuse and continues after the memory allocated Epilogue0 // NOTE: we can always reuse loadAb SMEM as long as we don't have persistent scheduler. + auto const reuseFirstChunksSmemStoreC = doesSplitKUseDsmem(splitK) && resIdx == 0 && !usePersistentScheduler; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index c7b18af138..fa250f8fe9 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -23,8 +23,6 @@ #ifdef TLLM_ENABLE_CUDA #include -#include -#include #endif namespace batchedGemm { @@ -57,7 +55,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if (dtype == tg::Dtype::E2m1) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; - } else if (dtype == tg::Dtype::MxE2m1) { + } else if (dtype == tg::Dtype::MxE2m1 || dtype == tg::Dtype::MxInt4) { if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4) { padMultiplier = 2; tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; @@ -197,9 +195,11 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c std::vector const& strides, const std::vector& tileShapes, void* gmemAddr) { CUtensorMap desc{}; - CUtensorMapDataType tmaDataFormat; + CUtensorMapDataType tmaDataFormat{}; if (dtype == tg::Dtype::E4m3 || dtype == tg::Dtype::UE8m0) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if (dtype == tg::Dtype::Bfloat16) { + tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else { std::cerr << "buildSfTmaDescriptor: unexpected dtype " << tg::dtypeToString(dtype) << std::endl; assert(false); diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h new file mode 100644 index 0000000000..2a1f2dcc78 --- /dev/null +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h @@ -0,0 +1,95 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Be careful when modifying this file as it is included by the generated kernels. For example, do +// not add TLLM_CHECK_* constructs in this file. Thanks! +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace batchedGemm { + +namespace trtllm { +namespace gen { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class CudaArch { + // Hopper + Sm90a = 0, + // Blackwell + Sm100a, + // Blackwell-family + Sm100f, + // Blackwell Ultra + Sm103a, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool isArchHopper(CudaArch cudaArch) { return cudaArch == CudaArch::Sm90a; } + +inline bool isArchBlackwell(CudaArch cudaArch) { + return cudaArch == CudaArch::Sm100a || cudaArch == CudaArch::Sm100f || + cudaArch == CudaArch::Sm103a; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::string cudaArchToString(CudaArch cudaArch, bool isFull = true) { + switch (cudaArch) { + case CudaArch::Sm90a: + return isFull ? "90a" : "90"; + case CudaArch::Sm100a: + return isFull ? "100a" : "100"; + case CudaArch::Sm100f: + return isFull ? "100f" : "100"; + case CudaArch::Sm103a: + return isFull ? "103a" : "103"; + default: + assert(false); + return ""; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline CudaArch stringToCudaArch(std::string const& str) { + if (str == "90a") { + return CudaArch::Sm90a; + } else if (str == "100a") { + return CudaArch::Sm100a; + } else if (str == "100f") { + return CudaArch::Sm100f; + } else if (str == "103a") { + return CudaArch::Sm103a; + } else { + assert(false); + return CudaArch::Sm100a; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gen +} // namespace trtllm +} // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h index 0866256492..355cfba961 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h @@ -70,13 +70,14 @@ enum class Dtype : uint32_t { Int64 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 1u, /*int*/ 1u, /*bits*/ 64u, /*uid*/ 11u), MxE2m1 = TLLM_ENCODE_DTYPE(/*block*/ 1u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 4u, /*uid*/ 12u), MxE4m3 = TLLM_ENCODE_DTYPE(/*block*/ 1u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 8u, /*uid*/ 13u), - UE8m0 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 0u, /*bits*/ 8u, /*uid*/ 14u), - UInt8 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 8u, /*uid*/ 15u), - UInt16 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 16u, /*uid*/ 16u), - UInt32 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 32u, /*uid*/ 17u), - UInt64 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 64u, /*uid*/ 18u), - UInt128 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 128u, /*uid*/ 19u), - Void = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 0u, /*uid*/ 20u), + MxInt4 = TLLM_ENCODE_DTYPE(/*block*/ 1u, /*signed*/ 1u, /*int*/ 1u, /*bits*/ 4u, /*uid*/ 14u), + UE8m0 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 0u, /*bits*/ 8u, /*uid*/ 15u), + UInt8 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 8u, /*uid*/ 16u), + UInt16 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 16u, /*uid*/ 17u), + UInt32 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 32u, /*uid*/ 18u), + UInt64 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 64u, /*uid*/ 19u), + UInt128 = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 0u, /*int*/ 1u, /*bits*/ 128u, /*uid*/ 20u), + Void = TLLM_ENCODE_DTYPE(/*block*/ 0u, /*signed*/ 1u, /*int*/ 0u, /*bits*/ 0u, /*uid*/ 21u), // clang-format on #undef TLLM_ENCODE_DTYPE @@ -160,6 +161,8 @@ inline std::string dtypeToString(Dtype dtype) { return "MxE4m3"; case Dtype::MxE2m1: return "MxE2m1"; + case Dtype::MxInt4: + return "MxInt4"; case Dtype::UE8m0: return "UE8m0"; case Dtype::UInt8: @@ -201,6 +204,7 @@ inline int dtypeNumEltsPerSf(Dtype dtype) { return 16; case Dtype::MxE2m1: case Dtype::MxE4m3: + case Dtype::MxInt4: return 32; default: assert(false); @@ -218,6 +222,8 @@ inline Dtype dtypeGetBlockSfType(Dtype dtype) { case Dtype::MxE2m1: case Dtype::MxE4m3: return Dtype::UE8m0; + case Dtype::MxInt4: + return Dtype::Bfloat16; default: assert(false); return Dtype::Void; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 1a78243593..c1ee02b91f 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -40,6 +40,7 @@ trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_bf16_moe, + trtllm_mxint4_block_scale_moe, ) from flashinfer.fused_moe.core import ( get_w2_permute_indices_with_cache, @@ -582,6 +583,216 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} +# ==================================================================================== +# MxInt4 Block Scale Quantization Implementation +# ==================================================================================== + + +def mxint4_quantize( + x: torch.Tensor, sf_vec_size: int = 32 +) -> tuple[torch.Tensor, torch.Tensor]: + x_reshaped = x.reshape(-1, sf_vec_size) + x_max = x_reshaped.max(dim=-1, keepdim=True)[0].to(torch.float32) + x_min = x_reshaped.min(dim=-1, keepdim=True)[0].to(torch.float32) + x_max = x_max * 8.0 / 7.0 + amax = torch.where(x_max > -x_min, x_max, -x_min) + scales = amax / 8.0 + x_scaled = x_reshaped * scales.reciprocal() + x_int8 = ( + x_scaled.round().clamp(-8, 7).to(torch.int8).reshape(-1, sf_vec_size // 2, 2) + ) + x_int4 = (x_int8[..., 0] & 0x0F) | ((x_int8[..., 1] & 0x0F) << 4) + return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.reshape( + -1, sf_vec_size + ) + + +class MxInt4BlockScaleMoe(Moe): + """MxInt4 MoE implementation with block scaling (DeepSeek style).""" + + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): + """Quantize weights to MxInt4 with block scaling.""" + num_experts = gemm1_weights.shape[0] + intermediate_size = gemm1_weights.shape[1] // 2 + hidden_size = gemm1_weights.shape[ + 2 + ] # [num_experts, 2*intermediate_size, hidden_size] + + # Quantize weights to MxInt4 + sf_vec_size = 32 + gemm1_weights_int4, gemm1_scales = mxint4_quantize(gemm1_weights, sf_vec_size) + gemm2_weights_int4, gemm2_scales = mxint4_quantize(gemm2_weights, sf_vec_size) + gemm1_scales = gemm1_scales.to(torch.bfloat16).reshape( + num_experts, + 2 * intermediate_size, + hidden_size // sf_vec_size, + ) + gemm2_scales = gemm2_scales.to(torch.bfloat16).reshape( + num_experts, hidden_size, intermediate_size // sf_vec_size + ) + return { + "hidden_states_scale_global": None, + "gemm1_weights": gemm1_weights_int4, + "gemm2_weights": gemm2_weights_int4, + "gemm1_scales": gemm1_scales, + "gemm2_scales": gemm2_scales, + "gemm1_scales_global": None, + "gemm2_scales_global": None, + } + + def quantize_inputs(self, hidden_states, *unused_args): + """No scaling for hidden states.""" + return { + "hidden_states": hidden_states.to(torch.bfloat16), + "hidden_states_scale": None, + } + + def prepare_static_weights_for_kernel( + self, + args_dequant, + args, + gemm1_weights_orig, + gemm2_weights_orig, + hidden_size, + intermediate_size, + num_experts, + weight_processing, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + + epilogue_tile_m = 128 + gemm1_weights_mxint4_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_mxint4_shuffled = [] + gemm2_scales_shuffled = [] + + for i in range(num_experts): + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + args.gemm1_weights[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_shuffled = ( + args.gemm1_weights[i] + .view(torch.uint8)[permute_indices.to(args.gemm1_weights.device)] + .contiguous() + ) + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + args.gemm1_scales[i].view(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=32, + ) + gemm1_scales_shuffled.append( + block_scale_interleave( + args.gemm1_scales[i] + .view(torch.bfloat16)[ + permute_sf_indices.to(args.gemm1_scales.device) + ] + .contiguous() + ) + ) + + permute_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + args.gemm2_weights[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_shuffled = ( + args.gemm2_weights[i] + .view(torch.uint8)[permute_indices.to(args.gemm2_weights.device)] + .contiguous() + ) + + permute_sf_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + args.gemm2_scales[i].view(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + block_scale_interleave( + args.gemm2_scales[i] + .view(torch.bfloat16)[ + permute_sf_indices.to(args.gemm2_scales.device) + ] + .contiguous() + ) + ) + + block_k = 128 + gemm1_weights_shuffled = convert_to_block_layout( + gemm1_weights_shuffled, block_k + ) + gemm2_weights_shuffled = convert_to_block_layout( + gemm2_weights_shuffled.view(torch.uint8), block_k + ) + + gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled) + gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled) + + gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled) + gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled) + gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16) + gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16) + + return { + "gemm1_weights": gemm1_weights_mxint4_shuffled, + "gemm1_scales": gemm1_scales_shuffled, + "gemm2_weights": gemm2_weights_mxint4_shuffled, + "gemm2_scales": gemm2_scales_shuffled, + } + + def call_moe( + self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs + ): + """Call MoE with runtime input quantization + kernel execution (done at runtime).""" + expert_logits = kwargs["expert_logits"] + num_experts = kwargs["num_experts"] + top_k = kwargs["top_k"] + n_groups = kwargs["n_groups"] + top_k_groups = kwargs["top_k_groups"] + intermediate_size = kwargs["intermediate_size"] + routing_method_type = kwargs["routing_method_type"] + enable_autotune = kwargs.get("enable_autotune", True) + + # Use autotuner for optimal kernel selection + with autotune(enable_autotune): + output = trtllm_mxint4_block_scale_moe( + expert_logits, # float + hidden_states_orig, + static_data["gemm1_weights"], + static_data["gemm1_scales"], + None, + None, + None, + static_data["gemm2_weights"], + static_data["gemm2_scales"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + 1.0, + routing_method_type=routing_method_type, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + ) + return output.to(torch.float) + + def compute_reference(self, args): + return run_moe_reference_mxint4(args) + + def get_tolerances(self): + """Get MXINT4-specific accuracy tolerances.""" + return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} + + # ==================================================================================== # FP8 Block Scale Quantization Implementation # ==================================================================================== @@ -1726,10 +1937,7 @@ def run_moe_dequant(args, quant_mode: QuantMode): .to(torch.float) ) args.c_global_sf = 1.0 - elif quant_mode == QuantMode.BF16: - activation_output = activation_output.to(torch.bfloat16).to(torch.float) - args.c_global_sf = 1.0 - else: # mxfp4Bf16 + else: # Bf16, MxFp4xBf16, MxInt4xBf16 activation_output = activation_output.to(torch.bfloat16).to(torch.float) args.c_global_sf = 1.0 @@ -1965,6 +2173,57 @@ def run_moe_reference_bf16(args): return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant +def run_moe_reference_mxint4(args): + sf_vec_size = 32 + + hidden_states_dequant = args.hidden_states.to(torch.bfloat16).to(torch.float) + + num_experts = args.gemm1_weights.shape[0] + + def dequantize(weights, scales): + k = weights.shape[-1] * 2 + n = weights.shape[-2] + # Unpack two 4-bit values (stored in two's-complement) from each byte + weights_int8 = ( + torch.stack([weights & 0x0F, (weights >> 4) & 0x0F], dim=-1) + .reshape(num_experts, n, k) + .to(torch.int8) + ) + + # Interpret nibbles as signed 4-bit two's-complement values in [-8, 7] + weights_int8 = torch.where(weights_int8 < 8, weights_int8, weights_int8 - 16) + + weights_float = weights_int8.to(torch.float) + scales_expanded = ( + scales.to(torch.bfloat16) + .to(torch.float) + .repeat_interleave(sf_vec_size, dim=-1) + .reshape(weights_float.shape) + ) + return weights_float * scales_expanded + + gemm1_weights_dequant = dequantize(args.gemm1_weights, args.gemm1_scales) + gemm2_weights_dequant = dequantize(args.gemm2_weights, args.gemm2_scales) + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + args.gated_act_type, + ) + + return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant + + def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): """Unified actual computation that delegates to implementation-specific methods.""" # 1. Prepare static weights for the kernel (offline processing) @@ -2219,6 +2478,7 @@ def run_moe_test( pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"), ], ) @pytest.mark.parametrize( @@ -2239,6 +2499,7 @@ def run_moe_test( FP8BlockScaleMoe, FP4Moe, BF16Moe, + MxInt4BlockScaleMoe, ], "compatible_intermediate_size": [384, 768, 1024], "enable_autotune": True, @@ -2260,6 +2521,7 @@ def run_moe_test( FP8BlockScaleMoe, FP4Moe, BF16Moe, + MxInt4BlockScaleMoe, ], "compatible_intermediate_size": [384, 1024], "enable_autotune": False, @@ -2281,6 +2543,7 @@ def run_moe_test( FP8BlockScaleMoe, FP4Moe, BF16Moe, + MxInt4BlockScaleMoe, ], "compatible_intermediate_size": [512], "enable_autotune": True, @@ -2312,7 +2575,11 @@ def run_moe_test( { "use_shuffled_weight": True, "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe, BF16Moe], + "compatible_moe_impls": [ + FP8BlockScaleMoe, + BF16Moe, + MxInt4BlockScaleMoe, + ], }, id="Shuffled_BlockMajorK", ), diff --git a/tests/moe/utils.py b/tests/moe/utils.py index ebaf85c189..5a8f932117 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -30,6 +30,7 @@ class QuantMode(IntEnum): FP8_BLOCK_SCALE = 4 FP8_PER_TENSOR = 5 BF16 = 6 + MXINT4_BF16_BF16 = 7 def skip_checks( @@ -85,6 +86,13 @@ def skip_checks( f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" ) + if type(moe_impl).__name__ == "MxInt4BlockScaleMoe" and ( + intermediate_size % 256 != 0 or hidden_size % 256 != 0 + ): + pytest.skip( + f"Incompatible: intermediate_size={intermediate_size} or hidden_size={hidden_size} with MXINT4_BF16_BF16 quantization" + ) + # TODO(jimmzhou): enable MxFP4xBf16 on SM103 if ( is_fp4_moe