@@ -575,12 +575,11 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
575575 tl .store (c_ptrs , c , mask = c_mask , cache_modifier = ".wt" )
576576
577577
578- @pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 ), [ 512 , 1024 , 2048 ], [ 2048 , 2048 , 2048 ] ])
579- @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(256 , 256 , 256 ), ( 128 , 128 , 256 ), (128 , 128 , 512 ), [32 , 32 , 64 ]])
578+ @pytest .mark .parametrize ("M, N, K" , [(1024 , 1024 , 1024 )])
579+ @pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 256 ), (64 , 64 , 512 ), [32 , 32 , 64 ]])
580580@pytest .mark .parametrize ("matrix_instr_nonkdim" , [16 , 32 ])
581581@pytest .mark .parametrize ("preshuffle" , [True , False ])
582- @pytest .mark .skipif (is_cuda (), reason = "AMD specific scale shuffling" )
583- @pytest .mark .skipif (not is_hip_cdna4 (), reason = "Requires hardware support for scaled mfma instructions" )
582+ @pytest .mark .skipif (is_hip () and not is_hip_cdna4 (), reason = "Requires hardware support for scaled mfma instructions" )
584583def test_preshuffle_scale_mxfp_cdna4 (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , matrix_instr_nonkdim , preshuffle , device ):
585584 # This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
586585 #
0 commit comments