Skip to content

Commit d503e7d

Browse files
ElizaWszolaLucasWilkinson
authored andcommitted
[Kernel] CUTLASS grouped gemm fp8 MoE kernel (vllm-project#13972)
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent cc52a9d commit d503e7d

File tree

22 files changed

+2317
-15
lines changed

22 files changed

+2317
-15
lines changed

CMakeLists.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
461461
set(FP4_ARCHS)
462462
endif()
463463

464+
#
465+
# CUTLASS MoE kernels
466+
467+
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
468+
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
469+
# to compile MoE kernels that use its output.
470+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
471+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
472+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
473+
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
474+
set_gencode_flags_for_srcs(
475+
SRCS "${SRCS}"
476+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
477+
list(APPEND VLLM_EXT_SRC "${SRCS}")
478+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
479+
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
480+
else()
481+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
482+
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
483+
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
484+
"if you intend on running FP8 quantized MoE models on Hopper.")
485+
else()
486+
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
487+
"in CUDA target architectures")
488+
endif()
489+
endif()
490+
464491
#
465492
# Machete kernels
466493

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import torch
4+
import torch.utils.benchmark as benchmark
5+
from benchmark_shapes import WEIGHT_SHAPES_MOE
6+
7+
from vllm import _custom_ops as ops
8+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
9+
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
10+
fused_experts,
11+
fused_topk)
12+
from vllm.utils import FlexibleArgumentParser
13+
14+
DEFAULT_MODELS = [
15+
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
16+
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
17+
]
18+
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
19+
DEFAULT_TP_SIZES = [1]
20+
21+
PER_ACT_TOKEN_OPTS = [False]
22+
PER_OUT_CH_OPTS = [False]
23+
24+
25+
def to_fp8(tensor: torch.Tensor):
26+
finfo = torch.finfo(torch.float8_e4m3fn)
27+
return torch.round(tensor.clamp(
28+
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
29+
30+
31+
def bench_run(results: list[benchmark.Measurement], model: str,
32+
num_experts: int, topk: int, per_act_token: bool,
33+
per_out_ch: bool, mkn: tuple[int, int, int]):
34+
label = "Quant Matmul"
35+
36+
sub_label = (
37+
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
38+
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
39+
mkn))
40+
41+
print(f"Testing: {sub_label}")
42+
43+
(m, k, n) = mkn
44+
45+
dtype = torch.half
46+
47+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
48+
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
49+
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
50+
51+
_, a_scale = ops.scaled_fp8_quant(a)
52+
53+
w1_q = torch.empty((num_experts, 2 * n, k),
54+
device="cuda",
55+
dtype=torch.float8_e4m3fn)
56+
w2_q = torch.empty((num_experts, k, n),
57+
device="cuda",
58+
dtype=torch.float8_e4m3fn)
59+
w1_scale = torch.empty((num_experts, 1, 1),
60+
device="cuda",
61+
dtype=torch.float32)
62+
w2_scale = torch.empty((num_experts, 1, 1),
63+
device="cuda",
64+
dtype=torch.float32)
65+
66+
ab_strides1 = torch.full((num_experts, ),
67+
k,
68+
device="cuda",
69+
dtype=torch.int64)
70+
c_strides1 = torch.full((num_experts, ),
71+
2 * n,
72+
device="cuda",
73+
dtype=torch.int64)
74+
ab_strides2 = torch.full((num_experts, ),
75+
n,
76+
device="cuda",
77+
dtype=torch.int64)
78+
c_strides2 = torch.full((num_experts, ),
79+
k,
80+
device="cuda",
81+
dtype=torch.int64)
82+
83+
for expert in range(num_experts):
84+
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
85+
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
86+
w1_q_notransp = w1_q.clone()
87+
w2_q_notransp = w2_q.clone()
88+
w1_q = w1_q.transpose(1, 2)
89+
w2_q = w2_q.transpose(1, 2)
90+
91+
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
92+
93+
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
94+
95+
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
96+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
97+
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
98+
a_scale: torch.Tensor, num_repeats: int):
99+
for _ in range(num_repeats):
100+
fused_experts(a,
101+
w1,
102+
w2,
103+
topk_weights,
104+
topk_ids,
105+
use_fp8_w8a8=True,
106+
w1_scale=w1_scale,
107+
w2_scale=w2_scale,
108+
a1_scale=a_scale)
109+
110+
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
111+
w1: torch.Tensor, w2: torch.Tensor,
112+
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
113+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
114+
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
115+
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
116+
num_repeats: int):
117+
for _ in range(num_repeats):
118+
cutlass_moe_fp8(a,
119+
w1,
120+
w2,
121+
w1_scale,
122+
w2_scale,
123+
topk_weights,
124+
topk_ids,
125+
ab_strides1,
126+
c_strides1,
127+
ab_strides2,
128+
c_strides2,
129+
a1_scale=a_scale)
130+
131+
def run_cutlass_from_graph(
132+
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
133+
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
134+
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
135+
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
136+
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
137+
with set_current_vllm_config(
138+
VllmConfig(parallel_config=ParallelConfig(
139+
pipeline_parallel_size=1))):
140+
return cutlass_moe_fp8(a,
141+
w1_q,
142+
w2_q,
143+
w1_scale,
144+
w2_scale,
145+
topk_weights,
146+
topk_ids,
147+
ab_strides1,
148+
c_strides1,
149+
ab_strides2,
150+
c_strides2,
151+
a1_scale=a_scale)
152+
153+
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
154+
w2: torch.Tensor, topk_weights: torch.Tensor,
155+
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
156+
w2_scale: torch.Tensor, a_scale: torch.Tensor):
157+
with set_current_vllm_config(
158+
VllmConfig(parallel_config=ParallelConfig(
159+
pipeline_parallel_size=1))):
160+
return fused_experts(a,
161+
w1,
162+
w2,
163+
topk_weights,
164+
topk_ids,
165+
use_fp8_w8a8=True,
166+
w1_scale=w1_scale,
167+
w2_scale=w2_scale,
168+
a1_scale=a_scale)
169+
170+
def replay_graph(graph, num_repeats):
171+
for _ in range(num_repeats):
172+
graph.replay()
173+
torch.cuda.synchronize()
174+
175+
cutlass_stream = torch.cuda.Stream()
176+
cutlass_graph = torch.cuda.CUDAGraph()
177+
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
178+
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
179+
topk_weights, topk_ids, ab_strides1, c_strides1,
180+
ab_strides2, c_strides2)
181+
torch.cuda.synchronize()
182+
183+
triton_stream = torch.cuda.Stream()
184+
triton_graph = torch.cuda.CUDAGraph()
185+
with torch.cuda.graph(triton_graph, stream=triton_stream):
186+
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
187+
topk_ids, w1_scale, w2_scale, a_scale)
188+
torch.cuda.synchronize()
189+
190+
min_run_time = 5
191+
num_warmup = 5
192+
num_runs = 25
193+
194+
globals = {
195+
# Baseline params
196+
"w1": w1,
197+
"w2": w2,
198+
"score": score,
199+
"topk": topk,
200+
"w1_q_notransp": w1_q_notransp,
201+
"w2_q_notransp": w2_q_notransp,
202+
# Cutlass params
203+
"a_scale": a_scale,
204+
"w1_q": w1_q,
205+
"w2_q": w2_q,
206+
"w1_scale": w1_scale,
207+
"w2_scale": w2_scale,
208+
"ab_strides1": ab_strides1,
209+
"c_strides1": c_strides1,
210+
"ab_strides2": ab_strides2,
211+
"c_strides2": c_strides2,
212+
# cuda graph params
213+
"cutlass_graph": cutlass_graph,
214+
"triton_graph": triton_graph,
215+
# Gen params
216+
"a": a,
217+
"topk_weights": topk_weights,
218+
"topk_ids": topk_ids,
219+
"num_runs": num_runs,
220+
# Kernels
221+
"run_triton_moe": run_triton_moe,
222+
"run_cutlass_moe": run_cutlass_moe,
223+
"replay_graph": replay_graph,
224+
}
225+
226+
# Warmup
227+
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
228+
w1_scale, w2_scale, a_scale, num_warmup)
229+
230+
results.append(
231+
benchmark.Timer(
232+
stmt=
233+
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
234+
globals=globals,
235+
label=label,
236+
sub_label=sub_label,
237+
description="triton_moe",
238+
).blocked_autorange(min_run_time=min_run_time))
239+
240+
# Warmup
241+
replay_graph(triton_graph, num_warmup)
242+
243+
results.append(
244+
benchmark.Timer(
245+
stmt="replay_graph(triton_graph, num_runs)",
246+
globals=globals,
247+
label=label,
248+
sub_label=sub_label,
249+
description="triton_moe_cuda_graphs",
250+
).blocked_autorange(min_run_time=min_run_time))
251+
252+
# Warmup
253+
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
254+
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
255+
num_warmup)
256+
257+
results.append(
258+
benchmark.Timer(
259+
stmt=
260+
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
261+
globals=globals,
262+
label=label,
263+
sub_label=sub_label,
264+
description="grouped_gemm_moe",
265+
).blocked_autorange(min_run_time=min_run_time))
266+
267+
# Warmup
268+
replay_graph(cutlass_graph, num_warmup)
269+
270+
results.append(
271+
benchmark.Timer(
272+
stmt="replay_graph(cutlass_graph, num_runs)",
273+
globals=globals,
274+
label=label,
275+
sub_label=sub_label,
276+
description="grouped_gemm_moe_cuda_graphs",
277+
).blocked_autorange(min_run_time=min_run_time))
278+
279+
280+
def main(args):
281+
print("Benchmarking models:")
282+
for i, model in enumerate(args.models):
283+
print(f"[{i}] {model}")
284+
285+
results: list[benchmark.Measurement] = []
286+
287+
for model in args.models:
288+
for tp in args.tp_sizes:
289+
for layer in WEIGHT_SHAPES_MOE[model]:
290+
num_experts = layer[0]
291+
topk = layer[1]
292+
size_k = layer[2]
293+
size_n = layer[3] // tp
294+
295+
if len(args.limit_k) > 0 and size_k not in args.limit_k:
296+
continue
297+
298+
if len(args.limit_n) > 0 and size_n not in args.limit_n:
299+
continue
300+
301+
for per_act_token in PER_ACT_TOKEN_OPTS:
302+
for per_out_ch in PER_OUT_CH_OPTS:
303+
for size_m in DEFAULT_BATCH_SIZES:
304+
mkn = (size_m, size_k, size_n)
305+
bench_run(results, model, num_experts, topk,
306+
per_act_token, per_out_ch, mkn)
307+
308+
compare = benchmark.Compare(results)
309+
compare.print()
310+
311+
312+
if __name__ == "__main__":
313+
parser = FlexibleArgumentParser(
314+
description="Benchmark Marlin across specified models/shapes/batches")
315+
parser.add_argument(
316+
"--models",
317+
nargs="+",
318+
type=str,
319+
default=DEFAULT_MODELS,
320+
choices=WEIGHT_SHAPES_MOE.keys(),
321+
)
322+
parser.add_argument("--tp-sizes",
323+
nargs="+",
324+
type=int,
325+
default=DEFAULT_TP_SIZES)
326+
parser.add_argument("--batch-sizes",
327+
nargs="+",
328+
type=int,
329+
default=DEFAULT_BATCH_SIZES)
330+
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
331+
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
332+
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
333+
parser.add_argument("--limit-per-act-token",
334+
nargs="+",
335+
type=int,
336+
default=[])
337+
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
338+
339+
args = parser.parse_args()
340+
main(args)

benchmarks/kernels/benchmark_shapes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,19 @@
7575
[7168, 8192],
7676
],
7777
}
78+
79+
WEIGHT_SHAPES_MOE = {
80+
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
81+
[8, 2, 4096, 28672],
82+
[8, 2, 14336, 4096],
83+
],
84+
"nm-testing/deepseekv2-lite": [
85+
[64, 6, 2048, 1408],
86+
],
87+
"ibm-granite/granite-3.0-1b-a400m": [
88+
[32, 8, 1024, 1024],
89+
],
90+
"ibm-granite/granite-3.0-3b-a800m": [
91+
[40, 8, 1024, 1536],
92+
],
93+
}

0 commit comments

Comments
 (0)