Skip to content

Commit b407a47

Browse files
committed
Address review comments
1 parent ce8add2 commit b407a47

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -575,12 +575,11 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
575575
tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt")
576576

577577

578-
@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024), [512, 1024, 2048], [2048, 2048, 2048]])
579-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(256, 256, 256), (128, 128, 256), (128, 128, 512), [32, 32, 64]])
578+
@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024)])
579+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 256), (64, 64, 512), [32, 32, 64]])
580580
@pytest.mark.parametrize("matrix_instr_nonkdim", [16, 32])
581581
@pytest.mark.parametrize("preshuffle", [True, False])
582-
@pytest.mark.skipif(is_cuda(), reason="AMD specific scale shuffling")
583-
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires hardware support for scaled mfma instructions")
582+
@pytest.mark.skipif(is_hip() and not is_hip_cdna4(), reason="Requires hardware support for scaled mfma instructions")
584583
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, matrix_instr_nonkdim, preshuffle, device):
585584
# This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
586585
#

0 commit comments

Comments
 (0)