Skip to content

Commit 2b29c3d

Browse files
authored
[triton_kernels] decouple split-k reduction from inter-expert reductions in matmul (#8483)
1 parent 3c2e6f8 commit 2b29c3d

File tree

11 files changed

+262
-398
lines changed

11 files changed

+262
-398
lines changed

python/triton_kernels/bench/distributed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
277277

278278
# precision configs
279279
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale)
280-
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.0, 1.0), 2)
280+
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2),
281+
(1.0, 1.0))
281282
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale)
282283
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale)
283284
if rank == 0:

python/triton_kernels/tests/test_matmul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, mode
130130
) if weight_use_flexpoint else InFlexData(),
131131
out_data=OutFlexData(
132132
dtype=out_dtype,
133-
expected_scale=make(4.00, 5.00, mode == "batched" or expt_is_inner),
134-
actual_scale=make(0, 0, mode == "batched" or expt_is_inner),
133+
expected_scale=make_scalar(4.00),
134+
actual_scale=make_scalar(0),
135135
checksum_scale=None,
136136
) if act_use_flexpoint else OutFlexData(),
137137
)
@@ -776,8 +776,8 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
776776
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
777777
b = matmul_ogs(
778778
x, w, bias, rdata, gindx, sindx, precision_opt,
779-
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
780-
(swiglu_alpha, swiglu_limit), 2))
779+
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2),
780+
(swiglu_alpha, swiglu_limit)))
781781
except opt_flags.InapplicableConstraint:
782782
pytest.skip("inapplicable constraint")
783783

python/triton_kernels/tests/test_reduce.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp_torch, downcast_to_mxfp_torch
66
from triton_kernels.numerics import InFlexData, OutFlexData
77
import triton
8+
import triton.language as tl
89

910

1011
def init_mask(mask_mode, B, M, N, device):
@@ -30,8 +31,9 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
3031

3132

3233
@triton.jit
33-
def plus_a(x, a):
34-
return x + a
34+
def plus_a_reduce(x, a):
35+
y = x + a
36+
return tl.sum(y.reshape([x.shape[0], x.shape[1] // 2, 2]), axis=2)
3537

3638

3739
@pytest.mark.parametrize("B, M, N, postprocess_fn", [
@@ -84,14 +86,15 @@ def test_op(B, M, N, dtype_str, dim, mask_mode, postprocess_fn):
8486
reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale)
8587
return
8688
if postprocess_fn == "plus_ten":
87-
postprocess_fn_tri = PostprocessFn(specs=FnSpecs("plus_a", plus_a, ("a", )), fn_args=(10, ))
88-
postprocess_fn_ref = lambda x: x + 10
89+
postprocess_fn_tri = PostprocessFn(specs=FnSpecs("plus_a", plus_a_reduce, ("a", ), reduction_n=2),
90+
fn_args=(10, ))
91+
postprocess_fn_ref = lambda x: (x + 10).reshape([x.shape[0], x.shape[1] // 2, 2]).sum(dim=2)
8992
else:
9093
postprocess_fn_tri = postprocess_fn_ref = None
9194
y_tri, y_tri_mxscale = reduce(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_tri,
92-
postprocess_fn=postprocess_fn_tri)
95+
postprocess_fn1=postprocess_fn_tri)
9396
y_ref, y_ref_mxscale = reduce_torch(x, dim=dim, mask=mask, x_mxscale=x_mscale, x_flex=x_flex, y_flex=y_flex_ref,
94-
postprocess_fn=postprocess_fn_ref)
97+
postprocess_fn1=postprocess_fn_ref)
9598
if is_mx:
9699
y_ref = upcast_from_mxfp_torch(y_ref, y_ref_mxscale, torch.float16, axis=-1)
97100
y_tri = upcast_from_mxfp_torch(y_tri, y_tri_mxscale, torch.float16, axis=-1)

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 97 additions & 185 deletions
Large diffs are not rendered by default.

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,12 @@ def matmul_launch_metadata(grid, kernel, args):
238238
fM = M if M is not None else n_tokens
239239
ret[f"flops{nbits}"] = 2.0 * fM * N * K * (1 if expt_is_inner else batch_size)
240240

241-
dst = args.get("GatherDstIndx", None)
242241
# sindx = args.get("WriteBackIndx", None)
243242
n_x_bytes = X.numel() * X.element_size()
244243
n_y_bytes = Y.numel() * Y.element_size()
245244
if hist is not None:
246245
assert n_tokens is not None
247-
n_expts_act = args["N_EXPTS_ACT"]
248-
249-
if (dst is not None) and launch_metadata_allow_sync():
250-
n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
251-
else:
252-
n_read_rows = n_tokens
246+
n_read_rows = n_tokens
253247

254248
if expt_is_inner:
255249
n_x_bytes = n_read_rows * X.shape[-2] * X.element_size()

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _matmul_ogs(
7171
# epilogue transform
7272
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
7373
# MoE config
74-
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
74+
N_EXPTS_TOT: tl.constexpr,
7575
# precision config
7676
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
7777
FLEXPOINT_SATURATE_INF: tl.constexpr,
@@ -81,6 +81,7 @@ def _matmul_ogs(
8181
# optimization config
8282
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
8383
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
84+
INIT_OUTPUT_TO_ZERO: tl.constexpr,
8485
# One of ["HOPPER", "BLACKWELL", None]
8586
SWIZZLE_MX_VALUE: tl.constexpr,
8687
# One of ["HOPPER", "BLACKWELL", None]
@@ -198,7 +199,7 @@ def _matmul_ogs(
198199
# We are tiling Y here, so the tiling is independent of matmul (where we
199200
# tile X & W and scatter to different rows of Y).
200201
# TODO: refactor (same code in _p_matmul_ogs)
201-
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
202+
if HAS_FUSED_SCATTER and INIT_OUTPUT_TO_ZERO:
202203
tl.device_assert(batch_size == 1)
203204
pid_mnk = pid
204205
if XCD_SWIZZLE != 1:
@@ -241,7 +242,7 @@ def _matmul_ogs(
241242
else:
242243
GatherIndx += start_m
243244
# no needs to bounds-check here because `offs_x_m` wraps around M dim
244-
offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
245+
offs_x_m = tl.load(GatherIndx + offs_x_m)
245246
offs_k = off_k_x + tl.arange(0, BLOCK_K)
246247
XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
247248

@@ -455,7 +456,7 @@ def _matmul_ogs(
455456
YActualScale += start_m * stride_y_mx_m
456457
YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
457458
else:
458-
YActualScalePtrs = YActualScale + (offs_y_m - num_idxs // N_EXPTS_ACT).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
459+
YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
459460
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
460461
else:
461462
if PER_BATCH_OUT_SCALE:

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _p_matmul_ogs(
8080
# epilogue transform
8181
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
8282
# MoE config
83-
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
83+
N_EXPTS_TOT: tl.constexpr,
8484
# precision config
8585
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
8686
FLEXPOINT_SATURATE_INF: tl.constexpr,
@@ -90,6 +90,7 @@ def _p_matmul_ogs(
9090
# optimization config
9191
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
9292
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
93+
INIT_OUTPUT_TO_ZERO: tl.constexpr,
9394
# NYI: Must be None
9495
SWIZZLE_MX_VALUE: tl.constexpr,
9596
# One of ["BLACKWELL", None]
@@ -172,7 +173,7 @@ def _p_matmul_ogs(
172173
yN = N // ACTIVATION_REDUCTION_N
173174

174175
# set masked out rows to 0
175-
if HAS_SCATTER and N_EXPTS_ACT == 1:
176+
if HAS_SCATTER and INIT_OUTPUT_TO_ZERO:
176177
# Iterate with reversed pids so that later pids will get more tiles if the number of
177178
# tiles isn't evenly divisible by the number of SMs.
178179
# The main loop after this iterates in the forward direction such that earlier
@@ -233,15 +234,14 @@ def _p_matmul_ogs(
233234
offs_x_m += start_z * (stride_x_z // stride_x_m)
234235
offs_x_m = tl.where(mask_m, offs_x_m, -1)
235236
else:
236-
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
237-
mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
237+
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m, other=-1)
238238
elif X_TMA_MODE is None or is_x_microscaled:
239239
offs_m = off_m + tl.arange(0, BLOCK_M)
240240
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
241241
# no needs to bounds-check here because `offs_m` wraps around M dim
242242
if GatherIndx is not None:
243243
tl.static_assert(HAS_GATHER)
244-
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
244+
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m)
245245
offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
246246

247247
if is_x_microscaled:

python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py

Lines changed: 0 additions & 102 deletions
This file was deleted.

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ class OptFlags:
2828
arch: str
2929
target_kernel_kwargs: dict
3030

31-
def __post_init__(self):
32-
if self.fused_scatter and self.split_k != 1:
33-
raise ValueError("Not supported")
34-
3531

3632
def max_allowable_mn(
3733
max_mn: int,

0 commit comments

Comments
 (0)