diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 9ac7eee7204e..69a3f64cb0b0 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -30,12 +30,18 @@ static inline cute::Shape get_problem_shape( } template -void cutlass_gemm_caller(torch::Device device, - cute::Shape prob_shape, - typename GemmKernel::MainloopArguments mainloop_args, - typename GemmKernel::EpilogueArguments epilogue_args) { +void cutlass_gemm_caller( + torch::Device device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args, + typename GemmKernel::TileSchedulerArguments scheduler = {}) { + cutlass::KernelHardwareInfo hw_info; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; + prob_shape, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index fb7a82b80ee6..e089c3d4be2c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -22,8 +22,9 @@ namespace vllm { using namespace cute; -template > +template > struct cutlass_3x_gemm_fp8_blockwise { using GroupSizeM = Int; using GroupSizeN = Int; @@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise { using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + SchedulerType>>; struct GemmKernel : public KernelType {}; @@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; + typename GemmKernel::TileSchedulerArguments scheduler; + + static constexpr bool UsesStreamKScheduler = + cute::is_same_v; + + if constexpr (UsesStreamKScheduler) { + using DecompositionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; + + scheduler.decomposition_mode = DecompositionMode::StreamK; + scheduler.reduction_mode = ReductionMode::Nondeterministic; + } + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, - epilogue_args); + epilogue_args, scheduler); } template @@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - cutlass_gemm_caller_blockwise< - cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, - b_scales); + auto k = a.size(1); + auto n = b.size(1); + + if (k > 3 * n) { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } } } // namespace vllm \ No newline at end of file