22
33#include " scaled_mm.cuh"
44#include " cutlass_gemm_caller.cuh"
5+ #include " cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
56
67/* *
78 * This file defines Gemm kernel configurations for SM100 (fp8) based on the
@@ -12,8 +13,88 @@ namespace vllm {
1213
1314using c3x::cutlass_gemm_caller;
1415
15- template <typename InType, typename OutType,
16- template <typename , typename , typename > typename Epilogue>
16+ template <typename ElementAB_, typename ElementD_,
17+ template <typename , typename , typename > typename Epilogue_,
18+ typename TileShape, typename ClusterShape, typename KernelSchedule,
19+ typename EpilogueSchedule, bool swap_ab_ = false >
20+ struct cutlass_3x_gemm_sm100_fp8 {
21+ using ElementAB = ElementAB_;
22+ using ElementC = ElementD_;
23+ using ElementD = ElementD_;
24+ using ElementAcc =
25+ typename std::conditional<std::is_same_v<ElementAB, int8_t >, int32_t ,
26+ float >::type;
27+
28+ using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
29+
30+ using EVTCompute = typename Epilogue::EVTCompute;
31+
32+ static constexpr int AlignmentAB =
33+ 128 / cutlass::sizeof_bits<ElementAB>::value;
34+ static constexpr int AlignmentCD =
35+ 128 / cutlass::sizeof_bits<ElementD>::value;
36+
37+ // Compile-time swap_ab flag
38+ static constexpr bool swap_ab = swap_ab_;
39+
40+ // -----------------------------------------------------------
41+ // Layout definitions
42+ // -----------------------------------------------------------
43+ using LayoutA = cutlass::layout::RowMajor;
44+ using LayoutA_T = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
45+
46+ using LayoutB = cutlass::layout::ColumnMajor;
47+ using LayoutB_T = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
48+
49+ using LayoutD = cutlass::layout::RowMajor;
50+ using LayoutD_Transpose =
51+ typename cutlass::layout::LayoutTranspose<LayoutD>::type;
52+
53+ using LayoutC = LayoutD;
54+ using LayoutC_Transpose = LayoutD_Transpose;
55+
56+ // -----------------------------------------------------------
57+ // Collective epilogue (conditionally swap operands and layouts)
58+ // -----------------------------------------------------------
59+ using CollectiveEpilogue =
60+ typename cutlass::epilogue::collective::CollectiveBuilder<
61+ cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
62+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
63+ ElementAcc, float , ElementC,
64+ conditional_t <swap_ab, LayoutC_Transpose, LayoutC>, AlignmentCD,
65+ ElementD, conditional_t <swap_ab, LayoutD_Transpose, LayoutD>,
66+ AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
67+
68+ static constexpr size_t CEStorageSize =
69+ sizeof (typename CollectiveEpilogue::SharedStorage);
70+
71+ using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
72+ static_cast <int >(CEStorageSize)>;
73+
74+ // -----------------------------------------------------------
75+ // Collective mainloop (conditionally swap operands and layouts)
76+ // -----------------------------------------------------------
77+ using CollectiveMainloop = conditional_t <
78+ swap_ab,
79+ typename cutlass::gemm::collective::CollectiveBuilder<
80+ cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
81+ LayoutB_T, AlignmentAB, // Swapped B (as A)
82+ ElementAB, LayoutA_T, AlignmentAB, // Swapped A (as B)
83+ ElementAcc, TileShape, ClusterShape, Stages,
84+ KernelSchedule>::CollectiveOp,
85+ typename cutlass::gemm::collective::CollectiveBuilder<
86+ cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
87+ LayoutA, AlignmentAB, ElementAB, LayoutB, AlignmentAB, ElementAcc,
88+ TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>;
89+
90+ // -----------------------------------------------------------
91+ // Kernel definition
92+ // -----------------------------------------------------------
93+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
94+ Shape<int , int , int , int >, CollectiveMainloop, CollectiveEpilogue, void >;
95+ };
96+
97+ template <typename InType, typename OutType, bool EnableBias>
1798struct sm100_fp8_config_default {
1899 // M in (256, inf)
19100 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
@@ -22,12 +103,16 @@ struct sm100_fp8_config_default {
22103 using TileShape = Shape<_256, _128, _128>;
23104 using ClusterShape = Shape<_2, _2, _1>;
24105 using Cutlass3xGemm =
25- cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
26- KernelSchedule, EpilogueSchedule>;
106+ conditional_t <EnableBias,
107+ cutlass_3x_gemm_sm100_fp8<
108+ InType, OutType, c3x::ScaledEpilogueBias, TileShape,
109+ ClusterShape, KernelSchedule, EpilogueSchedule>,
110+ cutlass_3x_gemm_sm100_fp8<
111+ InType, OutType, c3x::ScaledEpilogue, TileShape,
112+ ClusterShape, KernelSchedule, EpilogueSchedule>>;
27113};
28114
29- template <typename InType, typename OutType,
30- template <typename , typename , typename > typename Epilogue>
115+ template <typename InType, typename OutType, bool EnableBias>
31116struct sm100_fp8_config_M256 {
32117 // M in (64, 256]
33118 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
@@ -36,100 +121,197 @@ struct sm100_fp8_config_M256 {
36121 using TileShape = Shape<_128, _128, _128>;
37122 using ClusterShape = Shape<_2, _1, _1>;
38123 using Cutlass3xGemm =
39- cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
40- KernelSchedule, EpilogueSchedule>;
124+ conditional_t <EnableBias,
125+ cutlass_3x_gemm_sm100_fp8<
126+ InType, OutType, c3x::ScaledEpilogueBias, TileShape,
127+ ClusterShape, KernelSchedule, EpilogueSchedule>,
128+ cutlass_3x_gemm_sm100_fp8<
129+ InType, OutType, c3x::ScaledEpilogue, TileShape,
130+ ClusterShape, KernelSchedule, EpilogueSchedule>>;
41131};
42132
43- template <typename InType, typename OutType,
44- template <typename , typename , typename > typename Epilogue>
133+ template <typename InType, typename OutType, bool EnableBias>
134+ struct sm100_fp8_config_M64_swap_ab {
135+ // This config is for M in (16, 64] and K >= 4096
136+ static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
137+ using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
138+ using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
139+ using TileShape = Shape<_128, _64, _256>;
140+ using ClusterShape = Shape<_4, _1, _1>;
141+
142+ // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap
143+ // AB
144+ using Cutlass3xGemm = conditional_t <
145+ EnableBias,
146+ cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogueColumnBias,
147+ TileShape, ClusterShape, KernelSchedule,
148+ EpilogueSchedule, true >,
149+ cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
150+ ClusterShape, KernelSchedule, EpilogueSchedule,
151+ true >>;
152+ };
153+
154+ template <typename InType, typename OutType, bool EnableBias>
45155struct sm100_fp8_config_M64 {
46- // M in (16, 64]
156+ // This config is for M = 64 and K < 4096 (do not enable swap AB in such case)
47157 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
48158 using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
49159 using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
50160 using TileShape = Shape<_64, _64, _128>;
51161 using ClusterShape = Shape<_1, _1, _1>;
162+
52163 using Cutlass3xGemm =
53- cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
54- KernelSchedule, EpilogueSchedule>;
164+ conditional_t <EnableBias,
165+ cutlass_3x_gemm_sm100_fp8<
166+ InType, OutType, c3x::ScaledEpilogueBias, TileShape,
167+ ClusterShape, KernelSchedule, EpilogueSchedule>,
168+ cutlass_3x_gemm_sm100_fp8<
169+ InType, OutType, c3x::ScaledEpilogue, TileShape,
170+ ClusterShape, KernelSchedule, EpilogueSchedule>>;
55171};
56172
57- template <typename InType, typename OutType,
58- template <typename , typename , typename > typename Epilogue>
59- struct sm100_fp8_config_M16 {
173+ template <typename InType, typename OutType, bool EnableBias>
174+ struct sm100_fp8_config_M16_swap_ab {
60175 // M in [1, 16]
61176 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
62177 using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
63178 using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
64- using TileShape = Shape<_64, _64, _128>;
65- using ClusterShape = Shape<_1, _4, _1>;
66- using Cutlass3xGemm =
67- cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
68- KernelSchedule, EpilogueSchedule>;
179+ using TileShape = Shape<_128, _32, _128>;
180+ using ClusterShape = Shape<_4, _1, _1>;
181+
182+ // Use ScaledEpilogueColumnBias instead of ScaledEpilogueBias when doing swap
183+ // AB
184+ using Cutlass3xGemm = conditional_t <
185+ EnableBias,
186+ cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogueColumnBias,
187+ TileShape, ClusterShape, KernelSchedule,
188+ EpilogueSchedule, true >,
189+ cutlass_3x_gemm_sm100_fp8<InType, OutType, c3x::ScaledEpilogue, TileShape,
190+ ClusterShape, KernelSchedule, EpilogueSchedule,
191+ true >>;
69192};
70193
71- template <typename InType, typename OutType,
72- template <typename , typename , typename > typename Epilogue,
194+ template <typename Gemm, typename ... EpilogueArgs>
195+ void cutlass_gemm_caller_sm100_fp8 (torch::Tensor& out, torch::Tensor const & a,
196+ torch::Tensor const & b,
197+ EpilogueArgs&&... epilogue_params) {
198+ static constexpr bool swap_ab = Gemm::swap_ab;
199+ using ElementAB = typename Gemm::ElementAB;
200+ using ElementD = typename Gemm::ElementD;
201+ using GemmKernel = typename Gemm::GemmKernel;
202+
203+ using StrideA = typename Gemm::GemmKernel::StrideA;
204+ using StrideB = typename Gemm::GemmKernel::StrideB;
205+ using StrideC = typename Gemm::GemmKernel::StrideC;
206+
207+ int32_t m = a.size (0 ), n = b.size (1 ), k = a.size (1 );
208+ auto prob_shape =
209+ swap_ab ? cute::make_shape (n, m, k, 1 ) : cute::make_shape (m, n, k, 1 );
210+
211+ StrideA a_stride =
212+ cutlass::make_cute_packed_stride (StrideA{}, cute::make_shape (m, k, 1 ));
213+ StrideB b_stride =
214+ cutlass::make_cute_packed_stride (StrideB{}, cute::make_shape (n, k, 1 ));
215+ StrideC c_stride = cutlass::make_cute_packed_stride (
216+ StrideC{},
217+ swap_ab ? cute::make_shape (n, m, 1 ) : cute::make_shape (m, n, 1 ));
218+
219+ auto a_ptr = static_cast <ElementAB*>(a.data_ptr ());
220+ auto b_ptr = static_cast <ElementAB*>(b.data_ptr ());
221+ auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
222+
223+ typename GemmKernel::MainloopArguments mainloop_args =
224+ swap_ab ? typename GemmKernel::MainloopArguments{b_ptr, b_stride, a_ptr,
225+ a_stride}
226+ : typename GemmKernel::MainloopArguments{a_ptr, a_stride, b_ptr,
227+ b_stride};
228+
229+ typename GemmKernel::EpilogueArguments epilogue_args{
230+ Gemm::Epilogue::prepare_args (
231+ std::forward<EpilogueArgs>(epilogue_params)...),
232+ c_ptr, c_stride, c_ptr, c_stride};
233+
234+ c3x::cutlass_gemm_caller<GemmKernel>(a.device (), prob_shape, mainloop_args,
235+ epilogue_args);
236+ }
237+
238+ template <typename InType, typename OutType, bool EnableBias,
73239 typename ... EpilogueArgs>
74240inline void cutlass_gemm_sm100_fp8_dispatch (torch::Tensor& out,
75241 torch::Tensor const & a,
76242 torch::Tensor const & b,
243+ torch::Tensor const & a_scales,
244+ torch::Tensor const & b_scales,
77245 EpilogueArgs&&... args) {
78246 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
79247 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
80248 TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
81249
82250 using Cutlass3xGemmDefault =
83251 typename sm100_fp8_config_default<InType, OutType,
84- Epilogue>::Cutlass3xGemm;
85- using Cutlass3xGemmM16 =
86- typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
252+ EnableBias>::Cutlass3xGemm;
253+ using Cutlass3xGemmM16SwapAB =
254+ typename sm100_fp8_config_M16_swap_ab<InType, OutType,
255+ EnableBias>::Cutlass3xGemm;
256+ using Cutlass3xGemmM64SwapAB =
257+ typename sm100_fp8_config_M64_swap_ab<InType, OutType,
258+ EnableBias>::Cutlass3xGemm;
87259 using Cutlass3xGemmM64 =
88- typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
260+ typename sm100_fp8_config_M64<InType, OutType, EnableBias>::Cutlass3xGemm;
261+
89262 using Cutlass3xGemmM256 =
90- typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
263+ typename sm100_fp8_config_M256<InType, OutType,
264+ EnableBias>::Cutlass3xGemm;
91265
92266 uint32_t const m = a.size (0 );
93- uint32_t const mp2 =
94- std::max (static_cast <uint32_t >(16 ), next_pow_2 (m)); // next power of 2
267+ uint32_t const k = a.size (1 );
95268
96- if (mp2 <= 16 ) {
269+ if (m <= 16 ) {
97270 // m in [1, 16]
98- return cutlass_gemm_caller<Cutlass3xGemmM16 >(
99- out, a, b, std::forward<EpilogueArgs>(args)...);
100- } else if (mp2 <= 64 ) {
271+ return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM16SwapAB >(
272+ out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
273+ } else if (m <= 64 ) {
101274 // m in (16, 64]
102- return cutlass_gemm_caller<Cutlass3xGemmM64>(
103- out, a, b, std::forward<EpilogueArgs>(args)...);
104- } else if (mp2 <= 256 ) {
275+ if (m == 64 && k < 4096 ) {
276+ // do not enable swap AB
277+ return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM64>(
278+ out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
279+ }
280+ return cutlass_gemm_caller_sm100_fp8<Cutlass3xGemmM64SwapAB>(
281+ out, a, b, b_scales, a_scales, std::forward<EpilogueArgs>(args)...);
282+
283+ } else if (m <= 256 ) {
105284 // m in (64, 256]
106- return cutlass_gemm_caller <Cutlass3xGemmM256>(
107- out, a, b, std::forward<EpilogueArgs>(args)...);
285+ return cutlass_gemm_caller_sm100_fp8 <Cutlass3xGemmM256>(
286+ out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
108287 } else {
109288 // m in (256, inf)
110- return cutlass_gemm_caller <Cutlass3xGemmDefault>(
111- out, a, b, std::forward<EpilogueArgs>(args)...);
289+ return cutlass_gemm_caller_sm100_fp8 <Cutlass3xGemmDefault>(
290+ out, a, b, a_scales, b_scales, std::forward<EpilogueArgs>(args)...);
112291 }
113292}
114293
115- template <template <typename , typename , typename > typename Epilogue,
116- typename ... EpilogueArgs>
294+ template <bool EnableBias, typename ... EpilogueArgs>
117295void cutlass_scaled_mm_sm100_fp8_epilogue (torch::Tensor& out,
118296 torch::Tensor const & a,
119297 torch::Tensor const & b,
298+ torch::Tensor const & a_scales,
299+ torch::Tensor const & b_scales,
120300 EpilogueArgs&&... epilogue_args) {
121301 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
122302 TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
123303
124304 if (out.dtype () == torch::kBFloat16 ) {
125305 return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t ,
126- cutlass::bfloat16_t , Epilogue>(
127- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
306+ cutlass::bfloat16_t , EnableBias>(
307+ out, a, b, a_scales, b_scales,
308+ std::forward<EpilogueArgs>(epilogue_args)...);
128309 } else {
129310 TORCH_CHECK (out.dtype () == torch::kFloat16 );
130311 return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t ,
131- cutlass::half_t , Epilogue>(
132- out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
312+ cutlass::half_t , EnableBias>(
313+ out, a, b, a_scales, b_scales,
314+ std::forward<EpilogueArgs>(epilogue_args)...);
133315 }
134316}
135317
0 commit comments