Skip to content

Commit 00e2441

Browse files
committed
support blockwise fp8 matmul kernel
1 parent 455bfe8 commit 00e2441

File tree

11 files changed

+1367
-0
lines changed

11 files changed

+1367
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import argparse
2+
import copy
3+
import itertools
4+
5+
import torch
6+
import triton
7+
from sgl_kernel import fp8_blockwise_scaled_mm
8+
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
9+
10+
11+
def get_weight_shapes(args):
12+
models_tps = list(itertools.product(args.models, args.tp_sizes))
13+
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
14+
# cannot TP
15+
total = [
16+
# (512 + 64, 7168), # this weight is not supported by current kernel
17+
((128 + 64) * 128, 7168),
18+
(128 * (128 + 128), 512),
19+
(7168, 16384),
20+
(7168, 18432),
21+
]
22+
# N can TP
23+
n_tp = [
24+
(18432 * 2, 7168),
25+
((128 + 64) * 128, 7168),
26+
(128 * (128 + 128), 512),
27+
(24576, 1536),
28+
(4096, 7168),
29+
]
30+
# K can TP
31+
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
32+
# only support Deepseek-V3
33+
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]
34+
35+
weight_shapes = []
36+
for model, tp_size in models_tps:
37+
assert model in SUPPORT_MODEL
38+
for t in total:
39+
new_t = [t[0], t[1], model]
40+
weight_shapes.append(new_t)
41+
for n_t in n_tp:
42+
new_t = [n_t[0] // tp_size, n_t[1], model]
43+
weight_shapes.append(new_t)
44+
for k_t in k_tp:
45+
new_t = [k_t[0], k_t[1] // tp_size, model]
46+
weight_shapes.append(new_t)
47+
return weight_shapes
48+
49+
50+
def cdiv(a: int, b: int) -> int:
51+
"""Ceiling division."""
52+
return -(a // -b)
53+
54+
55+
def scale_shape(shape, group_shape):
56+
assert len(shape) == len(group_shape)
57+
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
58+
59+
60+
@triton.testing.perf_report(
61+
triton.testing.Benchmark(
62+
x_names=["batch_size"],
63+
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
64+
x_log=False,
65+
line_arg="provider",
66+
line_vals=["vllm", "sgl-kernel"],
67+
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"],
68+
styles=[("blue", "-"), ("orange", "-")],
69+
ylabel="GB/s",
70+
plot_name="fp8 blockwise scaled matmul",
71+
args={},
72+
)
73+
)
74+
def benchmark(batch_size, provider, N, K):
75+
M = batch_size
76+
fp8_info = torch.finfo(torch.float8_e4m3fn)
77+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
78+
79+
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
80+
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
81+
82+
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
83+
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
84+
85+
scale_a_group_shape = (1, 128)
86+
scale_b_group_shape = (128, 128)
87+
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
88+
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
89+
90+
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
91+
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
92+
scale_a = scale_a.t().contiguous().t()
93+
scale_b = scale_b.t().contiguous().t()
94+
95+
quantiles = [0.5, 0.2, 0.8]
96+
if provider == "sgl-kernel":
97+
ms, min_ms, max_ms = triton.testing.do_bench(
98+
lambda: fp8_blockwise_scaled_mm(
99+
a_fp8, b_fp8, scale_a, scale_b, torch.float16
100+
),
101+
quantiles=quantiles,
102+
)
103+
if provider == "vllm":
104+
ms, min_ms, max_ms = triton.testing.do_bench(
105+
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
106+
quantiles=quantiles,
107+
)
108+
gbps = (
109+
lambda ms: (
110+
(2 * M * N * K - M * N) * a_fp8.element_size()
111+
+ (3 * M * N) * scale_a.element_size()
112+
)
113+
* 1e-9
114+
/ (ms * 1e-3)
115+
)
116+
return gbps(ms), gbps(max_ms), gbps(min_ms)
117+
118+
119+
if __name__ == "__main__":
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--models",
123+
nargs="+",
124+
type=str,
125+
default=["deepseek-ai/DeepSeek-V3"],
126+
help="List of models to benchmark",
127+
)
128+
parser.add_argument(
129+
"--tp-sizes",
130+
nargs="+",
131+
type=int,
132+
default=[1],
133+
help="List of tensor parallel sizes",
134+
)
135+
args = parser.parse_args()
136+
137+
NK_model_names = get_weight_shapes(args)
138+
for N, K, model_name in NK_model_names:
139+
print(f"{model_name} N={N} K={K}: ")
140+
benchmark.run(
141+
print_data=True,
142+
show_plots=True,
143+
save_path="bench_fp8_blockwise_res",
144+
N=N,
145+
K=K,
146+
)
147+
148+
print("Benchmark finished!")

sgl-kernel/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _get_version():
8282
"src/sgl-kernel/csrc/moe_align_kernel.cu",
8383
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
8484
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
85+
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu",
8586
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
8687
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
8788
"3rdparty/flashinfer/csrc/activation.cu",

sgl-kernel/src/sgl-kernel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
bmm_fp8,
44
custom_dispose,
55
custom_reduce,
6+
fp8_blockwise_scaled_mm,
67
fp8_scaled_mm,
78
fused_add_rmsnorm,
89
gelu_and_mul,
@@ -29,6 +30,7 @@
2930
"bmm_fp8",
3031
"custom_dispose",
3132
"custom_reduce",
33+
"fp8_blockwise_scaled_mm",
3234
"fp8_scaled_mm",
3335
"fused_add_rmsnorm",
3436
"gelu_and_mul",
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Adapt from
2+
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
3+
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
4+
// clang-format off
5+
#pragma once
6+
7+
#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
8+
9+
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
10+
11+
12+
/////////////////////////////////////////////////////////////////////////////////////////////////
13+
14+
namespace cutlass::gemm::collective {
15+
16+
/////////////////////////////////////////////////////////////////////////////////////////////////
17+
18+
// GMMA_TMA_WS_SS (BlockScaled Builders)
19+
template <
20+
class ElementA,
21+
class GmemLayoutATag,
22+
int AlignmentA,
23+
class ElementB,
24+
class GmemLayoutBTag,
25+
int AlignmentB,
26+
class ElementAccumulator,
27+
class TileShape_MNK,
28+
class ClusterShape_MNK,
29+
class StageCountType,
30+
int ScaleGranularityM
31+
>
32+
struct CollectiveBuilder<
33+
arch::Sm90,
34+
arch::OpClassTensorOp,
35+
ElementA,
36+
GmemLayoutATag,
37+
AlignmentA,
38+
ElementB,
39+
GmemLayoutBTag,
40+
AlignmentB,
41+
ElementAccumulator,
42+
TileShape_MNK,
43+
ClusterShape_MNK,
44+
StageCountType,
45+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
46+
cute::enable_if_t<
47+
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
48+
> {
49+
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
50+
51+
static_assert(is_static<TileShape_MNK>::value);
52+
static_assert(is_static<ClusterShape_MNK>::value);
53+
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
54+
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
55+
#endif
56+
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
57+
"Should meet TMA alignment requirement\n");
58+
59+
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
60+
KernelPtrArrayTmaWarpSpecializedCooperative,
61+
KernelPtrArrayTmaWarpSpecializedPingpong>);
62+
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
63+
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
64+
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
65+
66+
// For fp32 types, map to tf32 MMA value type
67+
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
68+
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
69+
70+
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
71+
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
72+
73+
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
74+
KernelTmaWarpSpecializedCooperative,
75+
KernelPtrArrayTmaWarpSpecializedCooperative,
76+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
77+
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
78+
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
79+
80+
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
81+
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
82+
83+
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
84+
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
85+
86+
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
87+
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
88+
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
89+
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
90+
91+
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
92+
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
93+
94+
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
95+
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
96+
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
97+
98+
using SmemCopyAtomA = void;
99+
using SmemCopyAtomB = void;
100+
101+
using CollectiveOp = CollectiveMma<
102+
DispatchPolicy,
103+
TileShape_MNK,
104+
ElementA,
105+
TagToStrideA_t<GmemLayoutATag>,
106+
ElementB,
107+
TagToStrideB_t<GmemLayoutBTag>,
108+
TiledMma,
109+
GmemTiledCopyA,
110+
SmemLayoutAtomA,
111+
SmemCopyAtomA,
112+
cute::identity,
113+
GmemTiledCopyB,
114+
SmemLayoutAtomB,
115+
SmemCopyAtomB,
116+
cute::identity
117+
>;
118+
};
119+
120+
121+
/////////////////////////////////////////////////////////////////////////////////////////////////
122+
123+
} // namespace cutlass::gemm::collective
124+
125+
/////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)