Skip to content

Commit b10e2e3

Browse files
LyrisZhongFaqin Zhongmgoin
authored andcommitted
[Perf] SM100 - add swap AB optimization to CUTLASS FP8 GEMM (vllm-project#27284)
Signed-off-by: Faqin Zhong <[email protected]> Co-authored-by: Faqin Zhong <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent c11db07 commit b10e2e3

File tree

2 files changed

+233
-52
lines changed

2 files changed

+233
-52
lines changed

csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "scaled_mm_kernels.hpp"
22
#include "scaled_mm_sm100_fp8_dispatch.cuh"
3-
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
43

54
namespace vllm {
65

@@ -13,11 +12,11 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
1312
if (bias) {
1413
TORCH_CHECK(bias->dtype() == out.dtype(),
1514
"currently bias dtype must match output dtype ", out.dtype());
16-
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
17-
out, a, b, a_scales, b_scales, *bias);
15+
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
16+
b_scales, *bias);
1817
} else {
19-
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
20-
out, a, b, a_scales, b_scales);
18+
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
19+
b_scales);
2120
}
2221
}
2322

csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh

Lines changed: 229 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "scaled_mm.cuh"
44
#include "cutlass_gemm_caller.cuh"
5+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
56

67
/**
78
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
@@ -12,8 +13,88 @@ namespace vllm {
1213

1314
using c3x::cutlass_gemm_caller;
1415

15-
template <typename InType, typename OutType,
16-
template <typename, typename, typename> typename Epilogue>
16+
template <typename ElementAB_, typename ElementD_,
17+
template <typename, typename, typename> typename Epilogue_,
18+
typename TileShape, typename ClusterShape, typename KernelSchedule,
19+
typename EpilogueSchedule, bool swap_ab_ = false>
20+
struct cutlass_3x_gemm_sm100_fp8 {
21+
using ElementAB = ElementAB_;
22+
using ElementC = ElementD_;
23+
using ElementD = ElementD_;
24+
using ElementAcc =
25+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
26+
float>::type;
27+
28+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
29+
30+
using EVTCompute = typename Epilogue::EVTCompute;
31+
32+
static constexpr int AlignmentAB =
33+
128 / cutlass::sizeof_bits<ElementAB>::value;
34+
static constexpr int AlignmentCD =
35+
128 / cutlass::sizeof_bits<ElementD>::value;
36+
37+
// Compile-time swap_ab flag
38+
static constexpr bool swap_ab = swap_ab_;
39+
40+
// -----------------------------------------------------------
41+
// Layout definitions
42+
// -----------------------------------------------------------
43+
using LayoutA = cutlass::layout::RowMajor;
44+
using LayoutA_T = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
45+
46+
using LayoutB = cutlass::layout::ColumnMajor;
47+
using LayoutB_T = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
48+
49+
using LayoutD = cutlass::layout::RowMajor;
50+
using LayoutD_Transpose =
51+
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
52+
53+
using LayoutC = LayoutD;
54+
using LayoutC_Transpose = LayoutD_Transpose;
55+
56+
// -----------------------------------------------------------
57+
// Collective epilogue (conditionally swap operands and layouts)
58+
// -----------------------------------------------------------
59+
using CollectiveEpilogue =
60+
typename cutlass::epilogue::collective::CollectiveBuilder<
61+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
62+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
63+
ElementAcc, float, ElementC,
64+
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>, AlignmentCD,
65+
ElementD, conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
66+
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
67+
68+
static constexpr size_t CEStorageSize =
69+
sizeof(typename CollectiveEpilogue::SharedStorage);
70+
71+
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
72+
static_cast<int>(CEStorageSize)>;
73+
74+
// -----------------------------------------------------------
75+
// Collective mainloop (conditionally swap operands and layouts)
76+
// -----------------------------------------------------------
77+
using CollectiveMainloop = conditional_t<
78+
swap_ab,
79+
typename cutlass::gemm::collective::CollectiveBuilder<
80+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
81+
LayoutB_T, AlignmentAB, // Swapped B (as A)
82+
ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B)
83+
ElementAcc, TileShape, ClusterShape, Stages,
84+
KernelSchedule>::CollectiveOp,
85+
typename cutlass::gemm::collective::CollectiveBuilder<
86+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
87+
LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc,
88+
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>;
89+
90+
// -----------------------------------------------------------
91+
// Kernel definition
92+
// -----------------------------------------------------------
93+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
94+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
95+
};
96+
97+
template <typename InType, typename OutType, bool EnableBias>
1798
struct sm100_fp8_config_default {
1899
// M in (256, inf)
19100
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
@@ -22,12 +103,16 @@ struct sm100_fp8_config_default {
22103
using TileShape = Shape<_256, _128, _128>;
23104
using ClusterShape = Shape<_2, _2, _1>;
24105
using Cutlass3xGemm =
25-
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
26-
KernelSchedule, EpilogueSchedule>;
106+
conditional_t<EnableBias,
107+
cutlass_3x_gemm_sm100_fp8<
108+
InType, OutType, c3x::ScaledEpilogueBias, TileShape,
109+
ClusterShape, KernelSchedule, EpilogueSchedule>,
110+
cutlass_3x_gemm_sm100_fp8<
111+
InType, OutType, c3x::ScaledEpilogue, TileShape,
112+
ClusterShape, KernelSchedule, EpilogueSchedule>>;
27113
};
28114

29-
template <typename InType, typename OutType,
30-
template <typename, typename, typename> typename Epilogue>
115+
template <typename InType, typename OutType, bool EnableBias>
31116
struct sm100_fp8_config_M256 {
32117
// M in (64, 256]
33118
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
@@ -36,100 +121,197 @@ struct sm100_fp8_config_M256 {
36121
using TileShape = Shape<_128, _128, _128>;
37122
using ClusterShape = Shape<_2, _1, _1>;
38123
using Cutlass3xGemm =
39-
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
40-
KernelSchedule, EpilogueSchedule>;
124+
conditional_t<EnableBias,
125+
cutlass_3x_gemm_sm100_fp8<
126+
InType, OutType, c3x::ScaledEpilogueBias, TileShape,
127+
ClusterShape, KernelSchedule, EpilogueSchedule>,
128+
cutlass_3x_gemm_sm100_fp8<
129+
InType, OutType, c3x::ScaledEpilogue, TileShape,
130+
ClusterShape, KernelSchedule, EpilogueSchedule>>;
41131
};
42132

43-
template <typename InType, typename OutType,
44-
template <typename, typename, typename> typename Epilogue>
133+
template <typename InType, typename OutType, bool EnableBias>
134+
struct sm100_fp8_config_M64_swap_ab {
135+
// This config is for M in (16, 64] and K >= 4096
136+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
137+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
138+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
139+
using TileShape = Shape<_128, _64, _256>;
140+
using ClusterShape = Shape<_4, _1, _1>;
141+
142+
// Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap
143+
// AB
144+
using Cutlass3xGemm = conditional_t<
145+
EnableBias,
146+
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogueColumnBias,
147+
TileShape, ClusterShape, KernelSchedule,
148+
EpilogueSchedule, true>,
149+
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
150+
ClusterShape, KernelSchedule, EpilogueSchedule,
151+
true>>;
152+
};
153+
154+
template <typename InType, typename OutType, bool EnableBias>
45155
struct sm100_fp8_config_M64 {
46-
// M in (16, 64]
156+
// This config is for M = 64 and K < 4096 (do not enable swap AB in such case)
47157
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
48158
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
49159
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
50160
using TileShape = Shape<_64, _64, _128>;
51161
using ClusterShape = Shape<_1, _1, _1>;
162+
52163
using Cutlass3xGemm =
53-
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
54-
KernelSchedule, EpilogueSchedule>;
164+
conditional_t<EnableBias,
165+
cutlass_3x_gemm_sm100_fp8<
166+
InType, OutType, c3x::ScaledEpilogueBias, TileShape,
167+
ClusterShape, KernelSchedule, EpilogueSchedule>,
168+
cutlass_3x_gemm_sm100_fp8<
169+
InType, OutType, c3x::ScaledEpilogue, TileShape,
170+
ClusterShape, KernelSchedule, EpilogueSchedule>>;
55171
};
56172

57-
template <typename InType, typename OutType,
58-
template <typename, typename, typename> typename Epilogue>
59-
struct sm100_fp8_config_M16 {
173+
template <typename InType, typename OutType, bool EnableBias>
174+
struct sm100_fp8_config_M16_swap_ab {
60175
// M in [1, 16]
61176
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
62177
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
63178
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
64-
using TileShape = Shape<_64, _64, _128>;
65-
using ClusterShape = Shape<_1, _4, _1>;
66-
using Cutlass3xGemm =
67-
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
68-
KernelSchedule, EpilogueSchedule>;
179+
using TileShape = Shape<_128, _32, _128>;
180+
using ClusterShape = Shape<_4, _1, _1>;
181+
182+
// Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap
183+
// AB
184+
using Cutlass3xGemm = conditional_t<
185+
EnableBias,
186+
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogueColumnBias,
187+
TileShape, ClusterShape, KernelSchedule,
188+
EpilogueSchedule, true>,
189+
cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
190+
ClusterShape, KernelSchedule, EpilogueSchedule,
191+
true>>;
69192
};
70193

71-
template <typename InType, typename OutType,
72-
template <typename, typename, typename> typename Epilogue,
194+
template <typename Gemm, typename... EpilogueArgs>
195+
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
196+
torch::Tensor const& b,
197+
EpilogueArgs&&... epilogue_params) {
198+
static constexpr bool swap_ab = Gemm::swap_ab;
199+
using ElementAB = typename Gemm::ElementAB;
200+
using ElementD = typename Gemm::ElementD;
201+
using GemmKernel = typename Gemm::GemmKernel;
202+
203+
using StrideA = typename Gemm::GemmKernel::StrideA;
204+
using StrideB = typename Gemm::GemmKernel::StrideB;
205+
using StrideC = typename Gemm::GemmKernel::StrideC;
206+
207+
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
208+
auto prob_shape =
209+
swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
210+
211+
StrideA a_stride =
212+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
213+
StrideB b_stride =
214+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
215+
StrideC c_stride = cutlass::make_cute_packed_stride(
216+
StrideC{},
217+
swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
218+
219+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
220+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
221+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
222+
223+
typename GemmKernel::MainloopArguments mainloop_args =
224+
swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr,
225+
a_stride}
226+
: typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr,
227+
b_stride};
228+
229+
typename GemmKernel::EpilogueArguments epilogue_args{
230+
Gemm::Epilogue::prepare_args(
231+
std::forward<EpilogueArgs>(epilogue_params)...),
232+
c_ptr, c_stride, c_ptr, c_stride};
233+
234+
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
235+
epilogue_args);
236+
}
237+
238+
template <typename InType, typename OutType, bool EnableBias,
73239
typename... EpilogueArgs>
74240
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
75241
torch::Tensor const& a,
76242
torch::Tensor const& b,
243+
torch::Tensor const& a_scales,
244+
torch::Tensor const& b_scales,
77245
EpilogueArgs&&... args) {
78246
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
79247
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
80248
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
81249

82250
using Cutlass3xGemmDefault =
83251
typename sm100_fp8_config_default<InType, OutType,
84-
Epilogue>::Cutlass3xGemm;
85-
using Cutlass3xGemmM16 =
86-
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
252+
EnableBias>::Cutlass3xGemm;
253+
using Cutlass3xGemmM16SwapAB =
254+
typename sm100_fp8_config_M16_swap_ab<InType, OutType,
255+
EnableBias>::Cutlass3xGemm;
256+
using Cutlass3xGemmM64SwapAB =
257+
typename sm100_fp8_config_M64_swap_ab<InType, OutType,
258+
EnableBias>::Cutlass3xGemm;
87259
using Cutlass3xGemmM64 =
88-
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
260+
typename sm100_fp8_config_M64<InType, OutType, EnableBias>::Cutlass3xGemm;
261+
89262
using Cutlass3xGemmM256 =
90-
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
263+
typename sm100_fp8_config_M256<InType, OutType,
264+
EnableBias>::Cutlass3xGemm;
91265

92266
uint32_t const m = a.size(0);
93-
uint32_t const mp2 =
94-
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
267+
uint32_t const k = a.size(1);
95268

96-
if (mp2 <= 16) {
269+
if (m <= 16) {
97270
// m in [1, 16]
98-
return cutlass_gemm_caller<Cutlass3xGemmM16>(
99-
out, a, b, std::forward<EpilogueArgs>(args)...);
100-
} else if (mp2 <= 64) {
271+
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM16SwapAB>(
272+
out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
273+
} else if (m <= 64) {
101274
// m in (16, 64]
102-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
103-
out, a, b, std::forward<EpilogueArgs>(args)...);
104-
} else if (mp2 <= 256) {
275+
if (m == 64 && k < 4096) {
276+
// do not enable swap AB
277+
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM64>(
278+
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
279+
}
280+
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM64SwapAB>(
281+
out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
282+
283+
} else if (m <= 256) {
105284
// m in (64, 256]
106-
return cutlass_gemm_caller<Cutlass3xGemmM256>(
107-
out, a, b, std::forward<EpilogueArgs>(args)...);
285+
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM256>(
286+
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
108287
} else {
109288
// m in (256, inf)
110-
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
111-
out, a, b, std::forward<EpilogueArgs>(args)...);
289+
return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmDefault>(
290+
out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
112291
}
113292
}
114293

115-
template <template <typename, typename, typename> typename Epilogue,
116-
typename... EpilogueArgs>
294+
template <bool EnableBias, typename... EpilogueArgs>
117295
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
118296
torch::Tensor const& a,
119297
torch::Tensor const& b,
298+
torch::Tensor const& a_scales,
299+
torch::Tensor const& b_scales,
120300
EpilogueArgs&&... epilogue_args) {
121301
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
122302
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
123303

124304
if (out.dtype() == torch::kBFloat16) {
125305
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
126-
cutlass::bfloat16_t, Epilogue>(
127-
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
306+
cutlass::bfloat16_t, EnableBias>(
307+
out, a, b, a_scales, b_scales,
308+
std::forward<EpilogueArgs>(epilogue_args)...);
128309
} else {
129310
TORCH_CHECK(out.dtype() == torch::kFloat16);
130311
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
131-
cutlass::half_t, Epilogue>(
132-
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
312+
cutlass::half_t, EnableBias>(
313+
out, a, b, a_scales, b_scales,
314+
std::forward<EpilogueArgs>(epilogue_args)...);
133315
}
134316
}
135317

0 commit comments

Comments
 (0)