1+ #include " cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
2+
3+ /*
4+ This file defines custom epilogues for fusing channel scales, token scales,
5+ bias, and activation zero-points onto a GEMM operation using the
6+ CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
7+
8+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
9+ as well as a static prepare_args function that constructs an
10+ EVTCompute::Arguments struct.
11+ */
12+
13+ namespace vllm ::c2x {
14+
15+ using namespace cute ;
16+
17+ /*
18+ * This class provides the common load descriptors for the
19+ * ScaledEpilogue[...] classes
20+ */
21+ template <typename ElementD, typename OutputTileThreadMap>
22+ struct ScaledEpilogueBase {
23+ protected:
24+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
25+
26+ template <typename T>
27+ using ColOrScalarLoad =
28+ cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
29+ OutputTileThreadMap, T, Stride<Int<1 >, Int<0 >, Int<0 >>>;
30+
31+ template <typename T>
32+ using RowOrScalarLoad =
33+ cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
34+ OutputTileThreadMap, T, Stride<Int<0 >, Int<1 >, Int<0 >>>;
35+
36+ template <typename T>
37+ using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
38+ OutputTileThreadMap, T, Stride<Int<1 >, Int<0 >, Int<0 >>>;
39+
40+ template <typename T>
41+ using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
42+ OutputTileThreadMap, T, Stride<Int<0 >, Int<1 >, Int<0 >>>;
43+
44+ template <typename T>
45+ using RowOrZeroLoad =
46+ cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
47+ OutputTileThreadMap, T, Stride<Int<0 >, Int<1 >, Int<0 >>>;
48+
49+ // This utility function constructs the arguments for the load descriptors
50+ // from a tensor. It can handle both row and column, as well as row/column or
51+ // scalar cases.
52+ template <typename Descriptor, typename T>
53+ static auto args_from_tensor (torch::Tensor const & tensor) {
54+ using Arguments = typename Descriptor::Arguments;
55+ auto * data_ptr = static_cast <T*>(tensor.data_ptr ());
56+ if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
57+ std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
58+ return Arguments{data_ptr, tensor.numel () != 1 };
59+ } else {
60+ // it would technically work but no use case as data_ptr is never nullptr
61+ static_assert (!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
62+ return Arguments{data_ptr};
63+ }
64+ }
65+
66+ // This overload handles the case where there might not be a tensor, in which
67+ // case a nullptr is passed and a constant (0) is used.
68+ template <typename Descriptor, typename T>
69+ static auto args_from_tensor (c10::optional<torch::Tensor> const & tensor) {
70+ static_assert (std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
71+ using Arguments = typename Descriptor::Arguments;
72+ auto * data_ptr = tensor ? static_cast <T*>(tensor->data_ptr ()) : nullptr ;
73+ return Arguments{data_ptr};
74+ }
75+ };
76+
77+ /*
78+ This epilogue function defines a quantized GEMM operation similar to
79+ torch._scaled_mm.
80+
81+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
82+ per-row. B can be quantized per-tensor or per-column.
83+ Any combination of per-tensor and per-row or column is supported.
84+ A and B must have symmetric quantization (zero point == 0).
85+
86+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
87+ scales are applied elementwise with numpy-style broadcasting.
88+
89+ ScaleA and ScaleB define the epilogue functions that apply the scales for
90+ the A and B operands respectively. These scales may be either per-tensor or
91+ per row or column.
92+ */
93+ template <typename ElementD, typename OutputTileThreadMap>
94+ struct ScaledEpilogue
95+ : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
96+ private:
97+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
98+ using Accum = typename SUPER::Accum;
99+ using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
100+ using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
101+
102+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
103+ cutlass::multiplies, float , float ,
104+ cutlass::FloatRoundStyle::round_to_nearest>;
105+
106+ using EVTCompute0 =
107+ cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
108+
109+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
110+ cutlass::multiplies, ElementD, float ,
111+ cutlass::FloatRoundStyle::round_to_nearest>;
112+
113+ public:
114+ using EVTCompute =
115+ cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
116+ using ArgumentType = typename EVTCompute::Arguments;
117+
118+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
119+ torch::Tensor const & b_scales) {
120+ auto a_args = SUPER::template args_from_tensor<ScaleA, float >(a_scales);
121+ auto b_args = SUPER::template args_from_tensor<ScaleB, float >(b_scales);
122+
123+ typename EVTCompute0::Arguments evt0_args{b_args};
124+ return ArgumentType{a_args, evt0_args};
125+ }
126+ };
127+
128+ /*
129+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
130+ * This bias can also be used in the per-tensor azp case, where the activation
131+ * zero point (azp) is used to compute an azp correction term,
132+ * which is folded into the bias.
133+ *
134+ * The bias tensor must be per-output channel.
135+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
136+ */
137+ template <typename ElementD, typename OutputTileThreadMap>
138+ struct ScaledEpilogueBias
139+ : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
140+ protected:
141+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
142+ using Accum = typename SUPER::Accum;
143+ using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
144+ using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
145+ using Bias = typename SUPER::template RowLoad<ElementD>;
146+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
147+ cutlass::multiplies, float , float ,
148+ cutlass::FloatRoundStyle::round_to_nearest>;
149+
150+ using EVTCompute0 =
151+ cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
152+
153+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
154+ cutlass::multiply_add, ElementD, float ,
155+ cutlass::FloatRoundStyle::round_to_nearest>;
156+
157+ public:
158+ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
159+ EVTCompute0, Bias>;
160+ using ArgumentType = typename EVTCompute::Arguments;
161+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
162+ torch::Tensor const & b_scales,
163+ torch::Tensor const & bias) {
164+ auto a_args = SUPER::template args_from_tensor<ScaleA, float >(a_scales);
165+ auto b_args = SUPER::template args_from_tensor<ScaleB, float >(b_scales);
166+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
167+
168+ typename EVTCompute0::Arguments evt0_args{b_args};
169+ return ArgumentType{a_args, evt0_args, bias_args};
170+ }
171+ };
172+
173+ /*
174+ * This epilogue directly supports per-tensor azp in int32 form.
175+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
176+ * term, which should already be multiplied with the scalar azp.
177+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
178+ *
179+ * This epilogue also supports bias, which remains per-channel.
180+ */
181+ template <typename ElementD, typename OutputTileThreadMap>
182+ struct ScaledEpilogueBiasAzp
183+ : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
184+ private:
185+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
186+ using Accum = typename SUPER::Accum;
187+ using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
188+ using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
189+ using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
190+
191+ // This is the full AZP term, azp * J @ B, shape (1,n)
192+ using AzpWithAdj = typename SUPER::template RowLoad<int32_t >;
193+
194+ // Compute float(accum - azp_adj), both operands are int32_t
195+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
196+ cutlass::minus, float , int32_t ,
197+ cutlass::FloatRoundStyle::round_to_nearest>;
198+
199+ using EVTComputeAzp =
200+ cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
201+
202+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
203+ cutlass::multiplies, float , float ,
204+ cutlass::FloatRoundStyle::round_to_nearest>;
205+
206+ using EVTComputeScaleB =
207+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
208+ EVTComputeAzp>;
209+
210+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
211+ cutlass::multiply_add, ElementD, float ,
212+ cutlass::FloatRoundStyle::round_to_nearest>;
213+
214+ public:
215+ using EVTCompute =
216+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
217+ EVTComputeScaleB, Bias>;
218+
219+ using ArgumentType = typename EVTCompute::Arguments;
220+
221+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
222+ torch::Tensor const & b_scales,
223+ torch::Tensor const & azp_adj,
224+ c10::optional<torch::Tensor> const & bias) {
225+ auto a_args = SUPER::template args_from_tensor<ScaleA, float >(a_scales);
226+ auto b_args = SUPER::template args_from_tensor<ScaleB, float >(b_scales);
227+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
228+ auto azp_adj_args =
229+ SUPER::template args_from_tensor<AzpWithAdj, int32_t >(azp_adj);
230+
231+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
232+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
233+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
234+ }
235+ };
236+
237+ /*
238+ * This epilogue supports per-token azp by computing and applying
239+ * the correction term using a rank-1 update. If the term were materialized,
240+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
241+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
242+ * point for each row of A.
243+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
244+ *
245+ * This epilogue also supports bias, which remains per-channel.
246+ */
247+ template <typename ElementD, typename OutputTileThreadMap>
248+ struct ScaledEpilogueBiasAzpToken
249+ : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
250+ private:
251+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
252+ using Accum = typename SUPER::Accum;
253+ using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
254+ using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
255+ using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
256+
257+ // Per-token azp term, shape (m,1)
258+ using Azp = typename SUPER::template ColLoad<int32_t >;
259+
260+ // This is the AZP adjustment term, J @ B, shape (1,n)
261+ using AzpAdj = typename SUPER::template RowLoad<int32_t >;
262+
263+ // Compute azp * azp_adj
264+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
265+ cutlass::multiplies, int32_t , int32_t ,
266+ cutlass::FloatRoundStyle::round_to_nearest>;
267+
268+ using EVTComputeAzp =
269+ cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
270+
271+ // Compute float(accum - azp*azp_adj), all operands are int32_t
272+ using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
273+ cutlass::minus, float , int32_t ,
274+ cutlass::FloatRoundStyle::round_to_nearest>;
275+
276+ using EVTComputeAcc =
277+ cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
278+
279+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
280+ cutlass::multiplies, float , float ,
281+ cutlass::FloatRoundStyle::round_to_nearest>;
282+
283+ using EVTComputeScaleB =
284+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
285+ EVTComputeAcc>;
286+
287+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
288+ cutlass::multiply_add, ElementD, float ,
289+ cutlass::FloatRoundStyle::round_to_nearest>;
290+
291+ public:
292+ using EVTCompute =
293+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
294+ EVTComputeScaleB, Bias>;
295+
296+ using ArgumentType = typename EVTCompute::Arguments;
297+
298+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
299+ torch::Tensor const & b_scales,
300+ torch::Tensor const & azp_adj,
301+ torch::Tensor const & azp,
302+ c10::optional<torch::Tensor> const & bias) {
303+ auto a_args = SUPER::template args_from_tensor<ScaleA, float >(a_scales);
304+ auto b_args = SUPER::template args_from_tensor<ScaleB, float >(b_scales);
305+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
306+ auto azp_args = SUPER::template args_from_tensor<Azp, int32_t >(azp);
307+ auto azp_adj_args =
308+ SUPER::template args_from_tensor<AzpAdj, int32_t >(azp_adj);
309+
310+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
311+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
312+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
313+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
314+ }
315+ };
316+
317+ }; // namespace vllm::c2x
0 commit comments