-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Kernel] CUTLASS grouped gemm fp8 MoE kernel #13972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 68 commits into
vllm-project:main
from
neuralmagic:grouped-gemm-with-group-id
Mar 27, 2025
Merged
Changes from 60 commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
1825ef8
Cutlass grouped gemm files
ElizaWszola 5fd48e5
runs, bad result
ElizaWszola d5942cf
A little closer to working
ElizaWszola c570c69
Working for identical sizes
ElizaWszola 6ed63f2
Grouped gemm working
ElizaWszola e2b1fc0
Small cleanup
ElizaWszola dd163f5
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola acfd3ef
Benchmark grouped cutlass against bfloat16 torch.mm
ElizaWszola c6231b6
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola f1a5666
Start working on fused moe cutlass implementation
ElizaWszola 6414e31
Working halfway
ElizaWszola 67e2dd4
working mul test but the topk_weights are not yet included in kernel
ElizaWszola 6523529
cleaned up cutlass moe test, fixes
ElizaWszola b302d98
benchmark fused
ElizaWszola 342d1a4
pass input as one tensor with an array of offsets rather than a list …
ElizaWszola 7549e3d
Using tensors rather than tensor lists works with test_cutlass test
ElizaWszola 64c2a68
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 1ea7874
cleanup, add import
ElizaWszola d608164
working fused op
ElizaWszola 286f6c8
benchmark, create strides directly on device, small name refactor
ElizaWszola b6867bb
works with cuda graphs
ElizaWszola df04bc0
move stride tensor creation outside c++ code, cleanup
ElizaWszola 88c7134
cleanup benchmark
ElizaWszola 02e1d4e
profile
ElizaWszola 1d9c429
tuned shapes, fix
ElizaWszola b824ad2
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola ae90eee
Performance, add channelwise scales everywhere
ElizaWszola f191b35
name fix
ElizaWszola 22d4f7b
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 51941ff
perf improvements in data preparation
ElizaWszola d3cf1db
Integrate with deepseek v2
ElizaWszola 175ecdd
cudagraphs fix
ElizaWszola 3d7a487
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola ec0cb94
larger index type to support very large batches
ElizaWszola 6dd6d48
update benchmarks
ElizaWszola 716d8c0
Faster data preparation kernels, bring back correct benchmark shapes
ElizaWszola 975ab5f
enable cutlass grouped gemm only on sm90
ElizaWszola e83910e
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 89f2d1c
Move arch detection to CompressedTensorsMoEMethod, cleanup, bring bac…
ElizaWszola 4d2f62f
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 8fddd4f
Fix merge, cleanup imports
ElizaWszola 583f749
fix benchmark precommit hooks
ElizaWszola 10f5a97
Various cleanups
ElizaWszola 5e85587
precommit hook fix
ElizaWszola 63f6733
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 8f5ac77
Post-merge fix, fallback to triton if not yet implemented features ar…
ElizaWszola 3a01616
Lots of minor feedback changes, self-commenting names
ElizaWszola 3159141
format
ElizaWszola baa503d
Decide whether to use cutlass or triton in compressed tensors method …
ElizaWszola ed673cb
Docs, remove redundant args
ElizaWszola 5287681
Changed CUDA version error message, added tp TODO to benchmark
ElizaWszola 42dc92c
Add tp argument to benchmarks
ElizaWszola 53ab07a
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola d8de3c9
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 83f7084
Add bfloat16 type to the kernel
ElizaWszola be83180
Rename groups to num_experts in kernel, make group starts kernel more…
ElizaWszola e6481c8
format
ElizaWszola f0c2f06
format
ElizaWszola 8d0e700
format 3
ElizaWszola 84dbc2a
Add hack for accepting int input in weak_ref_tensors
ElizaWszola 5ad4b0b
Fixes
ElizaWszola 41eb522
format utils.py
ElizaWszola f5b5c7d
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola c6076b3
Make handling of both input scales consistent in the code
ElizaWszola c8f1567
Fix handling optional vals
ElizaWszola 96296cb
feedback: version checks, file structure
ElizaWszola 3977d67
Change cmake flag, remove unused code
ElizaWszola 83ee170
update kernel run conditions in scaled_mm_entry.cu
ElizaWszola File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,324 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import torch | ||
| import torch.utils.benchmark as benchmark | ||
| from benchmark_shapes import WEIGHT_SHAPES_MOE | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, | ||
| fused_experts, | ||
| fused_topk) | ||
| from vllm.utils import FlexibleArgumentParser | ||
|
|
||
| DEFAULT_MODELS = [ | ||
| "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", | ||
| "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" | ||
| ] | ||
| DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] | ||
| DEFAULT_TP_SIZES = [1] | ||
|
|
||
| PER_ACT_TOKEN_OPTS = [False] | ||
| PER_OUT_CH_OPTS = [False] | ||
|
|
||
|
|
||
| def to_fp8(tensor: torch.Tensor): | ||
| finfo = torch.finfo(torch.float8_e4m3fn) | ||
| return torch.round(tensor.clamp( | ||
| min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) | ||
|
|
||
|
|
||
| def bench_run(results: list[benchmark.Measurement], model: str, | ||
| num_experts: int, topk: int, per_act_token: bool, | ||
| per_out_ch: bool, mkn: tuple[int, int, int]): | ||
| label = "Quant Matmul" | ||
|
|
||
| sub_label = ( | ||
| "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " | ||
| "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, | ||
| mkn)) | ||
|
|
||
| print(f"Testing: {sub_label}") | ||
|
|
||
| (m, k, n) = mkn | ||
|
|
||
| dtype = torch.half | ||
|
|
||
| a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 | ||
| w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 | ||
| w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 | ||
|
|
||
| a_q, a_scale = ops.scaled_fp8_quant(a) | ||
|
|
||
| w1_q = torch.empty((num_experts, 2 * n, k), | ||
| device="cuda", | ||
| dtype=torch.float8_e4m3fn) | ||
| w2_q = torch.empty((num_experts, k, n), | ||
| device="cuda", | ||
| dtype=torch.float8_e4m3fn) | ||
| w1_scale = torch.empty((num_experts, 1, 1), | ||
| device="cuda", | ||
| dtype=torch.float32) | ||
| w2_scale = torch.empty((num_experts, 1, 1), | ||
| device="cuda", | ||
| dtype=torch.float32) | ||
|
|
||
| ab_strides1 = torch.full((num_experts, ), | ||
| k, | ||
| device="cuda", | ||
| dtype=torch.int64) | ||
| c_strides1 = torch.full((num_experts, ), | ||
| 2 * n, | ||
| device="cuda", | ||
| dtype=torch.int64) | ||
| ab_strides2 = torch.full((num_experts, ), | ||
| n, | ||
| device="cuda", | ||
| dtype=torch.int64) | ||
| c_strides2 = torch.full((num_experts, ), | ||
| k, | ||
| device="cuda", | ||
| dtype=torch.int64) | ||
|
|
||
| for expert in range(num_experts): | ||
| w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) | ||
| w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) | ||
| w1_q_notransp = w1_q.clone() | ||
| w2_q_notransp = w2_q.clone() | ||
| w1_q = w1_q.transpose(1, 2) | ||
| w2_q = w2_q.transpose(1, 2) | ||
|
|
||
| score = torch.randn((m, num_experts), device="cuda", dtype=dtype) | ||
|
|
||
| topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) | ||
|
|
||
| def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, topk_ids: torch.Tensor, | ||
| w1_scale: torch.Tensor, w2_scale: torch.Tensor, | ||
| a_scale: torch.Tensor, num_repeats: int): | ||
| for _ in range(num_repeats): | ||
| fused_experts(a, | ||
| w1, | ||
| w2, | ||
| topk_weights, | ||
| topk_ids, | ||
| use_fp8_w8a8=True, | ||
| w1_scale=w1_scale, | ||
| w2_scale=w2_scale, | ||
| a1_scale=a_scale) | ||
|
|
||
| def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, | ||
| w1: torch.Tensor, w2: torch.Tensor, | ||
| w1_scale: torch.Tensor, w2_scale: torch.Tensor, | ||
| topk_weights: torch.Tensor, topk_ids: torch.Tensor, | ||
| ab_strides1: torch.Tensor, c_strides1: torch.Tensor, | ||
| ab_strides2: torch.Tensor, c_strides2: torch.Tensor, | ||
| num_repeats: int): | ||
| for _ in range(num_repeats): | ||
| cutlass_moe_fp8(a, a_scale, w1, w2, w1_scale, w2_scale, | ||
| topk_weights, topk_ids, ab_strides1, c_strides1, | ||
| ab_strides2, c_strides2) | ||
|
|
||
| def run_cutlass_from_graph( | ||
| a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, | ||
| w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, | ||
| topk_weights: torch.Tensor, topk_ids: torch.Tensor, | ||
| ab_strides1: torch.Tensor, c_strides1: torch.Tensor, | ||
| ab_strides2: torch.Tensor, c_strides2: torch.Tensor): | ||
| with set_current_vllm_config( | ||
| VllmConfig(parallel_config=ParallelConfig( | ||
| pipeline_parallel_size=1))): | ||
| return cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, | ||
| w2_scale, topk_weights, topk_ids, | ||
| ab_strides1, c_strides1, ab_strides2, | ||
| c_strides2) | ||
|
|
||
| def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, | ||
| w2: torch.Tensor, topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, w1_scale: torch.Tensor, | ||
| w2_scale: torch.Tensor, a_scale: torch.Tensor): | ||
| with set_current_vllm_config( | ||
| VllmConfig(parallel_config=ParallelConfig( | ||
| pipeline_parallel_size=1))): | ||
| return fused_experts(a, | ||
| w1, | ||
| w2, | ||
| topk_weights, | ||
| topk_ids, | ||
| use_fp8_w8a8=True, | ||
| w1_scale=w1_scale, | ||
| w2_scale=w2_scale, | ||
| a1_scale=a_scale) | ||
|
|
||
| def replay_graph(graph, num_repeats): | ||
| for _ in range(num_repeats): | ||
| graph.replay() | ||
| torch.cuda.synchronize() | ||
|
|
||
| cutlass_stream = torch.cuda.Stream() | ||
| cutlass_graph = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): | ||
| run_cutlass_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, | ||
| topk_weights, topk_ids, ab_strides1, c_strides1, | ||
| ab_strides2, c_strides2) | ||
| torch.cuda.synchronize() | ||
|
|
||
| triton_stream = torch.cuda.Stream() | ||
| triton_graph = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(triton_graph, stream=triton_stream): | ||
| run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, | ||
| topk_ids, w1_scale, w2_scale, a_scale) | ||
| torch.cuda.synchronize() | ||
|
|
||
| min_run_time = 5 | ||
| num_warmup = 5 | ||
| num_runs = 25 | ||
|
|
||
| globals = { | ||
| # Baseline params | ||
| "a": a, | ||
| "w1": w1, | ||
| "w2": w2, | ||
| "score": score, | ||
| "topk": topk, | ||
| "w1_q_notransp": w1_q_notransp, | ||
| "w2_q_notransp": w2_q_notransp, | ||
| # Cutlass params | ||
| "a_q": a_q, | ||
| "a_scale": a_scale, | ||
| "w1_q": w1_q, | ||
| "w2_q": w2_q, | ||
| "w1_scale": w1_scale, | ||
| "w2_scale": w2_scale, | ||
| "ab_strides1": ab_strides1, | ||
| "c_strides1": c_strides1, | ||
| "ab_strides2": ab_strides2, | ||
| "c_strides2": c_strides2, | ||
| # cuda graph params | ||
| "cutlass_graph": cutlass_graph, | ||
| "triton_graph": triton_graph, | ||
| # Gen params | ||
| "topk_weights": topk_weights, | ||
| "topk_ids": topk_ids, | ||
| "num_runs": num_runs, | ||
| # Kernels | ||
| "run_triton_moe": run_triton_moe, | ||
| "run_cutlass_moe": run_cutlass_moe, | ||
| "replay_graph": replay_graph, | ||
| } | ||
|
|
||
| # Warmup | ||
| run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, | ||
| w1_scale, w2_scale, a_scale, num_warmup) | ||
|
|
||
| results.append( | ||
| benchmark.Timer( | ||
| stmt= | ||
| "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 | ||
| globals=globals, | ||
| label=label, | ||
| sub_label=sub_label, | ||
| description="triton_moe", | ||
| ).blocked_autorange(min_run_time=min_run_time)) | ||
|
|
||
| # Warmup | ||
| replay_graph(triton_graph, num_warmup) | ||
|
|
||
| results.append( | ||
| benchmark.Timer( | ||
| stmt="replay_graph(triton_graph, num_runs)", | ||
| globals=globals, | ||
| label=label, | ||
| sub_label=sub_label, | ||
| description="triton_moe_cuda_graphs", | ||
| ).blocked_autorange(min_run_time=min_run_time)) | ||
|
|
||
| # Warmup | ||
| run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, | ||
| topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, | ||
| num_warmup) | ||
|
|
||
| results.append( | ||
| benchmark.Timer( | ||
| stmt= | ||
| "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 | ||
| globals=globals, | ||
| label=label, | ||
| sub_label=sub_label, | ||
| description="grouped_gemm_moe", | ||
| ).blocked_autorange(min_run_time=min_run_time)) | ||
|
|
||
| # Warmup | ||
| replay_graph(cutlass_graph, num_warmup) | ||
|
|
||
| results.append( | ||
| benchmark.Timer( | ||
| stmt="replay_graph(cutlass_graph, num_runs)", | ||
| globals=globals, | ||
| label=label, | ||
| sub_label=sub_label, | ||
| description="grouped_gemm_moe_cuda_graphs", | ||
| ).blocked_autorange(min_run_time=min_run_time)) | ||
|
|
||
|
|
||
| def main(args): | ||
| print("Benchmarking models:") | ||
| for i, model in enumerate(args.models): | ||
| print(f"[{i}] {model}") | ||
|
|
||
| results: list[benchmark.Measurement] = [] | ||
|
|
||
| for model in args.models: | ||
| for tp in args.tp_sizes: | ||
| for layer in WEIGHT_SHAPES_MOE[model]: | ||
| num_experts = layer[0] | ||
| topk = layer[1] | ||
| size_k = layer[2] | ||
| size_n = layer[3] // tp | ||
|
|
||
| if len(args.limit_k) > 0 and size_k not in args.limit_k: | ||
| continue | ||
|
|
||
| if len(args.limit_n) > 0 and size_n not in args.limit_n: | ||
| continue | ||
|
|
||
| for per_act_token in PER_ACT_TOKEN_OPTS: | ||
| for per_out_ch in PER_OUT_CH_OPTS: | ||
| for size_m in DEFAULT_BATCH_SIZES: | ||
| mkn = (size_m, size_k, size_n) | ||
| bench_run(results, model, num_experts, topk, | ||
| per_act_token, per_out_ch, mkn) | ||
|
|
||
| compare = benchmark.Compare(results) | ||
| compare.print() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = FlexibleArgumentParser( | ||
| description="Benchmark Marlin across specified models/shapes/batches") | ||
| parser.add_argument( | ||
| "--models", | ||
| nargs="+", | ||
| type=str, | ||
| default=DEFAULT_MODELS, | ||
| choices=WEIGHT_SHAPES_MOE.keys(), | ||
| ) | ||
| parser.add_argument("--tp-sizes", | ||
| nargs="+", | ||
| type=int, | ||
| default=DEFAULT_TP_SIZES) | ||
| parser.add_argument("--batch-sizes", | ||
| nargs="+", | ||
| type=int, | ||
| default=DEFAULT_BATCH_SIZES) | ||
| parser.add_argument("--limit-k", nargs="+", type=int, default=[]) | ||
| parser.add_argument("--limit-n", nargs="+", type=int, default=[]) | ||
| parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) | ||
| parser.add_argument("--limit-per-act-token", | ||
| nargs="+", | ||
| type=int, | ||
| default=[]) | ||
| parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) | ||
|
|
||
| args = parser.parse_args() | ||
| main(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.