Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,248 @@ def block_scale_mxfp_matmul( #
tl.store(output_ptrs, accumulator, mask=c_mask)


@triton.jit
def _gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4(a_ptr, b_ptr, c_ptr, a_scales_ptr, b_scales_ptr, M, N, K, stride_am,
stride_ak, stride_bk, stride_bn, stride_ck, stride_cm, stride_cn,
stride_asm, stride_ask, stride_bsn, stride_bsk,
# Meta-parameters
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
mfma_nonkdim: tl.constexpr, preshuffle: tl.constexpr):
"""Kernel for computing the matmul C = A x B.
A and B inputs are in the microscale fp4 (mxfp4) format.
A_scales and B_scales are in e8m0 format.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""

pid = tl.program_id(axis=0)

num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

# We assume 32 elements along K share the same scale.
SCALE_GROUP_SIZE: tl.constexpr = 32

if preshuffle:
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
else:
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 1

num_k_iter = tl.cdiv(K, BLOCK_K // 2)
# Create pointers for first block of A and B input matrices
# The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
offs_k = tl.arange(0, BLOCK_K // 2)
offs_k_split = offs_k
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# Create pointers for the first block of A and B scales
offs_asn = (pid_n *
(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0, (BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE))) % N
offs_ks = tl.arange(0, BLOCK_K // SCALE_GROUP_SIZE * NON_K_PRESHUFFLE_BLOCK_SIZE)

# B scales are N x K even though B operand is K x N.
b_scale_ptrs = (b_scales_ptr + offs_asn[:, None] * stride_bsn + offs_ks[None, :] * stride_bsk)
offs_asm = (pid_m *
(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE) + tl.arange(0, (BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE))) % M
a_scale_ptrs = (a_scales_ptr + offs_asm[:, None] * stride_asm + offs_ks[None, :] * stride_ask)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k in range(0, num_k_iter):
if preshuffle:
# Here we "undo" the shuffle done in global memory (shuffle_scales_cdna4 function).
if mfma_nonkdim == 32:
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 2, 32, 4,
1).permute(0, 3, 1, 4, 2,
5).reshape(BLOCK_M, BLOCK_K // SCALE_GROUP_SIZE)
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 2, 32, 4,
1).permute(0, 3, 1, 4, 2,
5).reshape(BLOCK_N, BLOCK_K // SCALE_GROUP_SIZE)
elif mfma_nonkdim == 16:
a_scales = tl.load(a_scale_ptrs).reshape(BLOCK_M // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 4, 16, 2, 2,
1).permute(0, 5, 3, 1, 4, 2,
6).reshape(BLOCK_M, BLOCK_K // SCALE_GROUP_SIZE)
b_scales = tl.load(b_scale_ptrs).reshape(BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE,
BLOCK_K // SCALE_GROUP_SIZE // 8, 4, 16, 2, 2,
1).permute(0, 5, 3, 1, 4, 2,
6).reshape(BLOCK_N, BLOCK_K // SCALE_GROUP_SIZE)
else:
a_scales = tl.load(a_scale_ptrs)
b_scales = tl.load(b_scale_ptrs)

a = tl.load(a_ptrs)
b = tl.load(b_ptrs, cache_modifier=None)

accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1")

# Advance the ptrs to the next K block.
a_ptrs += (BLOCK_K // 2) * stride_ak
b_ptrs += (BLOCK_K // 2) * stride_bk
if preshuffle:
a_scale_ptrs += BLOCK_K * stride_ask
b_scale_ptrs += BLOCK_K * stride_bsk
else:
a_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE) * stride_ask
b_scale_ptrs += (BLOCK_K // SCALE_GROUP_SIZE) * stride_bsk

c = accumulator.to(c_ptr.type.element_ty)

# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :])
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)

tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt")


@pytest.mark.parametrize("M, N, K", [(1024, 1024, 1024)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 256), (64, 64, 512), [32, 32, 64]])
@pytest.mark.parametrize("mfma_nonkdim", [16, 32])
@pytest.mark.parametrize("preshuffle", [True, False])
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] == 10, reason="Compilation bug for GB200.")
@pytest.mark.skipif(is_hip() and not is_hip_cdna4(), reason="Scaled dot is not emulated on other archs yet.")
def test_preshuffle_scale_mxfp_cdna4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, device):
# This test primarily evaluates correctness for efficient scale packing for MFMA-scaled instructions.
#
# Scales are stored as 8-bit tensors, where each element scales 32 values from the A or B operand tensors.
# Since MFMA instructions are wave-level instructions, that means that each thread provides a fixed set of operand values to MFMA instructions.
#
# For example, in an MFMA instruction with shape 16x16x128:
# - 4 threads contribute elements along the K dimension.
# - 16 threads contribute elements along the M or N dimension.
#
# From the perspective of the scales tensor, even if the K dimension is stored contiguously in LDS,
# each thread sees its elements along K dim as strided due to interleaving with other threads.
# This striding limits the ability to load scale values using vectorized memory access.
#
# Our goal is to reorganize the scale tensor so that:
# 1. Each thread stores the 4 scale values it needs for 4 MFMA ops in contiguous memory.
# 2. Continuous threads access contiguous memory locations improving global memory coalescing when bypassing LDS,
# which is especially beneficial for "skinny" matmuls.
#
# We consider two MFMA cases: one with non-K dimension 16, and one with 32.
# In both, the minimum tile size for preshuffling is 32x32x256.
# For example, for a 32x256 operand tile, the corresponding scale tensor has shape 32x8,
# where each scale covers 32 elements along the K dimension.
#
# Each thread holds one scale per MFMA operation. We pack the 4 scale values (for 4 different MFMA ops)
# next to each other in memory.
#
# Case 1: mfma_scaled_16x16x128
#
# Packing order: mfma_op_0, mfma_op_2, mfma_op_1, mfma_op_3
#
# K = 128 K = 128
# +------------+ +------------+
# M=16| MFMA op 0 | | MFMA op 1 |
# +------------+ +------------+
# M=16| MFMA op 2 | | MFMA op 3 |
# +------------+ +------------+
#
# Case 2: mfma_scaled_32x32x64
#
# Packing order: mfma_op_0, mfma_op_1, mfma_op_2, mfma_op_3
#
# K=64 K=64 K=64 K=64
# +--------+ +--------+ +--------+ +--------+
# M=32| op 0 | | op 1 | | op 2 | | op 3 |
# +--------+ +--------+ +--------+ +--------+

if preshuffle and (BLOCK_M < 32 or BLOCK_N < 32 or BLOCK_K < 256):
pytest.skip("Minimal tile size for preshuffling is 32x32x256")

def shuffle_scales_cdna4(scales: torch.Tensor):
if not preshuffle:
return scales

scales_shuffled = scales.clone()

sm, sn = scales_shuffled.shape
if mfma_nonkdim == 32:
scales_shuffled = scales_shuffled.view(sm // 32, 32, sn // 8, 4, 2, 1)
scales_shuffled = scales_shuffled.permute(0, 2, 4, 1, 3, 5).contiguous()
elif mfma_nonkdim == 16:
scales_shuffled = scales_shuffled.view(sm // 32, 2, 16, sn // 8, 2, 4, 1)
scales_shuffled = scales_shuffled.permute(0, 3, 5, 2, 4, 1, 6).contiguous()

scales_shuffled = scales_shuffled.view(sm // 32, sn * 32)
return scales_shuffled

def e8m0_to_f32(x):
x_f32 = 2**((x - 127).to(torch.float32))
x_f32[x_f32 == 128] = float("nan")
return x_f32

def run_torch(x, w, x_scales, w_scales, dtype):
# First convert the x and w inputs to f32.
SCALE_GROUP_SIZE = 32
x_f32 = x.to(torch.float32)
w_f32 = w.to(torch.float32)
# Next convert the e8m0 scales to f32.
x_scales = x_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
x_scales_f32 = e8m0_to_f32(x_scales)
x_f32 = x_f32 * x_scales_f32
w_scales = w_scales.repeat_interleave(SCALE_GROUP_SIZE, dim=1).to(torch.float32)
w_scales_f32 = e8m0_to_f32(w_scales)
w_f32 = w_f32 * w_scales_f32
return torch.mm(x_f32, w_f32.T).to(dtype)

def generate_gemm_afp4wfp4_inputs(M, N, K):
torch.manual_seed(5)
SCALE_GROUP_SIZE = 32

x = MXFP4Tensor(size=(M, K), device="cuda").random()
w = MXFP4Tensor(size=(N, K), device="cuda").random()

x_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, M), dtype=torch.uint8, device="cuda")
w_scales = torch.randint(124, 128, (K // SCALE_GROUP_SIZE, N), dtype=torch.uint8, device="cuda")
x_scales = x_scales.T
w_scales = w_scales.T
x_scales_shuffled = shuffle_scales_cdna4(x_scales)
w_scales_shuffled = shuffle_scales_cdna4(w_scales)

return (
x,
w,
x_scales,
w_scales,
x_scales_shuffled,
w_scales_shuffled,
)

x_mxfp4, w_mxfp4, x_scales, w_scales, x_scales_triton, w_scales_triton = generate_gemm_afp4wfp4_inputs(M, N, K)

x = x_mxfp4.to_packed_tensor(dim=1)
w = w_mxfp4.to_packed_tensor(dim=1)

torch_out = run_torch(x_mxfp4, w_mxfp4, x_scales, w_scales, torch.float32)
M, K = x.shape
N, K = w.shape
w = w.T
triton_out = torch.empty((M, N), device=x.device)

kernel_kwargs = {}
if is_hip():
kernel_kwargs["matrix_instr_nonkdim"] = mfma_nonkdim

grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
_gemm_afp4_wfp4_kernel_preshuffled_scales_cdna4[grid](x, w, triton_out, x_scales_triton, w_scales_triton, M, N, K,
x.stride(0), x.stride(1), w.stride(0), w.stride(1), 0,
triton_out.stride(0), triton_out.stride(1),
x_scales_triton.stride(0), x_scales_triton.stride(1),
w_scales_triton.stride(0), w_scales_triton.stride(1), BLOCK_M,
BLOCK_N, BLOCK_K, mfma_nonkdim, preshuffle, num_warps=8,
num_stages=1, **kernel_kwargs)
triton_out = triton_out.to(torch.float32)
torch.testing.assert_close(torch_out, triton_out)


@pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128),
(128, 128, 256), (128, 256, 256)])
Expand Down
68 changes: 56 additions & 12 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,15 @@ struct DotOpMFMAConversionHelper {
results = b.zext(i32_ty, b.bitcast(vec, i8_ty));
}
}

if (2 == kBase)
// This case can occur during scale tensor packing when there aren't
// enough elements to fill all 4 opSel slots. For example, with an A
// tensor of size 16x256 and using 16x16x128 block sizes, we end up with
// only 2 elements to pack, resulting in a kBase of 2.
results = b.zext(i32_ty, b.bitcast(vec, i16_ty));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments to explain this case?
Also the comment at line 429 and 446 needs to be updated.

if (4 == kBase)
// This is for int8 on pre- CDNA3 GPUs
// This is for int8 on pre- CDNA3 GPUs and scale tensors on CDNA4 GPUs
results = b.bitcast(vec, i32_ty);
if (8 == kBase)
results = b.bitcast(vec, i64_ty);
Expand All @@ -465,6 +472,11 @@ struct DotOpMFMAConversionHelper {
auto elems = unpackLLElements(loc, value, rewriter);
// number of kBase-element vectors
int numVecInKBase = kRepInKWidth * kWidth / kBase;
if (numVecInKBase == 0) {
numVecInKBase = 1;
nonKRep /= kBase / (kRepInKWidth * kWidth);
assert(nonKRep > 0 && "nonKrep too small");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this assert?
This if can only happen for scales, and scale's kBase is bounded by numRepK * numRepM.

}
ValueTable dotOpVals;

SmallVector<int64_t> strides =
Expand Down Expand Up @@ -544,17 +556,19 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {

Value generateScaledMFMAOp(StringRef intrinsicName, Value valA, Value valB,
Value valC, Value valScaleA, Value valScaleB,
Type elemTypeA, Type elemTypeB) const {
Type elemTypeA, Type elemTypeB, int opSelA,
int opSelB) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto resType = valC.getType();
Value zeroFlag = b.i32_val(0);
Value valOpSelA = b.i32_val(opSelA);
Value valOpSelB = b.i32_val(opSelB);
OperationState loweredOp(loc, intrinsicName);
int32_t cbsz = getMfmaF8F6F4MatrixFormat(elemTypeA);
int32_t blgp = getMfmaF8F6F4MatrixFormat(elemTypeB);
assert((cbsz != -1) && (blgp != -1));
loweredOp.addTypes(resType);
loweredOp.addOperands({valA, valB, valC, b.i32_val(cbsz), b.i32_val(blgp),
zeroFlag, valScaleA, zeroFlag, valScaleB});
valOpSelA, valScaleA, valOpSelB, valScaleB});
return rewriter.create(loweredOp)->getResult(0);
}

Expand Down Expand Up @@ -636,8 +650,6 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
// better way to get it when adapting other data types. Similar to
// scaleKBase
constexpr int scaleKWidth = 1;
constexpr int scaleKBase = 1;

Value loadedA = adaptor.getA();
Value loadedB = adaptor.getB();
Value loadedAScale = adaptor.getAScale();
Expand All @@ -650,6 +662,27 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
auto numRepB = repA[0];
assert(repA[0] == repB[0]);

// Scaled MFMA instructions expect scale operands as 32-bit values,
// even though each individual scale is only 8 bits. To reduce register
// usage, we pack 4 scales into a single 32-bit value and use the opSel
// field to select the appropriate byte during execution. Packing is done
// along the K dimension first; if there aren’t enough values in K, we
// continue along the non-K dimension.
// TODO: Support opSel selection for constant scales stored in SGPRs.
const int scaleAKBase =
isAScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK * numRepM));
const int scaleBKBase =
isBScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK * numRepN));

int akPackedVals =
isAScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK));
int bkPackedVals =
isBScaleConstant ? 1 : std::min(4, static_cast<int>(numRepK));

assert(scaleAKBase % akPackedVals == 0 && scaleBKBase % bkPackedVals == 0);
int aNonKPackedVals = scaleAKBase / akPackedVals;
int bNonKPackedVals = scaleBKBase / bkPackedVals;

auto operandA = getValuesFromDotOperandLayoutStruct(
loadedA, numRepB, numRepM, numRepK, aKWidth, aKBase,
aTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
Expand All @@ -664,13 +697,13 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
if (existBothScales) {
auto aScaleTensorTy = cast<RankedTensorType>(aScale.getType());
operandAScale = getValuesFromDotOperandLayoutStruct(
loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleKBase,
loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleAKBase,
aScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false,
isAScaleConstant);

auto bScaleTensorTy = cast<RankedTensorType>(bScale.getType());
operandBScale = getValuesFromDotOperandLayoutStruct(
loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleKBase,
loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleBKBase,
bScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false,
isBScaleConstant);
}
Expand Down Expand Up @@ -731,18 +764,29 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
for (innerK = 0; innerK < innerKBound; innerK++) {
int k = is2Step ? outerK : innerK;
if (existBothScales) {
int akScale = k / akPackedVals;
int bkScale = k / bkPackedVals;
int opSelA = 0, opSelB = 0;

int mScale = m / aNonKPackedVals;
int nScale = n / bNonKPackedVals;
opSelA = (m * numRepK + k) % (aNonKPackedVals * akPackedVals);
opSelB = (n * numRepK + k) % (bNonKPackedVals * bkPackedVals);

if (mfmaLayout.getIsTransposed()) {
acc = generateScaledMFMAOp(
intrinsicName, operandB[{b, n, k}], operandA[{b, m, k}],
acc, operandBScale[{b, n, k}], operandAScale[{b, m, k}],
acc, operandBScale[{b, nScale, bkScale}],
operandAScale[{b, mScale, akScale}],
maybeMfmaIntrinsic->bElementType,
maybeMfmaIntrinsic->aElementType);
maybeMfmaIntrinsic->aElementType, opSelB, opSelA);
} else {
acc = generateScaledMFMAOp(
intrinsicName, operandA[{b, m, k}], operandB[{b, n, k}],
acc, operandAScale[{b, m, k}], operandBScale[{b, n, k}],
acc, operandAScale[{b, mScale, akScale}],
operandBScale[{b, nScale, bkScale}],
maybeMfmaIntrinsic->aElementType,
maybeMfmaIntrinsic->bElementType);
maybeMfmaIntrinsic->bElementType, opSelA, opSelB);
}
} else {
if (mfmaLayout.getIsTransposed()) {
Expand Down
Loading
Loading