Skip to content
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
1825ef8
Cutlass grouped gemm files
ElizaWszola Dec 6, 2024
5fd48e5
runs, bad result
ElizaWszola Dec 9, 2024
d5942cf
A little closer to working
ElizaWszola Dec 10, 2024
c570c69
Working for identical sizes
ElizaWszola Dec 11, 2024
6ed63f2
Grouped gemm working
ElizaWszola Dec 17, 2024
e2b1fc0
Small cleanup
ElizaWszola Dec 17, 2024
dd163f5
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 8, 2025
acfd3ef
Benchmark grouped cutlass against bfloat16 torch.mm
ElizaWszola Jan 13, 2025
c6231b6
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 13, 2025
f1a5666
Start working on fused moe cutlass implementation
ElizaWszola Jan 17, 2025
6414e31
Working halfway
ElizaWszola Jan 20, 2025
67e2dd4
working mul test but the topk_weights are not yet included in kernel
ElizaWszola Jan 23, 2025
6523529
cleaned up cutlass moe test, fixes
ElizaWszola Jan 23, 2025
b302d98
benchmark fused
ElizaWszola Jan 23, 2025
342d1a4
pass input as one tensor with an array of offsets rather than a list …
ElizaWszola Jan 24, 2025
7549e3d
Using tensors rather than tensor lists works with test_cutlass test
ElizaWszola Jan 28, 2025
64c2a68
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 28, 2025
1ea7874
cleanup, add import
ElizaWszola Jan 28, 2025
d608164
working fused op
ElizaWszola Jan 29, 2025
286f6c8
benchmark, create strides directly on device, small name refactor
ElizaWszola Jan 29, 2025
b6867bb
works with cuda graphs
ElizaWszola Jan 31, 2025
df04bc0
move stride tensor creation outside c++ code, cleanup
ElizaWszola Jan 31, 2025
88c7134
cleanup benchmark
ElizaWszola Jan 31, 2025
02e1d4e
profile
ElizaWszola Feb 4, 2025
1d9c429
tuned shapes, fix
ElizaWszola Feb 14, 2025
b824ad2
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Feb 14, 2025
ae90eee
Performance, add channelwise scales everywhere
ElizaWszola Feb 18, 2025
f191b35
name fix
ElizaWszola Feb 20, 2025
22d4f7b
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Feb 20, 2025
51941ff
perf improvements in data preparation
ElizaWszola Feb 20, 2025
d3cf1db
Integrate with deepseek v2
ElizaWszola Feb 24, 2025
175ecdd
cudagraphs fix
ElizaWszola Feb 24, 2025
3d7a487
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Feb 25, 2025
ec0cb94
larger index type to support very large batches
ElizaWszola Feb 25, 2025
6dd6d48
update benchmarks
ElizaWszola Feb 25, 2025
716d8c0
Faster data preparation kernels, bring back correct benchmark shapes
ElizaWszola Feb 27, 2025
975ab5f
enable cutlass grouped gemm only on sm90
ElizaWszola Feb 28, 2025
e83910e
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 4, 2025
89f2d1c
Move arch detection to CompressedTensorsMoEMethod, cleanup, bring bac…
ElizaWszola Mar 5, 2025
4d2f62f
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 5, 2025
8fddd4f
Fix merge, cleanup imports
ElizaWszola Mar 5, 2025
583f749
fix benchmark precommit hooks
ElizaWszola Mar 5, 2025
10f5a97
Various cleanups
ElizaWszola Mar 5, 2025
5e85587
precommit hook fix
ElizaWszola Mar 5, 2025
63f6733
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 12, 2025
8f5ac77
Post-merge fix, fallback to triton if not yet implemented features ar…
ElizaWszola Mar 12, 2025
3a01616
Lots of minor feedback changes, self-commenting names
ElizaWszola Mar 17, 2025
3159141
format
ElizaWszola Mar 17, 2025
baa503d
Decide whether to use cutlass or triton in compressed tensors method …
ElizaWszola Mar 17, 2025
ed673cb
Docs, remove redundant args
ElizaWszola Mar 18, 2025
5287681
Changed CUDA version error message, added tp TODO to benchmark
ElizaWszola Mar 18, 2025
42dc92c
Add tp argument to benchmarks
ElizaWszola Mar 18, 2025
53ab07a
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 18, 2025
d8de3c9
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 18, 2025
83f7084
Add bfloat16 type to the kernel
ElizaWszola Mar 18, 2025
be83180
Rename groups to num_experts in kernel, make group starts kernel more…
ElizaWszola Mar 19, 2025
e6481c8
format
ElizaWszola Mar 19, 2025
f0c2f06
format
ElizaWszola Mar 19, 2025
8d0e700
format 3
ElizaWszola Mar 19, 2025
84dbc2a
Add hack for accepting int input in weak_ref_tensors
ElizaWszola Mar 21, 2025
5ad4b0b
Fixes
ElizaWszola Mar 24, 2025
41eb522
format utils.py
ElizaWszola Mar 24, 2025
f5b5c7d
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 24, 2025
c6076b3
Make handling of both input scales consistent in the code
ElizaWszola Mar 26, 2025
c8f1567
Fix handling optional vals
ElizaWszola Mar 26, 2025
96296cb
feedback: version checks, file structure
ElizaWszola Mar 26, 2025
3977d67
Change cmake flag, remove unused code
ElizaWszola Mar 26, 2025
83ee170
update kernel run conditions in scaled_mm_entry.cu
ElizaWszola Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
Expand Down
324 changes: 324 additions & 0 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe,
fused_experts,
fused_topk)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = [
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]

PER_ACT_TOKEN_OPTS = [False] #[False, True]
PER_OUT_CH_OPTS = [False] #[False, True]


def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


def bench_run(results: list[benchmark.Measurement], model: str,
num_experts: int, topk: int, per_act_token: bool,
per_out_ch: bool, mkn: tuple[int, int, int]):
label = "Quant Matmul"

sub_label = (
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
mkn))

print(f"Testing: {sub_label}")

(m, k, n) = mkn

dtype = torch.half

a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

a_q, a_scale = ops.scaled_fp8_quant(a)

w1_q = torch.empty((num_experts, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((num_experts, k, n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)

ab_strides1 = torch.full((num_experts, ),
k,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_experts, ),
2 * n,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_experts, ),
n,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_experts, ),
k,
device="cuda",
dtype=torch.int64)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)

def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a_scale: torch.Tensor, num_repeats: int):
for _ in range(num_repeats):
fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)

def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
m: int, n: int, k: int, num_experts: int,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
num_repeats: int):
for _ in range(num_repeats):
cutlass_moe(a, a_scale, w1, w2, w1_scale, w2_scale, topk_weights,
topk_ids, m, n, k, num_experts, ab_strides1,
c_strides1, ab_strides2, c_strides2)

def run_cutlass_from_graph(
a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int,
k: int, e: int, ab_strides1: torch.Tensor,
c_strides1: torch.Tensor, ab_strides2: torch.Tensor,
c_strides2: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, m, n, k, e, ab_strides1,
c_strides1, ab_strides2, c_strides2)

def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, a_scale: torch.Tensor):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)

def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
graph.replay()
torch.cuda.synchronize()

cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
run_cutlass_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, m, n, k, num_experts,
ab_strides1, c_strides1, ab_strides2,
c_strides2)
torch.cuda.synchronize()

triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
topk_ids, w1_scale, w2_scale, a_scale)
torch.cuda.synchronize()

min_run_time = 5
num_warmup = 5
num_runs = 25

globals = {
# Baseline params
"a": a,
"w1": w1,
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_q": a_q,
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"m": m,
"n": n,
"k": k,
"num_experts": num_experts,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
# Gen params
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"num_runs": num_runs,
# Kernels
"run_triton_moe": run_triton_moe,
"run_cutlass_moe": run_cutlass_moe,
"replay_graph": replay_graph,
}

# Warmup
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
w1_scale, w2_scale, a_scale, num_warmup)

results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
replay_graph(triton_graph, num_warmup)

results.append(
benchmark.Timer(
stmt="replay_graph(triton_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
topk_ids, m, n, k, num_experts, ab_strides1, c_strides1,
ab_strides2, c_strides2, num_warmup)

results.append(
benchmark.Timer(
stmt=
"run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
replay_graph(cutlass_graph, num_warmup)

results.append(
benchmark.Timer(
stmt="replay_graph(cutlass_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))


def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

results: list[benchmark.Measurement] = []

for model in args.models:
for layer in WEIGHT_SHAPES_MOE[model]:
num_experts = layer[0]
topk = layer[1]
size_k = layer[2]
size_n = layer[3]

if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue

if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue

for per_act_token in PER_ACT_TOKEN_OPTS:
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in DEFAULT_BATCH_SIZES:
mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk,
per_act_token, per_out_ch, mkn)

compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token",
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

args = parser.parse_args()
main(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A --tp-sizes argument with the same behavior as in benchmarks/cutlass_benchmarks/w8a8_benchmarks.py would be very nice to have, especially to compare and contrast performance of the kernel in the EP vs the TP case.

16 changes: 16 additions & 0 deletions benchmarks/kernels/benchmark_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@
[7168, 8192],
],
}

WEIGHT_SHAPES_MOE = {
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
[8, 2, 4096, 28672],
[8, 2, 14336, 4096],
],
"nm-testing/deepseekv2-lite": [
[64, 6, 2048, 1408],
],
"ibm-granite/granite-3.0-1b-a400m": [
[32, 8, 1024, 1024],
],
"ibm-granite/granite-3.0-3b-a800m": [
[40, 8, 1024, 1536],
],
}
Loading