@@ -22,7 +22,7 @@ namespace vllm {
2222
2323using namespace cute ;
2424
25- template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
25+ template <typename SchedulerType, typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
2626 int TileSizeM_ = 128 , class ClusterShape = Shape<_1, _2, _1>>
2727struct cutlass_3x_gemm_fp8_blockwise {
2828 using GroupSizeM = Int<GroupSizeM_>;
@@ -84,7 +84,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
8484
8585 using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
8686 Shape<int , int , int , int >, CollectiveMainloop, CollectiveEpilogue,
87- cutlass::gemm::PersistentScheduler >>;
87+ SchedulerType >>;
8888
8989 struct GemmKernel : public KernelType {};
9090
@@ -154,15 +154,84 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
154154 epilogue_args);
155155}
156156
157+ template <typename Gemm>
158+ void cutlass_gemm_caller_blockwise_streamK (torch::Tensor& out, torch::Tensor const & a,
159+ torch::Tensor const & b,
160+ torch::Tensor const & a_scales,
161+ torch::Tensor const & b_scales) {
162+ using GemmKernel = typename Gemm::GemmKernel;
163+
164+ using ElementAB = typename Gemm::ElementAB;
165+ using ElementD = typename Gemm::ElementD;
166+
167+ auto prob_shape = c3x::get_problem_shape (a, b);
168+ int32_t m = get<0 >(prob_shape), n = get<1 >(prob_shape),
169+ k = get<2 >(prob_shape);
170+
171+ int64_t lda = a.stride (0 );
172+ int64_t ldb = b.stride (1 );
173+ int64_t ldc = out.stride (0 );
174+
175+ using StrideA = Stride<int64_t , Int<1 >, int64_t >;
176+ using StrideB = Stride<int64_t , Int<1 >, int64_t >;
177+ using StrideC = typename Gemm::StrideC;
178+
179+ StrideA a_stride{lda, Int<1 >{}, 0 };
180+ StrideB b_stride{ldb, Int<1 >{}, 0 };
181+ StrideC c_stride{ldc, Int<1 >{}, Int<0 >{}};
182+
183+ auto a_ptr = static_cast <ElementAB*>(a.data_ptr ());
184+ auto b_ptr = static_cast <ElementAB*>(b.data_ptr ());
185+ auto a_scales_ptr = static_cast <float *>(a_scales.data_ptr ());
186+ auto b_scales_ptr = static_cast <float *>(b_scales.data_ptr ());
187+
188+ // Check is the t is contiguous and is 1D or 2D with one of the dimensions
189+ // being 1 (i.e. a row or column vector)
190+ auto is_contiguous_vector = [](const torch::Tensor& t) {
191+ auto t_sizes = t.sizes ();
192+ return t.is_contiguous () &&
193+ (t.dim () == 1 ||
194+ (t.dim () == 2 &&
195+ *std::min_element (t_sizes.begin (), t_sizes.end ()) == 1 ));
196+ };
197+
198+ // TODO(lucas): lets clean-up the kernel so that we pass in Strides so
199+ // we don't have to deal with enforcing implicit layouts
200+ TORCH_CHECK (a_scales.size (0 ) == m / Gemm::GroupSizeM::value);
201+ TORCH_CHECK (a_scales.size (1 ) == k / Gemm::GroupSizeK::value);
202+ TORCH_CHECK (a_scales.stride (0 ) == 1 || is_contiguous_vector (a_scales),
203+ " a_scales must be M major" );
204+ TORCH_CHECK (b_scales.size (0 ) == k / Gemm::GroupSizeK::value);
205+ TORCH_CHECK (b_scales.size (1 ) == n / Gemm::GroupSizeN::value);
206+ TORCH_CHECK (b_scales.stride (0 ) == 1 || is_contiguous_vector (b_scales),
207+ " b_scales must be K major" );
208+ typename GemmKernel::MainloopArguments mainloop_args{
209+ a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
210+
211+ auto c_ptr = static_cast <ElementD*>(out.data_ptr ());
212+ typename GemmKernel::EpilogueArguments epilogue_args{
213+ {}, c_ptr, c_stride, c_ptr, c_stride};
214+
215+ c3x::cutlass_gemm_caller_streamK<GemmKernel>(a.device (), prob_shape, mainloop_args,
216+ epilogue_args);
217+ }
218+
157219template <typename OutType>
158220void cutlass_gemm_blockwise_sm90_fp8_dispatch (torch::Tensor& out,
159221 torch::Tensor const & a,
160222 torch::Tensor const & b,
161223 torch::Tensor const & a_scales,
162224 torch::Tensor const & b_scales) {
163- cutlass_gemm_caller_blockwise<
164- cutlass_3x_gemm_fp8_blockwise<OutType, 1 , 128 , 128 >>(out, a, b, a_scales,
165- b_scales);
225+ auto k = a_scales.size (1 );
226+ auto n = b_scales.size (1 );
227+
228+ if (k > 3 * n) {
229+ cutlass_gemm_caller_blockwise_streamK<
230+ cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::StreamKScheduler, OutType, 1 , 128 , 128 >>(out, a, b, a_scales, b_scales);
231+ } else {
232+ cutlass_gemm_caller_blockwise<
233+ cutlass_3x_gemm_fp8_blockwise<cutlass::gemm::PersistentScheduler, OutType, 1 , 128 , 128 >>(out, a, b, a_scales, b_scales);
234+ }
166235}
167236
168237} // namespace vllm
0 commit comments