Skip to content

Commit 8222437

Browse files
committed
fix: address review
Signed-off-by: Nikita Korobov <[email protected]>
1 parent 7e9ff16 commit 8222437

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void blockScaleInterleaveHost(TensorView blockScale, TensorView interleavedBlock
155155
auto globalRowIdx = eIdx * rows + rIdx;
156156
T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
157157
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
158-
uint8_t sf_ori = 0;
158+
T sf_ori = 0;
159159
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
160160
sf_ori = blockScalePtr[cIdx];
161161
}

flashinfer/fused_moe/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,22 +2568,22 @@ def trtllm_mxint4_block_scale_moe(
25682568
Args:
25692569
routing_logits (torch.Tensor): shape [seq_len, num_experts]
25702570
Input tensor of routing logits. Supports float32, bfloat16.
2571-
hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size]
2572-
Tensor of input hidden states. Supports bfloat16, mxfp8, and nvfp4 (packed into uint8)
2571+
hidden_states (torch.Tensor): shape [seq_len, hidden_size]
2572+
Tensor of input hidden states. Supports bfloat16.
25732573
gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
2574-
Tensor of FC1 weights. Dtype must be uint8 (packed fp4)
2575-
gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)]
2576-
Scale tensor of FC1 weights. Dtype must be float8.
2574+
Tensor of FC1 weights. Dtype must be uint8 (packed mxint4)
2575+
gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 32]
2576+
Scale tensor of FC1 weights. Dtype must be bfloat16.
25772577
gemm1_alpha (Optional[torch.Tensor]): shape [num_experts]
25782578
Tensor of swiglu alpha. Dtype is float32.
25792579
gemm1_beta (Optional[torch.Tensor]): shape [num_experts]
25802580
Tensor of swiglu beta. Dtype is float32.
25812581
gemm1_clamp_limit (Optional[torch.Tensor]): shape [num_experts]
25822582
Tensor of swiglu clamp limit. Dtype is float32.
25832583
gemm2_weights (torch.Tensor): shape [num_experts, hidden_size, intermediate_size]
2584-
Tensor of FC2 weights. Dtype must be uint8 (packed fp4)
2585-
gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size, intermediate_size // (32 if mxfp4 else 16)]
2586-
Scale tensor of FC2 weights. Dtype must be float8.
2584+
Tensor of FC2 weights. Dtype must be uint8 (packed mxint4)
2585+
gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size, intermediate_size // 32]
2586+
Scale tensor of FC2 weights. Dtype must be bfloat16.
25872587
num_experts (int): Total number of experts
25882588
top_k (int): Number of experts to route to per token
25892589
n_group (Optional[int]): Number of expert groups (can be None for some routing methods)

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def compute_reference(self, args):
789789
return run_moe_reference_mxint4(args)
790790

791791
def get_tolerances(self):
792-
"""Get FP4-specific accuracy tolerances."""
792+
"""Get MXINT4-specific accuracy tolerances."""
793793
return {"atol": 0.1, "rtol": 0.85, "percent": 0.925}
794794

795795

@@ -2472,12 +2472,12 @@ def run_moe_test(
24722472
@pytest.mark.parametrize(
24732473
"moe_impl",
24742474
[
2475-
# pytest.param(BF16Moe(), id="BF16xBF16"),
2476-
# pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
2477-
# pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
2478-
# pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
2479-
# pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
2480-
# pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
2475+
pytest.param(BF16Moe(), id="BF16xBF16"),
2476+
pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
2477+
pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
2478+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
2479+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
2480+
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
24812481
pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"),
24822482
],
24832483
)

0 commit comments

Comments
 (0)