Skip to content

Commit 1c3b5ea

Browse files
Add streamK for block-quantized CUTLASS kernels
Signed-off-by: leoneo <[email protected]>
1 parent 24700c3 commit 1c3b5ea

File tree

2 files changed

+106
-5
lines changed

2 files changed

+106
-5
lines changed

csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,38 @@ void cutlass_gemm_caller(torch::Device device,
5353
CUTLASS_CHECK(status);
5454
}
5555

56+
template <typename GemmKernel>
57+
void cutlass_gemm_caller_streamK(torch::Device device,
58+
cute::Shape<int, int, int, int> prob_shape,
59+
typename GemmKernel::MainloopArguments mainloop_args,
60+
typename GemmKernel::EpilogueArguments epilogue_args) {
61+
62+
63+
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
64+
prob_shape, mainloop_args, epilogue_args};
65+
66+
// add args for StreamK
67+
using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
68+
using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
69+
args.scheduler.decomposition_mode = DecompositionMode::StreamK;
70+
args.scheduler.reduction_mode = ReductionMode::Nondeterministic;
71+
72+
// Launch the CUTLASS GEMM kernel.
73+
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
74+
GemmOp gemm_op;
75+
CUTLASS_CHECK(gemm_op.can_implement(args));
76+
77+
size_t workspace_size = gemm_op.get_workspace_size(args);
78+
auto const workspace_options =
79+
torch::TensorOptions().dtype(torch::kUInt8).device(device);
80+
auto workspace = torch::empty(workspace_size, workspace_options);
81+
82+
auto stream = at::cuda::getCurrentCUDAStream(device.index());
83+
84+
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
85+
CUTLASS_CHECK(status);
86+
}
87+
5688
template <typename Gemm, typename... EpilogueArgs>
5789
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
5890
torch::Tensor const& b,

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace vllm {
2222

2323
using namespace cute;
2424

25-
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
25+
template <typename SchedulerType, typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
2626
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
2727
struct cutlass_3x_gemm_fp8_blockwise {
2828
using GroupSizeM = Int<GroupSizeM_>;
@@ -84,7 +84,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
8484

8585
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
8686
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
87-
cutlass::gemm::PersistentScheduler>>;
87+
SchedulerType>>;
8888

8989
struct GemmKernel : public KernelType {};
9090

@@ -154,15 +154,84 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
154154
epilogue_args);
155155
}
156156

157+
template <typename Gemm>
158+
void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, torch::Tensor const& a,
159+
torch::Tensor const& b,
160+
torch::Tensor const& a_scales,
161+
torch::Tensor const& b_scales) {
162+
using GemmKernel = typename Gemm::GemmKernel;
163+
164+
using ElementAB = typename Gemm::ElementAB;
165+
using ElementD = typename Gemm::ElementD;
166+
167+
auto prob_shape = c3x::get_problem_shape(a, b);
168+
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
169+
k = get<2>(prob_shape);
170+
171+
int64_t lda = a.stride(0);
172+
int64_t ldb = b.stride(1);
173+
int64_t ldc = out.stride(0);
174+
175+
using StrideA = Stride<int64_t, Int<1>, int64_t>;
176+
using StrideB = Stride<int64_t, Int<1>, int64_t>;
177+
using StrideC = typename Gemm::StrideC;
178+
179+
StrideA a_stride{lda, Int<1>{}, 0};
180+
StrideB b_stride{ldb, Int<1>{}, 0};
181+
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
182+
183+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
184+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
185+
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
186+
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
187+
188+
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
189+
// being 1 (i.e. a row or column vector)
190+
auto is_contiguous_vector = [](const torch::Tensor& t) {
191+
auto t_sizes = t.sizes();
192+
return t.is_contiguous() &&
193+
(t.dim() == 1 ||
194+
(t.dim() == 2 &&
195+
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
196+
};
197+
198+
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
199+
// we don't have to deal with enforcing implicit layouts
200+
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
201+
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
202+
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
203+
"a_scales must be M major");
204+
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
205+
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
206+
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
207+
"b_scales must be K major");
208+
typename GemmKernel::MainloopArguments mainloop_args{
209+
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
210+
211+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
212+
typename GemmKernel::EpilogueArguments epilogue_args{
213+
{}, c_ptr, c_stride, c_ptr, c_stride};
214+
215+
c3x::cutlass_gemm_caller_streamK<GemmKernel>(a.device(), prob_shape, mainloop_args,
216+
epilogue_args);
217+
}
218+
157219
template <typename OutType>
158220
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
159221
torch::Tensor const& a,
160222
torch::Tensor const& b,
161223
torch::Tensor const& a_scales,
162224
torch::Tensor const& b_scales) {
163-
cutlass_gemm_caller_blockwise<
164-
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
165-
b_scales);
225+
auto k = a_scales.size(1);
226+
auto n = b_scales.size(1);
227+
228+
if (k > 3 * n) {
229+
cutlass_gemm_caller_blockwise_streamK<
230+
cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(out, a, b, a_scales, b_scales);
231+
} else{
232+
cutlass_gemm_caller_blockwise<
233+
cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(out, a, b, a_scales, b_scales);
234+
}
166235
}
167236

168237
} // namespace vllm

0 commit comments

Comments
 (0)