Skip to content

Commit 98425c4

Browse files
LucasWilkinsonweilong.yu
authored andcommitted
[Kernel] Initial Machete W4A8 support + Refactors (vllm-project#9855)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 5f076f1 commit 98425c4

28 files changed

+2616
-1694
lines changed

benchmarks/kernels/benchmark_machete.py

Lines changed: 385 additions & 134 deletions
Large diffs are not rendered by default.

benchmarks/kernels/graph_machete_bench.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
args = parser.parse_args()
2121

2222
with open(args.filename, 'rb') as f:
23-
data: List[TMeasurement] = pickle.load(f)
23+
data = pickle.load(f)
24+
raw_results: List[TMeasurement] = data["results"]
2425

2526
results = defaultdict(lambda: list())
26-
for v in data:
27+
for v in raw_results:
2728
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
2829
if result is not None:
2930
KN = result.group(1)

benchmarks/kernels/weight_shapes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,10 @@
4040
([8192, 57344], 1),
4141
([28672, 8192], 0),
4242
],
43+
"meta-llama/Llama-3.1-405b-hf": [
44+
([16384, 18432], 1),
45+
([16384, 16384], 0),
46+
([16384, 106496], 1),
47+
([53248, 16384], 0),
48+
],
4349
}

csrc/cutlass_extensions/cute_utils.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
2020
// is the layout f(x) = x
2121
template <typename Layout>
2222
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
23-
if constexpr (std::is_same_v<Layout, void>)
23+
if constexpr (std::is_same_v<Layout, void>) {
2424
return true;
25-
else {
25+
} else {
2626
constexpr auto coalesced_layout = coalesce(Layout{});
2727
if constexpr (rank(coalesced_layout) == 1 &&
2828
stride<0>(coalesced_layout) == 1) {

csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp renamed to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
// clang-format off
5353

5454
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
55+
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
5556
#include "cute/tensor.hpp"
5657

5758
namespace cutlass::epilogue::threadblock {
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
2+
3+
/*
4+
This file defines custom epilogues for fusing channel scales, token scales,
5+
bias, and activation zero-points onto a GEMM operation using the
6+
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
7+
8+
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
9+
as well as a static prepare_args function that constructs an
10+
EVTCompute::Arguments struct.
11+
*/
12+
13+
namespace vllm::c2x {
14+
15+
using namespace cute;
16+
17+
/*
18+
* This class provides the common load descriptors for the
19+
* ScaledEpilogue[...] classes
20+
*/
21+
template <typename ElementD, typename OutputTileThreadMap>
22+
struct ScaledEpilogueBase {
23+
protected:
24+
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
25+
26+
template <typename T>
27+
using ColOrScalarLoad =
28+
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
29+
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
30+
31+
template <typename T>
32+
using RowOrScalarLoad =
33+
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
34+
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
35+
36+
template <typename T>
37+
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
38+
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
39+
40+
template <typename T>
41+
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
42+
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
43+
44+
template <typename T>
45+
using RowOrZeroLoad =
46+
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
47+
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
48+
49+
// This utility function constructs the arguments for the load descriptors
50+
// from a tensor. It can handle both row and column, as well as row/column or
51+
// scalar cases.
52+
template <typename Descriptor, typename T>
53+
static auto args_from_tensor(torch::Tensor const& tensor) {
54+
using Arguments = typename Descriptor::Arguments;
55+
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
56+
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
57+
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
58+
return Arguments{data_ptr, tensor.numel() != 1};
59+
} else {
60+
// it would technically work but no use case as data_ptr is never nullptr
61+
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
62+
return Arguments{data_ptr};
63+
}
64+
}
65+
66+
// This overload handles the case where there might not be a tensor, in which
67+
// case a nullptr is passed and a constant (0) is used.
68+
template <typename Descriptor, typename T>
69+
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
70+
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
71+
using Arguments = typename Descriptor::Arguments;
72+
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
73+
return Arguments{data_ptr};
74+
}
75+
};
76+
77+
/*
78+
This epilogue function defines a quantized GEMM operation similar to
79+
torch._scaled_mm.
80+
81+
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
82+
per-row. B can be quantized per-tensor or per-column.
83+
Any combination of per-tensor and per-row or column is supported.
84+
A and B must have symmetric quantization (zero point == 0).
85+
86+
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
87+
scales are applied elementwise with numpy-style broadcasting.
88+
89+
ScaleA and ScaleB define the epilogue functions that apply the scales for
90+
the A and B operands respectively. These scales may be either per-tensor or
91+
per row or column.
92+
*/
93+
template <typename ElementD, typename OutputTileThreadMap>
94+
struct ScaledEpilogue
95+
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
96+
private:
97+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
98+
using Accum = typename SUPER::Accum;
99+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
100+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
101+
102+
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
103+
cutlass::multiplies, float, float,
104+
cutlass::FloatRoundStyle::round_to_nearest>;
105+
106+
using EVTCompute0 =
107+
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
108+
109+
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
110+
cutlass::multiplies, ElementD, float,
111+
cutlass::FloatRoundStyle::round_to_nearest>;
112+
113+
public:
114+
using EVTCompute =
115+
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
116+
using ArgumentType = typename EVTCompute::Arguments;
117+
118+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
119+
torch::Tensor const& b_scales) {
120+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
121+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
122+
123+
typename EVTCompute0::Arguments evt0_args{b_args};
124+
return ArgumentType{a_args, evt0_args};
125+
}
126+
};
127+
128+
/*
129+
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
130+
* This bias can also be used in the per-tensor azp case, where the activation
131+
* zero point (azp) is used to compute an azp correction term,
132+
* which is folded into the bias.
133+
*
134+
* The bias tensor must be per-output channel.
135+
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
136+
*/
137+
template <typename ElementD, typename OutputTileThreadMap>
138+
struct ScaledEpilogueBias
139+
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
140+
protected:
141+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
142+
using Accum = typename SUPER::Accum;
143+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
144+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
145+
using Bias = typename SUPER::template RowLoad<ElementD>;
146+
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
147+
cutlass::multiplies, float, float,
148+
cutlass::FloatRoundStyle::round_to_nearest>;
149+
150+
using EVTCompute0 =
151+
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
152+
153+
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
154+
cutlass::multiply_add, ElementD, float,
155+
cutlass::FloatRoundStyle::round_to_nearest>;
156+
157+
public:
158+
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
159+
EVTCompute0, Bias>;
160+
using ArgumentType = typename EVTCompute::Arguments;
161+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
162+
torch::Tensor const& b_scales,
163+
torch::Tensor const& bias) {
164+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
165+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
166+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
167+
168+
typename EVTCompute0::Arguments evt0_args{b_args};
169+
return ArgumentType{a_args, evt0_args, bias_args};
170+
}
171+
};
172+
173+
/*
174+
* This epilogue directly supports per-tensor azp in int32 form.
175+
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
176+
* term, which should already be multiplied with the scalar azp.
177+
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
178+
*
179+
* This epilogue also supports bias, which remains per-channel.
180+
*/
181+
template <typename ElementD, typename OutputTileThreadMap>
182+
struct ScaledEpilogueBiasAzp
183+
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
184+
private:
185+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
186+
using Accum = typename SUPER::Accum;
187+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
188+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
189+
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
190+
191+
// This is the full AZP term, azp * J @ B, shape (1,n)
192+
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
193+
194+
// Compute float(accum - azp_adj), both operands are int32_t
195+
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
196+
cutlass::minus, float, int32_t,
197+
cutlass::FloatRoundStyle::round_to_nearest>;
198+
199+
using EVTComputeAzp =
200+
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
201+
202+
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
203+
cutlass::multiplies, float, float,
204+
cutlass::FloatRoundStyle::round_to_nearest>;
205+
206+
using EVTComputeScaleB =
207+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
208+
EVTComputeAzp>;
209+
210+
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
211+
cutlass::multiply_add, ElementD, float,
212+
cutlass::FloatRoundStyle::round_to_nearest>;
213+
214+
public:
215+
using EVTCompute =
216+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
217+
EVTComputeScaleB, Bias>;
218+
219+
using ArgumentType = typename EVTCompute::Arguments;
220+
221+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
222+
torch::Tensor const& b_scales,
223+
torch::Tensor const& azp_adj,
224+
c10::optional<torch::Tensor> const& bias) {
225+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
226+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
227+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
228+
auto azp_adj_args =
229+
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
230+
231+
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
232+
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
233+
return ArgumentType{a_args, evt_scale_b_args, bias_args};
234+
}
235+
};
236+
237+
/*
238+
* This epilogue supports per-token azp by computing and applying
239+
* the correction term using a rank-1 update. If the term were materialized,
240+
* it would require O(m*n) space, and this way it only requires O(m+n) space.
241+
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
242+
* point for each row of A.
243+
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
244+
*
245+
* This epilogue also supports bias, which remains per-channel.
246+
*/
247+
template <typename ElementD, typename OutputTileThreadMap>
248+
struct ScaledEpilogueBiasAzpToken
249+
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
250+
private:
251+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
252+
using Accum = typename SUPER::Accum;
253+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
254+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
255+
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
256+
257+
// Per-token azp term, shape (m,1)
258+
using Azp = typename SUPER::template ColLoad<int32_t>;
259+
260+
// This is the AZP adjustment term, J @ B, shape (1,n)
261+
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
262+
263+
// Compute azp * azp_adj
264+
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
265+
cutlass::multiplies, int32_t, int32_t,
266+
cutlass::FloatRoundStyle::round_to_nearest>;
267+
268+
using EVTComputeAzp =
269+
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
270+
271+
// Compute float(accum - azp*azp_adj), all operands are int32_t
272+
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
273+
cutlass::minus, float, int32_t,
274+
cutlass::FloatRoundStyle::round_to_nearest>;
275+
276+
using EVTComputeAcc =
277+
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
278+
279+
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
280+
cutlass::multiplies, float, float,
281+
cutlass::FloatRoundStyle::round_to_nearest>;
282+
283+
using EVTComputeScaleB =
284+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
285+
EVTComputeAcc>;
286+
287+
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
288+
cutlass::multiply_add, ElementD, float,
289+
cutlass::FloatRoundStyle::round_to_nearest>;
290+
291+
public:
292+
using EVTCompute =
293+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
294+
EVTComputeScaleB, Bias>;
295+
296+
using ArgumentType = typename EVTCompute::Arguments;
297+
298+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
299+
torch::Tensor const& b_scales,
300+
torch::Tensor const& azp_adj,
301+
torch::Tensor const& azp,
302+
c10::optional<torch::Tensor> const& bias) {
303+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
304+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
305+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
306+
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
307+
auto azp_adj_args =
308+
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
309+
310+
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
311+
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
312+
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
313+
return ArgumentType{a_args, evt_scale_b_args, bias_args};
314+
}
315+
};
316+
317+
}; // namespace vllm::c2x

0 commit comments

Comments
 (0)