Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,20 @@
from sglang.srt.layers.moe.topk import TopKOutput

if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sgl_kernel import scaled_fp4_quant

try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be great to have a env var to allow user to control which to use, thus if one gets some bug we have another

from flashinfer import mm_fp4 as fp4_gemm

enable_flashinfer_fp4_gemm = True
except ImportError:
if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm as fp4_gemm
else:
fp4_gemm = None
enable_flashinfer_fp4_gemm = False

try:
from flashinfer import fp4_quantize as fp4_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
flashinfer_cutlass_fused_moe = None
Expand Down Expand Up @@ -683,11 +693,16 @@ def apply(
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32

out = cutlass_scaled_fp4_mm(
w = layer.weight
w_scale_interleaved = layer.weight_scale_interleaved
if enable_flashinfer_fp4_gemm:
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
out = fp4_gemm(
x_fp4,
layer.weight,
w,
x_scale_interleaved,
layer.weight_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
)
Expand Down
210 changes: 210 additions & 0 deletions sgl-kernel/benchmark/bench_fp4_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import argparse
import copy
import csv
import itertools

import pytest
import torch
import triton
from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant

FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max


def get_weight_shapes(args):
models_tps = args.tp_sizes

if models_tps == [4]:
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]

if models_tps == [8]:
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
return [
[1024, 3584],
[7168, 256],
[7168, 2304],
[9216, 3584],
[512, 3584],
[7168, 128],
[7168, 1152],
[4608, 3584],
]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
# x_vals = [64],
x_log=False,
line_arg="provider",
line_vals=["cutlass", "cudnn", "trtllm"],
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
ylabel="latency (ms)",
plot_name="fp4_gemm_benchmark",
args={},
)
)
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
M = batch_size
packed_k = K
K = 2 * packed_k
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)

alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
# print("a_fp4", a_fp4)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")

quantiles = [0.5, 0.2, 0.8]
if provider == "cutlass":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
),
quantiles=quantiles,
)
if provider == "cudnn":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
),
quantiles=quantiles,
)
if provider == "trtllm":
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
),
quantiles=quantiles,
)
if correctness:
res_cutlass = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="cudnn",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "cudnn fp4 doesn't match cutlass fp4"
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "trtllm fp4 doesn't match cutlass fp4"

if csv_file:
with open(csv_file, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([provider, M, N, K, ms])

return ms, min_ms, max_ms


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
parser.add_argument(
"--dtype",
type=torch.dtype,
default=torch.bfloat16,
help="Data type",
)
parser.add_argument(
"--correctness",
action="store_true",
help="Check correctness",
)
parser.add_argument(
"--csv",
type=str,
default="results_cutlass_cudnn.csv",
help="CSV file to save results",
)
args = parser.parse_args()

if args.csv:
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"])

NKs = get_weight_shapes(args)
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)

print("Benchmark finished!")
Loading