@@ -575,13 +575,12 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
575575 tl .store (c_ptrs , c , mask = c_mask , cache_modifier = ".wt" )
576576
577577
578- @pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 ), [ 512 , 1024 , 2048 ], [ 2048 , 2048 , 2048 ] ])
579- @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(256 , 256 , 256 ), ( 128 , 128 , 256 ), (128 , 128 , 512 ), [32 , 32 , 64 ]])
580- @pytest .mark .parametrize ("matrix_instr_nonkdim " , [16 , 32 ])
578+ @pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 )])
579+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 256 ), (64 , 64 , 512 ), [32 , 32 , 64 ]])
580+ @pytest .mark .parametrize ("mfma_nonkdim " , [16 , 32 ])
581581@pytest .mark .parametrize ("preshuffle" , [True , False ])
582- @pytest .mark .skipif (is_cuda (), reason = "AMD specific scale shuffling" )
583- @pytest .mark .skipif (not is_hip_cdna4 (), reason = "Requires hardware support for scaled mfma instructions" )
584- def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , matrix_instr_nonkdim , preshuffle , device ):
582+ @pytest .mark .skipif (is_hip () and not is_hip_cdna4 (), reason = "Requires hardware support for scaled mfma instructions" )
583+ def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , device ):
585584 # This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
586585 #
587586 # Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
@@ -638,10 +637,10 @@ def shuffle_scales_cdna4(scales: torch.Tensor):
638637 scales_shuffled = scales .clone ()
639638
640639 sm , sn = scales_shuffled .shape
641- if matrix_instr_nonkdim == 32 :
640+ if mfma_nonkdim == 32 :
642641 scales_shuffled = scales_shuffled .view (sm // 32 , 32 , sn // 8 , 4 , 2 , 1 )
643642 scales_shuffled = scales_shuffled .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
644- elif matrix_instr_nonkdim == 16 :
643+ elif mfma_nonkdim == 16 :
645644 scales_shuffled = scales_shuffled .view (sm // 32 , 2 , 16 , sn // 8 , 2 , 4 , 1 )
646645 scales_shuffled = scales_shuffled .permute (0 , 3 , 5 , 2 , 4 , 1 , 6 ).contiguous ()
647646
@@ -701,8 +700,9 @@ def generate_gemm_afp4wfp4_inputs(M, N, K):
701700 w = w .T
702701 triton_out = torch .empty ((M , N ), device = x .device )
703702
704- # name collision when passing kernel parameter with same name as "meta" parameter
705- mfma_nonkdim = matrix_instr_nonkdim
703+ kernel_kwargs = {}
704+ if is_hip ():
705+ kernel_kwargs ["matrix_instr_nonkdim" ] = mfma_nonkdim
706706
707707 grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
708708 _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4 [grid ](x , w , triton_out , x_scales_triton , w_scales_triton , M , N , K ,
@@ -711,7 +711,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K):
711711 x_scales_triton .stride (0 ), x_scales_triton .stride (1 ),
712712 w_scales_triton .stride (0 ), w_scales_triton .stride (1 ), BLOCK_M ,
713713 BLOCK_N , BLOCK_K , mfma_nonkdim , preshuffle , num_warps = 8 ,
714- num_stages = 1 , matrix_instr_nonkdim = matrix_instr_nonkdim )
714+ num_stages = 1 , ** kernel_kwargs )
715715 triton_out = triton_out .to (torch .float32 )
716716 torch .testing .assert_close (torch_out , triton_out )
717717
0 commit comments