Skip to content

Commit 515ef4f

Browse files
authored
Fuse routed scaling factor in topk_reduce kernel (#6220)
1 parent f5599ef commit 515ef4f

File tree

10 files changed

+331
-9
lines changed

10 files changed

+331
-9
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from triton.testing import do_bench
5+
6+
7+
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
8+
@triton.jit
9+
def _moe_sum_reduce_kernel(
10+
input_ptr,
11+
input_stride_0,
12+
input_stride_1,
13+
input_stride_2,
14+
output_ptr,
15+
output_stride_0,
16+
output_stride_1,
17+
token_num: int,
18+
topk_num: int,
19+
hidden_dim: int,
20+
routed_scaling_factor: tl.constexpr,
21+
BLOCK_M: tl.constexpr,
22+
BLOCK_DIM: tl.constexpr,
23+
NUM_STAGE: tl.constexpr,
24+
):
25+
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
26+
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
27+
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
28+
29+
token_block_id = tl.program_id(0)
30+
dim_block_id = tl.program_id(1)
31+
32+
token_start = token_block_id * BLOCK_M
33+
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
34+
35+
dim_start = dim_block_id * BLOCK_DIM
36+
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
37+
38+
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
39+
40+
for token_index in range(token_start, token_end):
41+
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
42+
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
43+
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
44+
tmp = tl.load(
45+
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
46+
)
47+
accumulator += tmp
48+
accumulator = accumulator * routed_scaling_factor
49+
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
50+
tl.store(
51+
store_t_ptr,
52+
accumulator.to(input_ptr.dtype.element_ty),
53+
mask=offs_dim < dim_end,
54+
)
55+
56+
57+
def moe_sum_reduce(
58+
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
59+
):
60+
assert input.is_contiguous()
61+
assert output.is_contiguous()
62+
63+
token_num, topk_num, hidden_dim = input.shape
64+
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
65+
66+
BLOCK_M = 1
67+
BLOCK_DIM = 2048
68+
NUM_STAGE = 1
69+
num_warps = 8
70+
71+
grid = (
72+
triton.cdiv(token_num, BLOCK_M),
73+
triton.cdiv(hidden_dim, BLOCK_DIM),
74+
)
75+
76+
_moe_sum_reduce_kernel[grid](
77+
input,
78+
*input.stride(),
79+
output,
80+
*output.stride(),
81+
token_num=token_num,
82+
topk_num=topk_num,
83+
hidden_dim=hidden_dim,
84+
routed_scaling_factor=routed_scaling_factor,
85+
BLOCK_M=BLOCK_M,
86+
BLOCK_DIM=BLOCK_DIM,
87+
NUM_STAGE=NUM_STAGE,
88+
num_warps=num_warps,
89+
)
90+
return
91+
92+
93+
def compute_sum_scaled_baseline(
94+
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
95+
) -> torch.Tensor:
96+
torch.sum(x, dim=1, out=out)
97+
out.mul_(routed_scaling_factor)
98+
return out
99+
100+
101+
@torch.compile
102+
def compute_sum_scaled_compiled(
103+
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
104+
) -> torch.Tensor:
105+
torch.sum(x * routed_scaling_factor, dim=1, out=out)
106+
return out
107+
108+
109+
def get_benchmark():
110+
num_tokens_range = [2**i for i in range(0, 13)]
111+
112+
@triton.testing.perf_report(
113+
triton.testing.Benchmark(
114+
x_names=["num_tokens"],
115+
x_vals=num_tokens_range,
116+
line_arg="version",
117+
line_vals=["baseline", "compiled", "triton"],
118+
line_names=["Original", "TorchCompile", "TritonKernel"],
119+
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
120+
ylabel="us",
121+
plot_name="sum_scaled_performance",
122+
args={},
123+
)
124+
)
125+
def benchmark(num_tokens, version):
126+
topk = 9
127+
hidden_size = 4096
128+
dtype = torch.bfloat16
129+
scaling_factor = 0.3
130+
131+
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
132+
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
133+
134+
# Warmup
135+
for _ in range(3):
136+
if version == "baseline":
137+
compute_sum_scaled_baseline(x, out, scaling_factor)
138+
elif version == "compiled":
139+
compute_sum_scaled_compiled(x, out, scaling_factor)
140+
else:
141+
moe_sum_reduce(x, out, scaling_factor)
142+
143+
# Benchmark
144+
quantiles = [0.5, 0.2, 0.8]
145+
if version == "baseline":
146+
ms, min_ms, max_ms = do_bench(
147+
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
148+
quantiles=quantiles,
149+
)
150+
elif version == "compiled":
151+
ms, min_ms, max_ms = do_bench(
152+
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
153+
quantiles=quantiles,
154+
)
155+
else:
156+
ms, min_ms, max_ms = do_bench(
157+
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
158+
)
159+
160+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
161+
162+
return benchmark
163+
164+
165+
def verify_correctness(num_tokens=1024):
166+
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
167+
scaling_factor = 0.3
168+
169+
out_baseline = torch.empty_like(x[:, 0])
170+
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
171+
172+
out_compiled = torch.empty_like(out_baseline)
173+
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
174+
175+
out_triton = torch.empty_like(out_baseline)
176+
moe_sum_reduce(x, out_triton, scaling_factor)
177+
178+
if torch.allclose(
179+
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
180+
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
181+
print("✅ All implementations match")
182+
else:
183+
print("❌ Implementations differ")
184+
print(
185+
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
186+
)
187+
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
188+
189+
190+
if __name__ == "__main__":
191+
print("Running correctness verification...")
192+
verify_correctness()
193+
194+
print("\nRunning performance benchmark...")
195+
benchmark = get_benchmark()
196+
benchmark.run(
197+
print_data=True,
198+
# save_path="./configs/benchmark_ops/sum_scaled/"
199+
)

0 commit comments

Comments
 (0)