@@ -592,10 +592,15 @@ def mxint4_quantize(
592592 x : torch .Tensor , sf_vec_size : int = 32
593593) -> tuple [torch .Tensor , torch .Tensor ]:
594594 x_reshaped = x .reshape (- 1 , sf_vec_size )
595- amax = torch .abs (x_reshaped ).max (dim = - 1 , keepdim = True )[0 ].to (torch .float32 )
596- scales = amax / 7.0
595+ x_max = x_reshaped .max (dim = - 1 , keepdim = True )[0 ].to (torch .float32 )
596+ x_min = x_reshaped .min (dim = - 1 , keepdim = True )[0 ].to (torch .float32 )
597+ x_max = x_max * 8.0 / 7.0
598+ amax = torch .where (x_max > - x_min , x_max , - x_min )
599+ scales = amax / 8.0
597600 x_scaled = x_reshaped * scales .reciprocal ()
598- x_int8 = x_scaled .to (torch .int8 ).reshape (- 1 , sf_vec_size // 2 , 2 )
601+ x_int8 = (
602+ x_scaled .round ().clamp (- 8 , 7 ).to (torch .int8 ).reshape (- 1 , sf_vec_size // 2 , 2 )
603+ )
599604 x_int4 = (x_int8 [..., 0 ] & 0x0F ) | ((x_int8 [..., 1 ] & 0x0F ) << 4 )
600605 return x_int4 .reshape (* x .shape [:- 1 ], x .shape [- 1 ] // 2 ), scales .reshape (
601606 - 1 , sf_vec_size
@@ -655,12 +660,12 @@ def prepare_static_weights_for_kernel(
655660 ):
656661 """Prepare quantized weights for kernel (done offline with weights)."""
657662
658- # TODO: is this correct for mxint4 x bf16 kernel?
659663 epilogue_tile_m = 128
660-
661- # TODO: should we shuffle the weights and/or scales here?
662664 gemm1_weights_mxint4_shuffled = []
665+ gemm1_scales_shuffled = []
663666 gemm2_weights_mxint4_shuffled = []
667+ gemm2_scales_shuffled = []
668+
664669 for i in range (num_experts ):
665670 # Calculate the permute indices for the following:
666671 # 1. Reorder rows of W1 and scales for fused gated activation
@@ -676,6 +681,21 @@ def prepare_static_weights_for_kernel(
676681 .view (torch .uint8 )[permute_indices .to (args .gemm1_weights .device )]
677682 .contiguous ()
678683 )
684+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices (
685+ self ._cache_permute_indices ,
686+ args .gemm1_scales [i ].view (torch .bfloat16 ),
687+ epilogue_tile_m ,
688+ num_elts_per_sf = 32 ,
689+ )
690+ gemm1_scales_shuffled .append (
691+ block_scale_interleave (
692+ args .gemm1_scales [i ]
693+ .view (torch .bfloat16 )[
694+ permute_sf_indices .to (args .gemm1_scales .device )
695+ ]
696+ .contiguous ()
697+ )
698+ )
679699
680700 permute_indices = get_w2_permute_indices_with_cache (
681701 self ._cache_permute_indices ,
@@ -688,25 +708,43 @@ def prepare_static_weights_for_kernel(
688708 .contiguous ()
689709 )
690710
711+ permute_sf_indices = get_w2_permute_indices_with_cache (
712+ self ._cache_permute_indices ,
713+ args .gemm2_scales [i ].view (torch .bfloat16 ),
714+ epilogue_tile_m ,
715+ num_elts_per_sf = 16 ,
716+ )
717+ gemm2_scales_shuffled .append (
718+ block_scale_interleave (
719+ args .gemm2_scales [i ]
720+ .view (torch .bfloat16 )[
721+ permute_sf_indices .to (args .gemm2_scales .device )
722+ ]
723+ .contiguous ()
724+ )
725+ )
726+
691727 block_k = 128
692728 gemm1_weights_shuffled = convert_to_block_layout (
693729 gemm1_weights_shuffled , block_k
694730 )
695731 gemm2_weights_shuffled = convert_to_block_layout (
696- gemm2_weights_shuffled , block_k
732+ gemm2_weights_shuffled . view ( torch . uint8 ) , block_k
697733 )
698734
699735 gemm1_weights_mxint4_shuffled .append (gemm1_weights_shuffled )
700736 gemm2_weights_mxint4_shuffled .append (gemm2_weights_shuffled )
701737
702738 gemm1_weights_mxint4_shuffled = torch .stack (gemm1_weights_mxint4_shuffled )
703739 gemm2_weights_mxint4_shuffled = torch .stack (gemm2_weights_mxint4_shuffled )
740+ gemm1_scales_shuffled = torch .stack (gemm1_scales_shuffled ).view (torch .bfloat16 )
741+ gemm2_scales_shuffled = torch .stack (gemm2_scales_shuffled ).view (torch .bfloat16 )
704742
705743 return {
706744 "gemm1_weights" : gemm1_weights_mxint4_shuffled ,
707- "gemm1_scales" : args . gemm1_scales ,
745+ "gemm1_scales" : gemm1_scales_shuffled ,
708746 "gemm2_weights" : gemm2_weights_mxint4_shuffled ,
709- "gemm2_scales" : args . gemm2_scales ,
747+ "gemm2_scales" : gemm2_scales_shuffled ,
710748 }
711749
712750 def call_moe (
@@ -2145,10 +2183,17 @@ def run_moe_reference_mxint4(args):
21452183 def dequantize (weights , scales ):
21462184 k = weights .shape [- 1 ] * 2
21472185 n = weights .shape [- 2 ]
2148- weights_int8 = torch .stack (
2149- [weights & 0x0F , (weights >> 4 ) & 0x0F ], dim = - 1
2150- ).reshape (num_experts , n , k )
2151- weights_float = weights_int8 .to (torch .bfloat16 ).to (torch .float )
2186+ # Unpack two 4-bit values (stored in two's-complement) from each byte
2187+ weights_int8 = (
2188+ torch .stack ([weights & 0x0F , (weights >> 4 ) & 0x0F ], dim = - 1 )
2189+ .reshape (num_experts , n , k )
2190+ .to (torch .int8 )
2191+ )
2192+
2193+ # Interpret nibbles as signed 4-bit two's-complement values in [-8, 7]
2194+ weights_int8 = torch .where (weights_int8 < 8 , weights_int8 , weights_int8 - 16 )
2195+
2196+ weights_float = weights_int8 .to (torch .float )
21522197 scales_expanded = (
21532198 scales .to (torch .bfloat16 )
21542199 .to (torch .float )
@@ -2427,12 +2472,12 @@ def run_moe_test(
24272472@pytest .mark .parametrize (
24282473 "moe_impl" ,
24292474 [
2430- pytest .param (BF16Moe (), id = "BF16xBF16" ),
2431- pytest .param (FP8BlockScaleMoe (), id = "FP8_Block" ),
2432- pytest .param (FP8PerTensorMoe (), id = "FP8_Tensor" ),
2433- pytest .param (FP4Moe (quant_mode = QuantMode .FP4_NVFP4_NVFP4 ), id = "NvFP4xNvFP4" ),
2434- pytest .param (FP4Moe (quant_mode = QuantMode .FP4_MXFP4_MXFP8 ), id = "MxFP4xMxFP8" ),
2435- pytest .param (FP4Moe (quant_mode = QuantMode .FP4_MXFP4_Bf16 ), id = "MxFP4xBf16" ),
2475+ # pytest.param(BF16Moe(), id="BF16xBF16"),
2476+ # pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
2477+ # pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
2478+ # pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
2479+ # pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
2480+ # pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
24362481 pytest .param (MxInt4BlockScaleMoe (), id = "MxInt4xBf16" ),
24372482 ],
24382483)
0 commit comments