Skip to content
238 changes: 145 additions & 93 deletions python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,9 @@ def block_scale_mxfp_matmul( #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_STAGES: tl.constexpr, USE_2D_SCALE_LOAD: tl.constexpr):
## This kernel assumes a_scale and b_scale are coming in with shapes
## [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance
## on nvidia sm100+ HW
# This kernel assumes a_scale and b_scale are coming in with shapes
# [BLOCK_M(or N) // 128, BLOCK_K // 128, 32, 4, 4] for optimial performance
# on nvidia sm100+ HW
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
Expand Down Expand Up @@ -482,18 +482,21 @@ def block_scale_mxfp_matmul( #


@triton.jit
def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scales_ptr, b_scales_ptr, M, N, K, stride_am,
stride_ak, stride_bk, stride_bn, stride_ck, stride_cm, stride_cn,
stride_asm, stride_ask, stride_bsn, stride_bsk,
# Meta-parameters
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
mfma_nonkdim: tl.constexpr, preshuffle: tl.constexpr):
def _gemm_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scales_ptr, b_scales_ptr, M, N, K, stride_am,
stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_asm, stride_ask,
stride_bsn, stride_bsk,
# Meta-parameters
DTYPE_A: tl.constexpr, DTYPE_B: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, mfma_nonkdim: tl.constexpr,
preshuffle: tl.constexpr, fast_math: tl.constexpr):
"""Kernel for computing the matmul C = A x B.
A and B inputs are in the microscale fp4 (mxfp4) format.
A_scales and B_scales are in e8m0 format.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""

PACK_FACTOR_A: tl.constexpr = 2 if DTYPE_A == "e2m1" else 1
PACK_FACTOR_B: tl.constexpr = 2 if DTYPE_B == "e2m1" else 1

pid = tl.program_id(axis=0)

num_pid_n = tl.cdiv(N, BLOCK_N)
Expand All @@ -502,73 +505,99 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale

# We assume 32 elements along K share the same scale.
SCALE_GROUP_SIZE: tl.constexpr = 32
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // SCALE_GROUP_SIZE

if preshuffle:
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
else:
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 1

num_k_iter = tl.cdiv(K, BLOCK_K // 2)
# Create pointers for first block of A and B input matrices
# The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
offs_k = tl.arange(0, BLOCK_K // 2)
offs_k_split = offs_k
offs_ak = tl.arange(0, BLOCK_K // PACK_FACTOR_A)
offs_bk = tl.arange(0, BLOCK_K // PACK_FACTOR_B)
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# Create pointers for the first block of A and B scales
offs_asn = (pid_n *
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0, (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE))) % N
offs_ks = tl.arange(0, BLOCK_K // SCALE_GROUP_SIZE * NON_K_PRESHUFFLE_BLOCK_SIZE)
offs_ks = tl.arange(0, MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE)

# B scales are N x K even though B operand is K x N.
b_scale_ptrs = (b_scales_ptr + offs_asn[:, None] * stride_bsn + offs_ks[None, :] * stride_bsk)
offs_asm = (pid_m *
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0, (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE))) % M
a_scale_ptrs = (a_scales_ptr + offs_asm[:, None] * stride_asm + offs_ks[None, :] * stride_ask)
if a_scales_ptr is not None:
offs_asm = (pid_m *
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0,
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE))) % M
a_scale_ptrs = (a_scales_ptr + offs_asm[:, None] * stride_asm + offs_ks[None, :] * stride_ask)
if b_scales_ptr is not None:
offs_asn = (pid_n *
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0,
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE))) % N
b_scale_ptrs = (b_scales_ptr + offs_asn[:, None] * stride_bsn + offs_ks[None, :] * stride_bsk)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k in range(0, num_k_iter):
for k in range(0, tl.cdiv(K, BLOCK_K)):
if preshuffle:
# Here we "undo" the shuffle done in global memory (shuffle_scales_cdna4 function).
if mfma_nonkdim == 32:
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 2, 32, 4,
1).permute(0, 3, 1, 4, 2,
5).reshape(BLOCK_M, BLOCK_K // SCALE_GROUP_SIZE)
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 2, 32, 4,
1).permute(0, 3, 1, 4, 2,
5).reshape(BLOCK_N, BLOCK_K // SCALE_GROUP_SIZE)
if a_scales_ptr is not None:
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
MX_SCALE_BLOCK_K // 8, 2, 32, 4,
1).permute(0, 3, 1, 4, 2,
5).reshape(BLOCK_M, MX_SCALE_BLOCK_K)
else:
a_scales = None
if b_scales_ptr is not None:
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
MX_SCALE_BLOCK_K // 8, 2, 32, 4,
1).permute(0, 3, 1, 4, 2,
5).reshape(BLOCK_N, MX_SCALE_BLOCK_K)
else:
b_scales = None
elif mfma_nonkdim == 16:
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 4, 16, 2, 2,
1).permute(0, 5, 3, 1, 4, 2,
6).reshape(BLOCK_M, BLOCK_K // SCALE_GROUP_SIZE)
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 4, 16, 2, 2,
1).permute(0, 5, 3, 1, 4, 2,
6).reshape(BLOCK_N, BLOCK_K // SCALE_GROUP_SIZE)
if a_scales_ptr is not None:
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2,
1).permute(0, 5, 3, 1, 4, 2,
6).reshape(BLOCK_M, MX_SCALE_BLOCK_K)
else:
a_scales = None
if b_scales_ptr is not None:
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2,
1).permute(0, 5, 3, 1, 4, 2,
6).reshape(BLOCK_N, MX_SCALE_BLOCK_K)
else:
b_scales = None
else:
a_scales = tl.load(a_scale_ptrs)
b_scales = tl.load(b_scale_ptrs)
if a_scales_ptr is not None:
a_scales = tl.load(a_scale_ptrs)
else:
a_scales = None
if b_scales_ptr is not None:
b_scales = tl.load(b_scale_ptrs)
else:
b_scales = None

a = tl.load(a_ptrs)
b = tl.load(b_ptrs, cache_modifier=None)

accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1")
accumulator += tl.dot_scaled(a, a_scales, DTYPE_A, b, b_scales, DTYPE_B, fast_math=fast_math)

# Advance the ptrs to the next K block.
a_ptrs += (BLOCK_K // 2) * stride_ak
b_ptrs += (BLOCK_K // 2) * stride_bk
a_ptrs += (BLOCK_K // PACK_FACTOR_A) * stride_ak
b_ptrs += (BLOCK_K // PACK_FACTOR_B) * stride_bk
if preshuffle:
a_scale_ptrs += BLOCK_K * stride_ask
b_scale_ptrs += BLOCK_K * stride_bsk
if a_scales_ptr is not None:
a_scale_ptrs += BLOCK_K * stride_ask
if b_scales_ptr is not None:
b_scale_ptrs += BLOCK_K * stride_bsk
else:
a_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE) * stride_ask
b_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE) * stride_bsk
if a_scales_ptr is not None:
a_scale_ptrs += MX_SCALE_BLOCK_K * stride_ask
if b_scales_ptr is not None:
b_scale_ptrs += MX_SCALE_BLOCK_K * stride_bsk

c = accumulator.to(c_ptr.type.element_ty)

Expand All @@ -583,11 +612,14 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scale

@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 256), (64, 64, 512), [32, 32, 64]])
@pytest.mark.parametrize("DTYPE_A, DTYPE_B, FAST_MATH", [("mxfp4", "mxfp4", False), ("fp16", "mxfp8e5", False),
("mxfp8e4", "bf16", False), ("bf16", "mxfp4", True)])
@pytest.mark.parametrize("mfma_nonkdim", [16, 32])
@pytest.mark.parametrize("preshuffle", [True, False])
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] == 10, reason="Compilation bug for GB200.")
@pytest.mark.skipif(is_hip() and not is_hip_cdna4(), reason="Scaled dot is not emulated on other archs yet.")
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, device):
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B, FAST_MATH, mfma_nonkdim,
preshuffle, device):
# This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
#
# Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
Expand Down Expand Up @@ -637,6 +669,12 @@ def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_no
if preshuffle and (BLOCK_M < 32 or BLOCK_N < 32 or BLOCK_K < 256):
pytest.skip("Minimal tile size for preshuffling is 32x32x256")

if not (DTYPE_A.startswith("mx") or DTYPE_B.startswith("mx")):
pytest.skip("Requires at least 1 microscaling operand")

if is_cuda() and (DTYPE_A == "mxfp8e4" or DTYPE_B == "mxfp8e4"):
pytest.skip("Skip fp8e4 on NV backend")

def shuffle_scales_cdna4(scales: torch.Tensor):
if not preshuffle:
return scales
Expand Down Expand Up @@ -665,63 +703,77 @@ def run_torch(x, w, x_scales, w_scales, dtype):
x_f32 = x.to(torch.float32)
w_f32 = w.to(torch.float32)
# Next convert the e8m0 scales to f32.
x_scales = x_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
x_scales_f32 = e8m0_to_f32(x_scales)
x_f32 = x_f32 * x_scales_f32
w_scales = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
w_scales_f32 = e8m0_to_f32(w_scales)
w_f32 = w_f32 * w_scales_f32
if x_scales is not None:
x_scales = x_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
x_scales_f32 = e8m0_to_f32(x_scales)
x_f32 = x_f32 * x_scales_f32
if w_scales is not None:
w_scales = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
w_scales_f32 = e8m0_to_f32(w_scales)
w_f32 = w_f32 * w_scales_f32
return torch.mm(x_f32, w_f32.T).to(dtype)

def generate_gemm_afp4wfp4_inputs(M, N, K):
dtype_to_torch_type = {
"fp16": torch.half, "bf16": torch.bfloat16, "mxfp8e5": torch.float8_e5m2, "mxfp8e4": torch.float8_e4m3fn
}

dtype_to_triton_type = {"fp16": "fp16", "bf16": "bf16", "mxfp8e5": "e5m2", "mxfp8e4": "e4m3", "mxfp4": "e2m1"}

def generate_gemm_input(dim0, dim1, dtype):
torch.manual_seed(5)
SCALE_GROUP_SIZE = 32

x = MXFP4Tensor(size=(M, K), device="cuda").random()
w = MXFP4Tensor(size=(N, K), device="cuda").random()

x_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, M), dtype=torch.uint8, device="cuda")
w_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, N), dtype=torch.uint8, device="cuda")
x_scales = x_scales.T
w_scales = w_scales.T
x_scales_shuffled = shuffle_scales_cdna4(x_scales)
w_scales_shuffled = shuffle_scales_cdna4(w_scales)

return (
x,
w,
x_scales,
w_scales,
x_scales_shuffled,
w_scales_shuffled,
)

x_mxfp4, w_mxfp4, x_scales, w_scales, x_scales_triton, w_scales_triton = generate_gemm_afp4wfp4_inputs(M, N, K)

x = x_mxfp4.to_packed_tensor(dim=1)
w = w_mxfp4.to_packed_tensor(dim=1)

torch_out = run_torch(x_mxfp4, w_mxfp4, x_scales, w_scales, torch.float32)
M, K = x.shape
N, K = w.shape
if dtype == "mxfp4":
v = MXFP4Tensor(size=(dim0, dim1), device="cuda").random()
elif dtype == "mxfp8e5":
v = torch.randint(20, 40, (dim0, dim1), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
elif dtype == "mxfp8e4":
v = torch.randint(20, 40, (dim0, dim1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device)
elif dtype in ("fp16", "bf16"):
v = torch.randn((dim0, dim1), device=device, dtype=dtype_to_torch_type[dtype])
else:
raise ValueError(f"Unsupported data type: {dtype}")

if dtype.startswith("mx"):
scales = torch.randint(124, 128, (dim0, dim1 // SCALE_GROUP_SIZE), dtype=torch.uint8, device=device)
scales_shuffled = shuffle_scales_cdna4(scales)
else:
scales = None
scales_shuffled = None

return (v, scales, scales_shuffled)

x, x_scales, x_scales_triton = generate_gemm_input(M, K, DTYPE_A)
w, w_scales, w_scales_triton = generate_gemm_input(N, K, DTYPE_B)

torch_out = run_torch(x, w, x_scales, w_scales, torch.float32)

if DTYPE_A == "mxfp4":
x = x.to_packed_tensor(dim=1)

if DTYPE_B == "mxfp4":
w = w.to_packed_tensor(dim=1)

w = w.T
triton_out = torch.empty((M, N), device=x.device)

x_scales_strides = x_scales_triton.stride() if x_scales is not None else (None, None)
w_scales_strides = w_scales_triton.stride() if w_scales is not None else (None, None)

kernel_kwargs = {}
if is_hip():
kernel_kwargs["matrix_instr_nonkdim"] = mfma_nonkdim

grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
k = _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton,
w_scales_triton, M, N, K, x.stride(0), x.stride(1),
w.stride(0), w.stride(1), 0, triton_out.stride(0),
triton_out.stride(1), x_scales_triton.stride(0),
x_scales_triton.stride(1), w_scales_triton.stride(0),
w_scales_triton.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
mfma_nonkdim, preshuffle, num_warps=8, num_stages=1,
**kernel_kwargs)
k = _gemm_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton, w_scales_triton, M, N, K,
x.stride(0), x.stride(1), w.stride(0), w.stride(1),
triton_out.stride(0), triton_out.stride(1), *x_scales_strides,
*w_scales_strides, dtype_to_triton_type[DTYPE_A],
dtype_to_triton_type[DTYPE_B], BLOCK_M, BLOCK_N, BLOCK_K,
mfma_nonkdim, preshuffle, fast_math=FAST_MATH, num_warps=8,
num_stages=1, **kernel_kwargs)
triton_out = triton_out.to(torch.float32)
torch.testing.assert_close(torch_out, triton_out)
torch.testing.assert_close(torch_out, triton_out, atol=2e-5, rtol=1e-4)
if is_hip() and preshuffle:
assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"]
assert "ds_read_u8" not in k.asm["amdgcn"]
Expand All @@ -738,7 +790,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
NUM_STAGES = min(NUM_STAGES, 2)
elif BLOCK_K == 256:
NUM_STAGES = min(NUM_STAGES, 3)
#since the block size are big we use num_warps = 8 to avoid pressure problems.
# since the block size are big we use num_warps = 8 to avoid pressure problems.
num_warps = 8
torch.manual_seed(42)
dtype_src_str = "float8e5"
Expand Down
Loading
Loading