From 1c3b5ea67db1a6d9aeedd4d0d875e9887aa7e73c Mon Sep 17 00:00:00 2001 From: leoneo <1320612015@qq.com> Date: Sun, 9 Feb 2025 14:47:57 +0800 Subject: [PATCH 1/4] Add streamK for block-quantized CUTLASS kernels Signed-off-by: leoneo <1320612015@qq.com> --- .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 32 ++++++++ .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 79 +++++++++++++++++-- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 9ac7eee7204e..787095dfa73b 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -53,6 +53,38 @@ void cutlass_gemm_caller(torch::Device device, CUTLASS_CHECK(status); } +template +void cutlass_gemm_caller_streamK(torch::Device device, + cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args) { + + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // add args for StreamK + using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; + args.scheduler.decomposition_mode = DecompositionMode::StreamK; + args.scheduler.reduction_mode = ReductionMode::Nondeterministic; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(device); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + template void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index fb7a82b80ee6..a0b16c2ffa2a 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -22,7 +22,7 @@ namespace vllm { using namespace cute; -template > struct cutlass_3x_gemm_fp8_blockwise { using GroupSizeM = Int; @@ -84,7 +84,7 @@ struct cutlass_3x_gemm_fp8_blockwise { using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + SchedulerType>>; struct GemmKernel : public KernelType {}; @@ -154,15 +154,84 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, epilogue_args); } +template +void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + auto prob_shape = c3x::get_problem_shape(a, b); + int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), + k = get<2>(prob_shape); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + // Check is the t is contiguous and is 1D or 2D with one of the dimensions + // being 1 (i.e. a row or column vector) + auto is_contiguous_vector = [](const torch::Tensor& t) { + auto t_sizes = t.sizes(); + return t.is_contiguous() && + (t.dim() == 1 || + (t.dim() == 2 && + *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); + }; + + // TODO(lucas): lets clean-up the kernel so that we pass in Strides so + // we don't have to deal with enforcing implicit layouts + TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); + TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), + "a_scales must be M major"); + TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); + TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), + "b_scales must be K major"); + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller_streamK(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + template void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - cutlass_gemm_caller_blockwise< - cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, - b_scales); + auto k = a_scales.size(1); + auto n = b_scales.size(1); + + if (k > 3 * n) { + cutlass_gemm_caller_blockwise_streamK< + cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, b_scales); + } else{ + cutlass_gemm_caller_blockwise< + cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, b_scales); + } } } // namespace vllm \ No newline at end of file From 68f18b418a545c743bc95471f6a7c4b664d5bc7e Mon Sep 17 00:00:00 2001 From: leoneo <1320612015@qq.com> Date: Sun, 9 Feb 2025 17:24:29 +0800 Subject: [PATCH 2/4] Add streamK for block-quantized CUTLASS kernels Signed-off-by: leoneo <1320612015@qq.com> --- .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 32 +++++++ .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 85 +++++++++++++++++-- 2 files changed, 111 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 9ac7eee7204e..82ae6bef6c0b 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -53,6 +53,38 @@ void cutlass_gemm_caller(torch::Device device, CUTLASS_CHECK(status); } +template +void cutlass_gemm_caller_streamK( + torch::Device device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args) { + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // add args for StreamK + using DecompositionMode = cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; + args.scheduler.decomposition_mode = DecompositionMode::StreamK; + args.scheduler.reduction_mode = ReductionMode::Nondeterministic; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(device); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + template void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index fb7a82b80ee6..5cce90411552 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -22,8 +22,9 @@ namespace vllm { using namespace cute; -template > +template > struct cutlass_3x_gemm_fp8_blockwise { using GroupSizeM = Int; using GroupSizeN = Int; @@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise { using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + SchedulerType>>; struct GemmKernel : public KernelType {}; @@ -154,15 +155,87 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, epilogue_args); } +template +void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + auto prob_shape = c3x::get_problem_shape(a, b); + int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), + k = get<2>(prob_shape); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + // Check is the t is contiguous and is 1D or 2D with one of the dimensions + // being 1 (i.e. a row or column vector) + auto is_contiguous_vector = [](const torch::Tensor& t) { + auto t_sizes = t.sizes(); + return t.is_contiguous() && + (t.dim() == 1 || + (t.dim() == 2 && + *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); + }; + + // TODO(lucas): lets clean-up the kernel so that we pass in Strides so + // we don't have to deal with enforcing implicit layouts + TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); + TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), + "a_scales must be M major"); + TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); + TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); + TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), + "b_scales must be K major"); + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + c3x::cutlass_gemm_caller_streamK(a.device(), prob_shape, + mainloop_args, epilogue_args); +} + template void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - cutlass_gemm_caller_blockwise< - cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, - b_scales); + auto k = a_scales.size(1); + auto n = b_scales.size(1); + + if (k > 3 * n) { + cutlass_gemm_caller_blockwise_streamK>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); + } } } // namespace vllm \ No newline at end of file From 2dff762a1677364f2d6396e6ef8c338bc340c799 Mon Sep 17 00:00:00 2001 From: leoneo <1320612015@qq.com> Date: Sun, 9 Feb 2025 17:32:12 +0800 Subject: [PATCH 3/4] [Kernel]Add streamK for block-quantized CUTLASS kernels Signed-off-by: leoneo <1320612015@qq.com> --- .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 33 ++++++++--- .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 59 +++++++++++++++---- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 787095dfa73b..d0c496c76d8f 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -54,18 +54,33 @@ void cutlass_gemm_caller(torch::Device device, } template -void cutlass_gemm_caller_streamK(torch::Device device, - cute::Shape prob_shape, - typename GemmKernel::MainloopArguments mainloop_args, - typename GemmKernel::EpilogueArguments epilogue_args) { - - +<<<<<<< HEAD +void cutlass_gemm_caller_streamK( + torch::Device device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args) { + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // add args for StreamK + using DecompositionMode = cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; +======= +void cutlass_gemm_caller_streamK( + torch::Device device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args) { typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args}; - + // add args for StreamK - using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; + using DecompositionMode = cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; +>>>>>>> 68f18b41 (Add streamK for block-quantized CUTLASS kernels) args.scheduler.decomposition_mode = DecompositionMode::StreamK; args.scheduler.reduction_mode = ReductionMode::Nondeterministic; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index a0b16c2ffa2a..170fd25c3982 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -22,8 +22,15 @@ namespace vllm { using namespace cute; -template > +<<<<<<< HEAD +template > +======= +template > +>>>>>>> 68f18b41 (Add streamK for block-quantized CUTLASS kernels) struct cutlass_3x_gemm_fp8_blockwise { using GroupSizeM = Int; using GroupSizeN = Int; @@ -155,10 +162,19 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +<<<<<<< HEAD +void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { +======= +void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { +>>>>>>> 68f18b41 (Add streamK for block-quantized CUTLASS kernels) using GemmKernel = typename Gemm::GemmKernel; using ElementAB = typename Gemm::ElementAB; @@ -212,8 +228,13 @@ void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, torch::Tensor con typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; - c3x::cutlass_gemm_caller_streamK(a.device(), prob_shape, mainloop_args, - epilogue_args); +<<<<<<< HEAD + c3x::cutlass_gemm_caller_streamK(a.device(), prob_shape, + mainloop_args, epilogue_args); +======= + c3x::cutlass_gemm_caller_streamK(a.device(), prob_shape, + mainloop_args, epilogue_args); +>>>>>>> 68f18b41 (Add streamK for block-quantized CUTLASS kernels) } template @@ -226,11 +247,23 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, auto n = b_scales.size(1); if (k > 3 * n) { - cutlass_gemm_caller_blockwise_streamK< - cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, b_scales); - } else{ - cutlass_gemm_caller_blockwise< - cutlass_3x_gemm_fp8_blockwise>(out, a, b, a_scales, b_scales); +<<<<<<< HEAD + cutlass_gemm_caller_blockwise_streamK>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); +======= + cutlass_gemm_caller_blockwise_streamK>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise>( + out, a, b, a_scales, b_scales); +>>>>>>> 68f18b41 (Add streamK for block-quantized CUTLASS kernels) } } From b986c01404cc185ad61917d2a6eabb33834a70c5 Mon Sep 17 00:00:00 2001 From: leoneo <1320612015@qq.com> Date: Tue, 11 Feb 2025 13:44:13 +0800 Subject: [PATCH 4/4] fix some nits Signed-off-by: leoneo <1320612015@qq.com> --- .../cutlass_w8a8/c3x/cutlass_gemm_caller.cuh | 45 +++-------- .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 81 ++++--------------- 2 files changed, 26 insertions(+), 100 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index ea26dd89286e..69a3f64cb0b0 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -30,45 +30,18 @@ static inline cute::Shape get_problem_shape( } template -void cutlass_gemm_caller(torch::Device device, - cute::Shape prob_shape, - typename GemmKernel::MainloopArguments mainloop_args, - typename GemmKernel::EpilogueArguments epilogue_args) { - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(device); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(device.index()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template -void cutlass_gemm_caller_streamK( +void cutlass_gemm_caller( torch::Device device, cute::Shape prob_shape, typename GemmKernel::MainloopArguments mainloop_args, - typename GemmKernel::EpilogueArguments epilogue_args) { + typename GemmKernel::EpilogueArguments epilogue_args, + typename GemmKernel::TileSchedulerArguments scheduler = {}) { + cutlass::KernelHardwareInfo hw_info; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // add args for StreamK - using DecompositionMode = cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using ReductionMode = cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::ReductionMode; - - args.scheduler.decomposition_mode = DecompositionMode::StreamK; - args.scheduler.reduction_mode = ReductionMode::Nondeterministic; + prob_shape, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index 5cce90411552..e089c3d4be2c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -151,71 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; - c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, - epilogue_args); -} - -template -void cutlass_gemm_caller_blockwise_streamK(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - using GemmKernel = typename Gemm::GemmKernel; - - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - auto prob_shape = c3x::get_problem_shape(a, b); - int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), - k = get<2>(prob_shape); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + typename GemmKernel::TileSchedulerArguments scheduler; - // Check is the t is contiguous and is 1D or 2D with one of the dimensions - // being 1 (i.e. a row or column vector) - auto is_contiguous_vector = [](const torch::Tensor& t) { - auto t_sizes = t.sizes(); - return t.is_contiguous() && - (t.dim() == 1 || - (t.dim() == 2 && - *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); - }; + static constexpr bool UsesStreamKScheduler = + cute::is_same_v; - // TODO(lucas): lets clean-up the kernel so that we pass in Strides so - // we don't have to deal with enforcing implicit layouts - TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); - TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), - "a_scales must be M major"); - TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); - TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), - "b_scales must be K major"); - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + if constexpr (UsesStreamKScheduler) { + using DecompositionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - {}, c_ptr, c_stride, c_ptr, c_stride}; + scheduler.decomposition_mode = DecompositionMode::StreamK; + scheduler.reduction_mode = ReductionMode::Nondeterministic; + } - c3x::cutlass_gemm_caller_streamK(a.device(), prob_shape, - mainloop_args, epilogue_args); + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args, scheduler); } template @@ -224,11 +177,11 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - auto k = a_scales.size(1); - auto n = b_scales.size(1); + auto k = a.size(1); + auto n = b.size(1); if (k > 3 * n) { - cutlass_gemm_caller_blockwise_streamK>( out, a, b, a_scales, b_scales); } else {