Skip to content

Commit 8a294c7

Browse files
committed
Address review comments
1 parent ce8add2 commit 8a294c7

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -575,13 +575,12 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale
575575
tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt")
576576

577577

578-
@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024), [512, 1024, 2048], [2048, 2048, 2048]])
579-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(256, 256, 256), (128, 128, 256), (128, 128, 512), [32, 32, 64]])
580-
@pytest.mark.parametrize("matrix_instr_nonkdim", [16, 32])
578+
@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024)])
579+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 256), (64, 64, 512), [32, 32, 64]])
580+
@pytest.mark.parametrize("mfma_nonkdim", [16, 32])
581581
@pytest.mark.parametrize("preshuffle", [True, False])
582-
@pytest.mark.skipif(is_cuda(), reason="AMD specific scale shuffling")
583-
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires hardware support for scaled mfma instructions")
584-
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, matrix_instr_nonkdim, preshuffle, device):
582+
@pytest.mark.skipif(is_hip() and not is_hip_cdna4(), reason="Requires hardware support for scaled mfma instructions")
583+
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, device):
585584
# This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
586585
#
587586
# Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
@@ -638,10 +637,10 @@ def shuffle_scales_cdna4(scales: torch.Tensor):
638637
scales_shuffled = scales.clone()
639638

640639
sm, sn = scales_shuffled.shape
641-
if matrix_instr_nonkdim == 32:
640+
if mfma_nonkdim == 32:
642641
scales_shuffled = scales_shuffled.view(sm // 32, 32, sn // 8, 4, 2, 1)
643642
scales_shuffled = scales_shuffled.permute(0, 2, 4, 1, 3, 5).contiguous()
644-
elif matrix_instr_nonkdim == 16:
643+
elif mfma_nonkdim == 16:
645644
scales_shuffled = scales_shuffled.view(sm // 32, 2, 16, sn // 8, 2, 4, 1)
646645
scales_shuffled = scales_shuffled.permute(0, 3, 5, 2, 4, 1, 6).contiguous()
647646

@@ -701,8 +700,9 @@ def generate_gemm_afp4wfp4_inputs(M, N, K):
701700
w = w.T
702701
triton_out = torch.empty((M, N), device=x.device)
703702

704-
# name collision when passing kernel parameter with same name as "meta" parameter
705-
mfma_nonkdim = matrix_instr_nonkdim
703+
kernel_kwargs = {}
704+
if is_hip():
705+
kernel_kwargs["matrix_instr_nonkdim"] = mfma_nonkdim
706706

707707
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
708708
_gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton, w_scales_triton, M, N, K,
@@ -711,7 +711,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K):
711711
x_scales_triton.stride(0), x_scales_triton.stride(1),
712712
w_scales_triton.stride(0), w_scales_triton.stride(1), BLOCK_M,
713713
BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, num_warps=8,
714-
num_stages=1, matrix_instr_nonkdim=matrix_instr_nonkdim)
714+
num_stages=1, **kernel_kwargs)
715715
triton_out = triton_out.to(torch.float32)
716716
torch.testing.assert_close(torch_out, triton_out)
717717

0 commit comments

Comments
 (0)