@@ -22,8 +22,9 @@ namespace vllm {
2222
2323using namespace cute ;
2424
25- template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
26- int TileSizeM_ = 128 , class ClusterShape = Shape<_1, _2, _1>>
25+ template <typename SchedulerType, typename OutType, int GroupSizeM_,
26+ int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128 ,
27+ class ClusterShape = Shape<_1, _2, _1>>
2728struct cutlass_3x_gemm_fp8_blockwise {
2829 using GroupSizeM = Int<GroupSizeM_>;
2930 using GroupSizeN = Int<GroupSizeN_>;
@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
8485
8586 using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
8687 Shape<int , int , int , int >, CollectiveMainloop, CollectiveEpilogue,
87- cutlass::gemm::PersistentScheduler >>;
88+ SchedulerType >>;
8889
8990 struct GemmKernel : public KernelType {};
9091
@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
150151 typename GemmKernel::EpilogueArguments epilogue_args{
151152 {}, c_ptr, c_stride, c_ptr, c_stride};
152153
154+ typename GemmKernel::TileSchedulerArguments scheduler;
155+
156+ static constexpr bool UsesStreamKScheduler =
157+ cute::is_same_v<typename GemmKernel::TileSchedulerTag,
158+ cutlass::gemm::StreamKScheduler>;
159+
160+ if constexpr (UsesStreamKScheduler) {
161+ using DecompositionMode = typename cutlass::gemm::kernel::detail::
162+ PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
163+ using ReductionMode = typename cutlass::gemm::kernel::detail::
164+ PersistentTileSchedulerSm90StreamKParams::ReductionMode;
165+
166+ scheduler.decomposition_mode = DecompositionMode::StreamK;
167+ scheduler.reduction_mode = ReductionMode::Nondeterministic;
168+ }
169+
153170 c3x::cutlass_gemm_caller<GemmKernel>(a.device (), prob_shape, mainloop_args,
154- epilogue_args);
171+ epilogue_args, scheduler );
155172}
156173
157174template <typename OutType>
@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
160177 torch::Tensor const & b,
161178 torch::Tensor const & a_scales,
162179 torch::Tensor const & b_scales) {
163- cutlass_gemm_caller_blockwise<
164- cutlass_3x_gemm_fp8_blockwise<OutType, 1 , 128 , 128 >>(out, a, b, a_scales,
165- b_scales);
180+ auto k = a.size (1 );
181+ auto n = b.size (1 );
182+
183+ if (k > 3 * n) {
184+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
185+ cutlass::gemm::StreamKScheduler, OutType, 1 , 128 , 128 >>(
186+ out, a, b, a_scales, b_scales);
187+ } else {
188+ cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
189+ cutlass::gemm::PersistentScheduler, OutType, 1 , 128 , 128 >>(
190+ out, a, b, a_scales, b_scales);
191+ }
166192}
167193
168194} // namespace vllm
0 commit comments