From 1825ef883bbfdb8f2546a02a25609a37ac2ecdc5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Dec 2024 14:36:05 +0000 Subject: [PATCH 01/58] Cutlass grouped gemm files Signed-off-by: ElizaWszola --- CMakeLists.txt | 9 +- csrc/cpu/torch_bindings.cpp | 7 + csrc/ops.h | 8 + .../cutlass_w8a8/grouped_gemm_test.cu | 397 ++++++++++++++++++ .../cutlass_w8a8/scaled_mm_entry.cu | 20 + csrc/torch_bindings.cpp | 8 + tests/kernels/test_cutlass.py | 68 +++ 7 files changed, 514 insertions(+), 3 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 5acbd762ee95..9d6185e75633 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -209,13 +209,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + # GIT_TAG v3.5.1 + GIT_TAG dbdae514e03f83968f8b7dd4fb064071b9bfbdd1 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(cutlass) @@ -261,7 +262,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 03beefbc6de7..d6c32322ff59 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,6 +118,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); +// CUTLASS w8a8 grouped GEMM // TODO complete this + ops.def( + "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " + " Tensor b_scales, Tensor problem_sizes, " + " Tensor out_offsets, Tensor a_offsets, " + " Tensor b_offsets) -> ()"); + ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/csrc/ops.h b/csrc/ops.h index 672e608e9c47..fce4346fa421 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,6 +145,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); +void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu new file mode 100644 index 000000000000..8e46b9a33cea --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -0,0 +1,397 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" + +// TODO let's see which of these we'll need + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +#include "common.hpp" + +// get rid of these? +// #include "helper.h" +// using namespace cute; + +using namespace cute; + +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +namespace { + + // A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementAB_Type = cutlass::float_e4m3_t; // Element type for A matrix operand +// using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ElementC_Type = cutlass::half_t; + +// // A matrix configuration +// using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +// constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// // B matrix configuration +// using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +// constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// // C/D matrix configuration +// using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +// constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +// using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Different configs for pingpong/cooperative +// struct CooperativeConfig { +// using KernelSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; +// using EpilogueSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperative; +// using TileShape = cute::Shape; +// using ClusterShape = cute::Shape; +// }; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::ColumnMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_group_gemm { + + using ElementAB = ElementAB_; + using ElementC = ElementC_; + using ElementAccumulator = float; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementC, + ElementC, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideC = cute::remove_pointer_t, cute::Int<0>>>; + + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + // the orig hat cutlass::epilogue::fusion::LinearCombination + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, 4, + ElementC, LayoutC*, 4, + EpilogueSchedule, EVTCompute + >::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementAB, LayoutA*, 16, + ElementAB, LayoutB*, 16, + ElementAccumulator, + TileShape, ClusterShape, + Stages, KernelSchedule + >::CollectiveOp; + + using KernelType = enable_sm90_or_later>; + + struct GemmKernel : public KernelType {}; +}; + +template +struct ItemDeleter { + void operator()(T* ptr) { + cudaFree(ptr); // noexcept + } +}; + +template +void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + // using ElementC = typename Gemm::ElementC; + using ElementC = typename Gemm::ElementC; + using ElementAcc = float; + + int groups = problem_sizes.size(0); + std::vector a_ptrs_host(groups); + std::vector b_ptrs_host(groups); + std::vector c_ptrs_host(groups); + std::vector d_ptrs_host(groups); + + for (int g = 0; g < groups; ++g) { + a_ptrs_host.at(g) = (ElementAB*)a.data_ptr();// + a_offsets[g].item(); + b_ptrs_host.at(g) = (ElementAB*)b.data_ptr();// + b_offsets[g].item(); + c_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + d_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + } + + // int32_t groups = a.size(0); + // int32_t m = a.size(1); + // int32_t n = b.size(2); + // int32_t k = a.size(2); + + // int64_t lda = a.stride(1); + // int64_t ldb = b.stride(2); + // int64_t ldc = out.stride(1); + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // StrideA stride_A{lda, cute::Int<1>{}, 0}; + // StrideB stride_B{ldb, cute::Int<1>{}, 0}; + // StrideC stride_C{ldc, cute::Int<1>{}, cute::Int<0>{}}; + + // this should be vector of A ptrs + // auto ptr_A = static_cast(a.data_ptr()); + // auto ptr_B = static_cast(b.data_ptr()); + // auto ptr_C = static_cast(out.data_ptr()); + + cutlass::platform::unique_ptr stride_A; + cutlass::platform::unique_ptr stride_B; + cutlass::platform::unique_ptr stride_C; + cutlass::platform::unique_ptr stride_D; + + cutlass::platform::unique_ptr ptr_A; + cutlass::platform::unique_ptr ptr_B; + cutlass::platform::unique_ptr ptr_C; + cutlass::platform::unique_ptr ptr_D; + + using GemmKernel = typename Gemm::GemmKernel; + + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; + + std::vector problem_sizes_host; + problem_sizes_host.reserve(groups); + for (int32_t g = 0; g < groups; ++g) { + int32_t m = problem_sizes[g][0].item(); + int32_t n = problem_sizes[g][1].item(); + int32_t k = problem_sizes[g][2].item(); + problem_sizes_host.push_back({m, n, k}); + } + + SingleProblemShape* problem_sizes_device; + int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); + cudaMalloc(&problem_sizes_device, problem_sizes_size); + cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), groups, + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> problem_sizes_ptr( + problem_sizes_device); + ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; + + const ElementAB** a_ptrs_device; + cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> a_ptrs_ptr( + a_ptrs_device + ); + + const ElementAB** b_ptrs_device; + cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> b_ptrs_ptr( + b_ptrs_device + ); + + const ElementC** c_ptrs_device; + cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> c_ptrs_ptr( + c_ptrs_device + ); + + ElementC** d_ptrs_device; + cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> d_ptrs_ptr( + d_ptrs_device + ); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptrs_ptr.get(), stride_A.get(), b_ptrs_ptr.get(), stride_B.get()}; + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptrs_ptr.get(), stride_C.get(), d_ptrs_ptr.get(), stride_D.get()}; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + prob_shape, + mainloop_args, + epilogue_args, + hw_info + }; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + // // auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); + + // #if defined(ENABLE_SM90_KERNEL_LEVEL) + // printf("did run through\n"); + cutlass::Status status = gemm_op.run(); + CUTLASS_CHECK(status); + // #endif + +} + +// typedef InType = cutlass::float_e4m3_t; +// typedef OutType = torch::half; +// typedef Epilogue = ScaledEpilogueBias; + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +} + +// TODO hardcode types here? +void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets) { + + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + // int32_t m = a.size(1); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config_default::Cutlass3xGemm; + // using Cutlass3xGemmM64 = + // typename sm90_fp8_config_M64::Cutlass3xGemm; + // using Cutlass3xGemmM128 = + // typename sm90_fp8_config_M128::Cutlass3xGemm; + + + // // uint32_t const m = a.size(0); + // uint32_t const mp2 = + // std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + // if (mp2 <= 64) { + // // m in [1, 64] + // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // } else if (mp2 <= 128) { + // // m in (64, 128] + // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // } else { + // // m in (128, inf) + cutlass_group_gemm_caller(out, a, b, problem_sizes, + out_offsets, a_offsets, b_offsets, a_scales, b_scales); + // } + +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e..78225f9b0db0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -27,6 +27,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); + +void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets); + #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -151,6 +160,17 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } +void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets) { + cutlass_grouped_mm_sm90(out, a, b, a_scales, b_scales, problem_sizes, + out_offsets, a_offsets, b_offsets); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e4cc7ec95184..a10c661b22a6 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -311,6 +311,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // CUTLASS w8a8 grouped GEMM // TODO complete this + ops.def( + "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " + " Tensor b_scales, Tensor problem_sizes, " + " Tensor out_offsets, Tensor a_offsets, " + " Tensor b_offsets) -> ()"); + ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index afe53797322f..6228c908545d 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -6,6 +6,7 @@ import pytest import torch +import random from tests.kernels.utils import opcheck from vllm import _custom_ops as ops @@ -453,3 +454,70 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + +# TODO fix scales +@pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) +@pytest.mark.parametrize("num_groups", [10]) +@pytest.mark.parametrize("per_act_token", [False])# [True, False]) +@pytest.mark.parametrize("per_out_ch", [True])# [True, False]) +@pytest.mark.parametrize("use_bias", [False])# [True, False]) +@pytest.mark.skipif(not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, + per_act_token: bool, + per_out_ch: bool, use_bias: bool): + + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + device = "cuda" + out_dtype = torch.half + + alignment = 16 # 128 // 8 + problem_sizes = torch.empty((num_groups, 3), device="cpu") + offsets_a = torch.empty((num_groups), device="cpu") + offsets_b = torch.empty((num_groups), device="cpu") + offsets_c = torch.empty((num_groups), device="cpu") + tot_a = 0 + tot_b = 0 + tot_c = 0 + for g in range(num_groups): + m = alignment * random.randint(1, 64) + n = alignment * random.randint(1, 64) + k = alignment * random.randint(1, 64) + tot_a += m * k + tot_b += k * n + tot_c += m * n + offsets_a[g] = m * k + offsets_b[g] = k * n + offsets_c[g] = m * n + problem_sizes[g][0] = m + problem_sizes[g][1] = n + problem_sizes[g][2] = k + + a = to_fp8(torch.randn((tot_a), device=device)) + b = to_fp8(torch.randn((tot_b), device=device).t()) + c = torch.zeros((tot_c), device=device).to(out_dtype) + + m_a_scales = m if per_act_token else 1 + n_b_scales = n if per_out_ch else 1 + + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) + if use_bias: + bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 + else: + bias = None + + # TODO strides we can get later the same way as in scaled_mm_c3x.cu + torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, + offsets_c, offsets_a, offsets_b) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + print(c) + + # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) + + # opcheck(torch.ops._C.cutlass_scaled_mm, + # (out, a, b, scale_a, scale_b, bias)) From 5fd48e5b4270cc43428f149eb731ec117b2afec8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 9 Dec 2024 12:20:50 +0000 Subject: [PATCH 02/58] runs, bad result Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_gemm_test.cu | 107 ++++++++---------- tests/kernels/test_cutlass.py | 34 +++--- 2 files changed, 68 insertions(+), 73 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 8e46b9a33cea..004599c2b5d2 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -58,35 +58,14 @@ using ElementAB_Type = cutlass::float_e4m3_t; // using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand using ElementC_Type = cutlass::half_t; -// // A matrix configuration -// using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -// constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) - -// // B matrix configuration -// using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand -// constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) - -// // C/D matrix configuration -// using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands -// constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) - // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -// using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size - -// Different configs for pingpong/cooperative -// struct CooperativeConfig { -// using KernelSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; -// using EpilogueSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperative; -// using TileShape = cute::Shape; -// using ClusterShape = cute::Shape; -// }; -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::ColumnMajor; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::ColumnMajor; template typename Epilogue_, @@ -107,8 +86,8 @@ struct cutlass_3x_group_gemm { using StrideC = cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; // the orig hat cutlass::epilogue::fusion::LinearCombination @@ -172,34 +151,25 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = (ElementAB*)a.data_ptr();// + a_offsets[g].item(); - b_ptrs_host.at(g) = (ElementAB*)b.data_ptr();// + b_offsets[g].item(); - c_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); - d_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item(); + b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item(); + c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); + d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); } - // int32_t groups = a.size(0); - // int32_t m = a.size(1); - // int32_t n = b.size(2); - // int32_t k = a.size(2); - - // int64_t lda = a.stride(1); - // int64_t ldb = b.stride(2); - // int64_t ldc = out.stride(1); - using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideC = typename Gemm::GemmKernel::InternalStrideC; using StrideD = typename Gemm::GemmKernel::InternalStrideD; - // StrideA stride_A{lda, cute::Int<1>{}, 0}; - // StrideB stride_B{ldb, cute::Int<1>{}, 0}; - // StrideC stride_C{ldc, cute::Int<1>{}, cute::Int<0>{}}; + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); - // this should be vector of A ptrs - // auto ptr_A = static_cast(a.data_ptr()); - // auto ptr_B = static_cast(b.data_ptr()); - // auto ptr_C = static_cast(out.data_ptr()); + std::vector a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}}); + std::vector b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}}); + // TODO fix + std::vector c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}}); cutlass::platform::unique_ptr stride_A; cutlass::platform::unique_ptr stride_B; @@ -212,7 +182,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, cutlass::platform::unique_ptr ptr_D; using GemmKernel = typename Gemm::GemmKernel; - + cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish // to use a GPU other than that with device ID 0. @@ -241,38 +211,60 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, const ElementAB** a_ptrs_device; cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> a_ptrs_ptr( a_ptrs_device ); const ElementAB** b_ptrs_device; cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> b_ptrs_ptr( b_ptrs_device ); const ElementC** c_ptrs_device; cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> c_ptrs_ptr( c_ptrs_device ); + // TODO if we start with empty values here, no need to copy ElementC** d_ptrs_device; cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> d_ptrs_ptr( d_ptrs_device ); + StrideA* a_stride_device; + cudaMalloc(&a_stride_device, groups * sizeof(StrideA*)); + cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> a_stride_ptr( + a_stride_device + ); + + StrideB* b_stride_device; + cudaMalloc(&b_stride_device, groups * sizeof(StrideB*)); + cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> b_stride_ptr( + b_stride_device + ); + + StrideC* c_stride_device; + cudaMalloc(&c_stride_device, groups * sizeof(StrideC*)); + cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> c_stride_ptr( + c_stride_device + ); + typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), stride_A.get(), b_ptrs_ptr.get(), stride_B.get()}; + a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptrs_ptr.get(), stride_C.get(), d_ptrs_ptr.get(), stride_D.get()}; + c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, @@ -296,11 +288,8 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); - // #if defined(ENABLE_SM90_KERNEL_LEVEL) - // printf("did run through\n"); - cutlass::Status status = gemm_op.run(); - CUTLASS_CHECK(status); - // #endif + cutlass::Status status = gemm_op.run(); + CUTLASS_CHECK(status); } @@ -367,7 +356,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - // int32_t m = a.size(1); + // int32_t m = a.size(1); using Cutlass3xGemmDefault = typename sm90_fp8_config_default Date: Tue, 10 Dec 2024 15:24:17 +0000 Subject: [PATCH 03/58] A little closer to working Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 2 +- .../cutlass_w8a8/grouped_gemm_test.cu | 305 +++++++++--------- .../cutlass_w8a8/scaled_mm_entry.cu | 12 +- tests/kernels/test_cutlass.py | 87 +++-- 4 files changed, 224 insertions(+), 182 deletions(-) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index d6c32322ff59..80a326cdc5ef 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,7 +118,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); -// CUTLASS w8a8 grouped GEMM // TODO complete this + // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " " Tensor b_scales, Tensor problem_sizes, " diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 004599c2b5d2..db86bd1a4b46 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -33,12 +33,12 @@ using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 -#define ENABLE_SM90_KERNEL_LEVEL 1 + #define ENABLE_SM90_KERNEL_LEVEL 1 #endif namespace { - // A wrapper for the GEMM kernel that is used to guard against compilation on +// A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef @@ -47,32 +47,36 @@ template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); - #endif +#endif } }; -using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group -using ElementAB_Type = cutlass::float_e4m3_t; // Element type for A matrix operand -// using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // + // per group +using ElementAB_Type = + cutlass::float_e4m3_t; // Element type for A matrix operand +// using ElementB = cutlass::float_e4m3_t; // +// Element type for B matrix operand using ElementC_Type = cutlass::half_t; // Core kernel configurations -using ElementAccumulator = float; // Element type for internal accumulation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::ColumnMajor; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_group_gemm { - using ElementAB = ElementAB_; using ElementC = ElementC_; using ElementAccumulator = float; @@ -84,42 +88,36 @@ struct cutlass_3x_group_gemm { using Epilogue = Epilogue_; - using StrideC = cute::remove_pointer_t, cute::Int<0>>>; + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; - // the orig hat cutlass::epilogue::fusion::LinearCombination - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, - ElementC, LayoutC*, 4, - ElementC, LayoutC*, 4, - EpilogueSchedule, EVTCompute - >::CollectiveOp; + // the orig hat cutlass::epilogue::fusion::LinearCombination + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, 4, ElementC, LayoutC*, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementAB, LayoutA*, 16, - ElementAB, LayoutB*, 16, - ElementAccumulator, - TileShape, ClusterShape, - Stages, KernelSchedule - >::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, 16, ElementAB, LayoutB*, + 16, ElementAccumulator, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp; using KernelType = enable_sm90_or_later>; + ProblemShape, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; }; @@ -127,20 +125,19 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder template struct ItemDeleter { void operator()(T* ptr) { - cudaFree(ptr); // noexcept + cudaFree(ptr); // noexcept } }; template void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets, - EpilogueArgs&&... epilogue_params) { + torch::Tensor const& b, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets, + EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; - // using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC; using ElementAcc = float; @@ -151,43 +148,48 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item(); - b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item(); - c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); - d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); + a_ptrs_host.at(g) = + static_cast(a.data_ptr()) + a_offsets[g].item(); + b_ptrs_host.at(g) = + static_cast(b.data_ptr()) + b_offsets[g].item(); + c_ptrs_host.at(g) = + static_cast(out.data_ptr()) + out_offsets[g].item(); + d_ptrs_host.at(g) = + static_cast(out.data_ptr()) + out_offsets[g].item(); + printf("%d %d %d\n", a_offsets[g].item(), + b_offsets[g].item(), out_offsets[g].item()); } - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - std::vector a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}}); - std::vector b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}}); - // TODO fix - std::vector c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}}); + using GemmKernel = typename Gemm::GemmKernel; - cutlass::platform::unique_ptr stride_A; - cutlass::platform::unique_ptr stride_B; - cutlass::platform::unique_ptr stride_C; - cutlass::platform::unique_ptr stride_D; + using StrideA = typename GemmKernel::InternalStrideA; + using StrideB = typename GemmKernel::InternalStrideB; + using StrideC = typename GemmKernel::InternalStrideC; + // using StrideD = typename GemmKernel::InternalStrideD; - cutlass::platform::unique_ptr ptr_A; - cutlass::platform::unique_ptr ptr_B; - cutlass::platform::unique_ptr ptr_C; - cutlass::platform::unique_ptr ptr_D; + std::vector a_stride_host(groups); + std::vector b_stride_host(groups); + std::vector c_stride_host(groups); - using GemmKernel = typename Gemm::GemmKernel; + for (int g = 0; g < groups; ++g) { + int32_t m = problem_sizes[g][0].item(); + int32_t n = problem_sizes[g][1].item(); + int32_t k = problem_sizes[g][2].item(); + a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, + // row + b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, + // col + c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, + // row + } cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with multiple GPUs and wish - // to use a GPU other than that with device ID 0. + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; @@ -203,76 +205,83 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, SingleProblemShape* problem_sizes_device; int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); cudaMalloc(&problem_sizes_device, problem_sizes_size); - cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), groups, - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> problem_sizes_ptr( - problem_sizes_device); - ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; + cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), + groups * sizeof(SingleProblemShape), cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + problem_sizes_ptr(problem_sizes_device); + ProblemShape prob_shape{groups, problem_sizes_ptr.get(), + problem_sizes_host.data()}; + + // ElementAB* a_host_print; + // int numel = a.numel(); + // cudaMalloc(&a_host_print, groups * sizeof(ElementAB)); + // cudaMemcpy(a_host_print, static_cast(a.data_ptr()), numel* + // sizeof(ElementAB), cudaMemcpyDeviceToHost); + // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* + // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); const ElementAB** a_ptrs_device; cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> a_ptrs_ptr( - a_ptrs_device - ); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups * sizeof(ElementAB*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + a_ptrs_ptr(a_ptrs_device); const ElementAB** b_ptrs_device; cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> b_ptrs_ptr( - b_ptrs_device - ); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups * sizeof(ElementAB*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + b_ptrs_ptr(b_ptrs_device); const ElementC** c_ptrs_device; cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> c_ptrs_ptr( - c_ptrs_device - ); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups * sizeof(ElementC*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + c_ptrs_ptr(c_ptrs_device); - // TODO if we start with empty values here, no need to copy ElementC** d_ptrs_device; cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups * sizeof(ElementC*), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> d_ptrs_ptr( - d_ptrs_device - ); + d_ptrs_device); StrideA* a_stride_device; - cudaMalloc(&a_stride_device, groups * sizeof(StrideA*)); - cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&a_stride_device, groups * sizeof(StrideA)); + cudaMemcpy(a_stride_device, a_stride_host.data(), groups * sizeof(StrideA), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> a_stride_ptr( - a_stride_device - ); + a_stride_device); StrideB* b_stride_device; - cudaMalloc(&b_stride_device, groups * sizeof(StrideB*)); - cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&b_stride_device, groups * sizeof(StrideB)); + cudaMemcpy(b_stride_device, b_stride_host.data(), groups * sizeof(StrideB), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> b_stride_ptr( - b_stride_device - ); + b_stride_device); StrideC* c_stride_device; - cudaMalloc(&c_stride_device, groups * sizeof(StrideC*)); - cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&c_stride_device, groups * sizeof(StrideC)); + cudaMemcpy(c_stride_device, c_stride_host.data(), groups * sizeof(StrideC), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> c_stride_ptr( - c_stride_device - ); + c_stride_device); typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; + a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), + b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; + c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), + c_stride_ptr.get()}; typename GemmKernel::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - prob_shape, - mainloop_args, - epilogue_args, - hw_info - }; + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args, hw_info}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; @@ -284,18 +293,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); - // // auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); - - cutlass::Status status = gemm_op.run(); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); - } // typedef InType = cutlass::float_e4m3_t; // typedef OutType = torch::half; -// typedef Epilogue = ScaledEpilogueBias; template typename Epilogue> @@ -304,12 +309,13 @@ struct sm90_fp8_config_default { static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; template ()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; template ()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; -} +} // namespace // TODO hardcode types here? -void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets) { - +void cutlass_grouped_mm_sm90( + torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, torch::Tensor const& b_offsets) { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); // int32_t m = a.size(1); - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_fp8_config_default< + ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogue>::Cutlass3xGemm; // using Cutlass3xGemmM64 = - // typename sm90_fp8_config_M64::Cutlass3xGemm; + // typename sm90_fp8_config_M64::Cutlass3xGemm; // using Cutlass3xGemmM128 = - // typename sm90_fp8_config_M128::Cutlass3xGemm; - + // typename sm90_fp8_config_M128::Cutlass3xGemm; // // uint32_t const m = a.size(0); // uint32_t const mp2 = @@ -373,14 +378,16 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, // if (mp2 <= 64) { // // m in [1, 64] - // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // cutlass_group_gemm_caller(out, a, b, a_scales, + // b_scales); // } else if (mp2 <= 128) { // // m in (64, 128] - // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // cutlass_group_gemm_caller(out, a, b, a_scales, + // b_scales); // } else { // // m in (128, inf) - cutlass_group_gemm_caller(out, a, b, problem_sizes, - out_offsets, a_offsets, b_offsets, a_scales, b_scales); + cutlass_group_gemm_caller( + out, a, b, problem_sizes, out_offsets, a_offsets, b_offsets, a_scales, + b_scales); // } - } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 78225f9b0db0..961437893dee 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -28,13 +28,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets); +void cutlass_grouped_mm_sm90( + torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, torch::Tensor const& b_offsets); #endif diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index a97c8f307df3..563a3f433d98 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,11 +2,11 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import random from typing import Optional, Type import pytest import torch -import random from tests.kernels.utils import opcheck from vllm import _custom_ops as ops @@ -455,41 +455,43 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + # TODO fix scales @pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) @pytest.mark.parametrize("num_groups", [10]) -@pytest.mark.parametrize("per_act_token", [False])# [True, False]) -@pytest.mark.parametrize("per_out_ch", [True])# [True, False]) -@pytest.mark.parametrize("use_bias", [False])# [True, False]) +@pytest.mark.parametrize("per_act_token", [False]) # [True, False]) +@pytest.mark.parametrize("per_out_ch", [True]) # [True, False]) +@pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, - per_act_token: bool, - per_out_ch: bool, use_bias: bool): + per_act_token: bool, per_out_ch: bool, + use_bias: bool): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. device = "cuda" out_dtype = torch.half - alignment = 16 # 128 // 8 + alignment = 16 # 128 // 8 problem_sizes = torch.empty((num_groups, 3), device="cpu") - offsets_a = torch.empty((num_groups), device="cpu") - offsets_b = torch.empty((num_groups), device="cpu") - offsets_c = torch.empty((num_groups), device="cpu") + offsets_a = torch.empty((num_groups), device="cpu", dtype=torch.int32) + offsets_b = torch.empty((num_groups), device="cpu", dtype=torch.int32) + offsets_c = torch.empty((num_groups), device="cpu", dtype=torch.int32) tot_a = 0 tot_b = 0 tot_c = 0 + m = alignment * random.randint(1, 64) + n = alignment * random.randint(1, 64) + k = alignment * random.randint(1, 64) for g in range(num_groups): - m = alignment * random.randint(1, 64) - n = alignment * random.randint(1, 64) - k = alignment * random.randint(1, 64) tot_a += m tot_b += k tot_c += m - offsets_a[g] = m * k - offsets_b[g] = k * n - offsets_c[g] = m * n + print(m, n, k) + offsets_a[g] = g * m * k + offsets_b[g] = g * k * n + offsets_c[g] = g * m * n problem_sizes[g][0] = m problem_sizes[g][1] = n problem_sizes[g][2] = k @@ -497,32 +499,67 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, a = to_fp8(torch.randn((tot_a, k), device=device)) b = to_fp8(torch.randn((tot_b, n), device=device).t()) c = torch.zeros((tot_c, n), device=device).to(out_dtype) + baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) - print(tot_a, tot_b, tot_c) + # print(a) + # print(b) - print(a.stride(), b.stride(), c.stride()) + # print(offsets_a) + # print(offsets_b) + # print(offsets_c) + # print(tot_a, tot_b, tot_c) + + # print(a.stride(), b.stride(), c.stride()) # m_a_scales = m if per_act_token else 1 # n_b_scales = n if per_out_ch else 1 - scale_a = (torch.randn((tot_a if per_act_token else num_groups), - device=device, - dtype=torch.float32)) - scale_b = (torch.randn((tot_b if per_act_token else num_groups), - device=device, - dtype=torch.float32)) + # scale_a = (torch.randn((tot_a if per_act_token else num_groups), + # device=device, + # dtype=torch.float32)) + # scale_b = (torch.randn((tot_b if per_act_token else num_groups), + # device=device, + # dtype=torch.float32)) + + scale_a = (torch.ones((tot_a if per_act_token else num_groups), + device=device, + dtype=torch.float32)) + scale_b = (torch.ones((tot_b if per_act_token else num_groups), + device=device, + dtype=torch.float32)) + # if use_bias: # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 # else: # bias = None + print(a) + # TODO strides we can get later the same way as in scaled_mm_c3x.cu torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, offsets_c, offsets_a, offsets_b) - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) + # print(a.dtype) + # print(a) + + # torch.set_printoptions(profile='full') + # # print(c[2*m:3*m]) + # print(torch.max(c, dim=1)) + # print(torch.max(c, dim=0)) print(c) + for g in range(num_groups): + baseline[g * m:(g + 1) * m] = baseline_scaled_mm( + a[g * m:(g + 1) * m], + b.t()[g * k:(g + 1) * k], + scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], + scale_b[g * k:(g + 1) * k] if per_act_token else scale_b[g], + out_dtype, None) + print(baseline[g * m:(g + 1) * m]) + print(c[g * m:(g + 1) * m]) + print("*") + # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) # opcheck(torch.ops._C.cutlass_scaled_mm, From c570c69ed80d0f7e2a2be27ef1f931497bc3e589 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 11 Dec 2024 14:41:46 +0000 Subject: [PATCH 04/58] Working for identical sizes Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_gemm_test.cu | 167 +++++++++--------- tests/kernels/test_cutlass.py | 62 ++++--- 2 files changed, 118 insertions(+), 111 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index db86bd1a4b46..03d23c773969 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -129,6 +129,31 @@ struct ItemDeleter { } }; +template +cutlass::platform::unique_ptr> make_device_ptr( + std::vector& data_host) { + T* data_device; + int count = data_host.size(); + cudaMalloc(&data_device, count * sizeof(T)); + cudaMemcpy(data_device, data_host.data(), count * sizeof(T), + cudaMemcpyHostToDevice); + return cutlass::platform::unique_ptr>(data_device); +} + +/////////////// +template +void print(const TupType& _tup, std::index_sequence) { + std::cout << "("; + (..., (std::cout << (I == 0 ? "" : ", ") << std::get(_tup))); + std::cout << ")\n"; +} + +template +void print(const std::tuple& _tup) { + print(_tup, std::make_index_sequence()); +} +//////////// + template void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, @@ -142,46 +167,67 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, using ElementAcc = float; int groups = problem_sizes.size(0); - std::vector a_ptrs_host(groups); - std::vector b_ptrs_host(groups); - std::vector c_ptrs_host(groups); + std::vector a_ptrs_host(groups); + std::vector b_ptrs_host(groups); + std::vector c_ptrs_host(groups); std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = - static_cast(a.data_ptr()) + a_offsets[g].item(); - b_ptrs_host.at(g) = - static_cast(b.data_ptr()) + b_offsets[g].item(); - c_ptrs_host.at(g) = - static_cast(out.data_ptr()) + out_offsets[g].item(); + a_ptrs_host.at(g) = static_cast(a.data_ptr()) + + a_offsets[g].item(); + b_ptrs_host.at(g) = static_cast(b.data_ptr()) + + b_offsets[g].item(); + c_ptrs_host.at(g) = static_cast(out.data_ptr()) + + out_offsets[g].item(); d_ptrs_host.at(g) = static_cast(out.data_ptr()) + out_offsets[g].item(); - printf("%d %d %d\n", a_offsets[g].item(), + printf("off: %d %d %d\n", a_offsets[g].item(), b_offsets[g].item(), out_offsets[g].item()); } using GemmKernel = typename Gemm::GemmKernel; - using StrideA = typename GemmKernel::InternalStrideA; - using StrideB = typename GemmKernel::InternalStrideB; - using StrideC = typename GemmKernel::InternalStrideC; - // using StrideD = typename GemmKernel::InternalStrideD; + // using StrideA = typename GemmKernel::InternalStrideA; + // using StrideB = typename GemmKernel::InternalStrideB; + // using StrideC = typename GemmKernel::InternalStrideC; + // // using StrideD = typename GemmKernel::InternalStrideD; - std::vector a_stride_host(groups); - std::vector b_stride_host(groups); - std::vector c_stride_host(groups); + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); - for (int g = 0; g < groups; ++g) { - int32_t m = problem_sizes[g][0].item(); - int32_t n = problem_sizes[g][1].item(); - int32_t k = problem_sizes[g][2].item(); - a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, - // row - b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, - // col - c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, - // row - } + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = + typename GemmKernel::InternalStrideC; // typename Gemm::StrideC; + + // StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + // StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + // StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + std::vector a_stride_host(groups, StrideA{lda, Int<1>{}, Int<0>{}}); + std::vector b_stride_host(groups, StrideB{ldb, Int<1>{}, Int<0>{}}); + std::vector c_stride_host(groups, StrideC{ldc, Int<1>{}, Int<0>{}}); + + printf("a: "); + print(a_stride_host[0]); + printf("\nb: "); + print(b_stride_host[0]); + printf("\nc: "); + print(c_stride_host[0]); + printf("\n"); + + // for (int g = 0; g < groups; ++g) { + // int32_t m = problem_sizes[g][0].item(); + // int32_t n = problem_sizes[g][1].item(); + // int32_t k = problem_sizes[g][2].item(); + // a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, + // // row + // b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, + // // col + // c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, + // // row + // } cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with @@ -200,16 +246,11 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, int32_t n = problem_sizes[g][1].item(); int32_t k = problem_sizes[g][2].item(); problem_sizes_host.push_back({m, n, k}); + printf("mnk: %d, %d, %d\n", m, n, k); } - SingleProblemShape* problem_sizes_device; - int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); - cudaMalloc(&problem_sizes_device, problem_sizes_size); - cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), - groups * sizeof(SingleProblemShape), cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - problem_sizes_ptr(problem_sizes_device); + auto problem_sizes_ptr = + make_device_ptr(problem_sizes_host); ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; @@ -221,54 +262,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); - const ElementAB** a_ptrs_device; - cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups * sizeof(ElementAB*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - a_ptrs_ptr(a_ptrs_device); - - const ElementAB** b_ptrs_device; - cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups * sizeof(ElementAB*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - b_ptrs_ptr(b_ptrs_device); - - const ElementC** c_ptrs_device; - cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups * sizeof(ElementC*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> - c_ptrs_ptr(c_ptrs_device); - - ElementC** d_ptrs_device; - cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups * sizeof(ElementC*), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> d_ptrs_ptr( - d_ptrs_device); - - StrideA* a_stride_device; - cudaMalloc(&a_stride_device, groups * sizeof(StrideA)); - cudaMemcpy(a_stride_device, a_stride_host.data(), groups * sizeof(StrideA), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> a_stride_ptr( - a_stride_device); + auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); + auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); + auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - StrideB* b_stride_device; - cudaMalloc(&b_stride_device, groups * sizeof(StrideB)); - cudaMemcpy(b_stride_device, b_stride_host.data(), groups * sizeof(StrideB), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> b_stride_ptr( - b_stride_device); - - StrideC* c_stride_device; - cudaMalloc(&c_stride_device, groups * sizeof(StrideC)); - cudaMemcpy(c_stride_device, c_stride_host.data(), groups * sizeof(StrideC), - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> c_stride_ptr( - c_stride_device); + auto a_stride_ptr = make_device_ptr(a_stride_host); + auto b_stride_ptr = make_device_ptr(b_stride_host); + auto c_stride_ptr = make_device_ptr(c_stride_host); typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 563a3f433d98..1532feba47d6 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -62,6 +62,7 @@ def baseline_scaled_mm(a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None) -> torch.Tensor: + print(a.shape, b.shape, scale_a.shape, scale_b.shape) output = (scale_a * (scale_b * (torch.mm( a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) if bias is not None: @@ -458,9 +459,9 @@ def test_cutlass_support_opcheck(): # TODO fix scales @pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) -@pytest.mark.parametrize("num_groups", [10]) -@pytest.mark.parametrize("per_act_token", [False]) # [True, False]) -@pytest.mark.parametrize("per_out_ch", [True]) # [True, False]) +@pytest.mark.parametrize("num_groups", [1, 4, 10]) +@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) @pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") @@ -486,7 +487,7 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, k = alignment * random.randint(1, 64) for g in range(num_groups): tot_a += m - tot_b += k + tot_b += n tot_c += m print(m, n, k) offsets_a[g] = g * m * k @@ -497,7 +498,13 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, problem_sizes[g][2] = k a = to_fp8(torch.randn((tot_a, k), device=device)) - b = to_fp8(torch.randn((tot_b, n), device=device).t()) + + b_float = torch.randn((tot_b, k), device=device) + # for g in range(num_groups): + # b_float[g * k:(g + 1) * k] = torch.full((k, n), g + 1) + # print(b_float) + + b = to_fp8(b_float.t()) c = torch.zeros((tot_c, n), device=device).to(out_dtype) baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) @@ -511,29 +518,19 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, # print(a.stride(), b.stride(), c.stride()) - # m_a_scales = m if per_act_token else 1 - # n_b_scales = n if per_out_ch else 1 - - # scale_a = (torch.randn((tot_a if per_act_token else num_groups), - # device=device, - # dtype=torch.float32)) - # scale_b = (torch.randn((tot_b if per_act_token else num_groups), - # device=device, - # dtype=torch.float32)) - - scale_a = (torch.ones((tot_a if per_act_token else num_groups), - device=device, - dtype=torch.float32)) - scale_b = (torch.ones((tot_b if per_act_token else num_groups), - device=device, - dtype=torch.float32)) + scale_a = (torch.randn(((m, 1) if per_act_token else (1, 1)), + device=device, + dtype=torch.float32)) + scale_b = (torch.randn(((1, n) if per_out_ch else (1, 1)), + device=device, + dtype=torch.float32)) # if use_bias: # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 # else: # bias = None - print(a) + # print(a) # TODO strides we can get later the same way as in scaled_mm_c3x.cu torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, @@ -547,20 +544,29 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, # # print(c[2*m:3*m]) # print(torch.max(c, dim=1)) # print(torch.max(c, dim=0)) - print(c) + # print(c) for g in range(num_groups): + print(a[g * m:(g + 1) * m].shape, b[:, g * n:(g + 1) * n].shape) baseline[g * m:(g + 1) * m] = baseline_scaled_mm( a[g * m:(g + 1) * m], - b.t()[g * k:(g + 1) * k], - scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], - scale_b[g * k:(g + 1) * k] if per_act_token else scale_b[g], - out_dtype, None) + b[:, g * n:(g + 1) * n], + # scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], + # # scale_b[:, g * n:(g + 1) * n] if per_out_ch else scale_b[:, g], + # scale_b[g], + scale_a, + scale_b, + out_dtype, + None) print(baseline[g * m:(g + 1) * m]) print(c[g * m:(g + 1) * m]) print("*") - # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) + # print(baseline) + # print(c) + + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) # opcheck(torch.ops._C.cutlass_scaled_mm, # (out, a, b, scale_a, scale_b, bias)) From 6ed63f2ebae2d2d6742cc08c855ac8e5b6eb7cd1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 17 Dec 2024 16:41:45 +0000 Subject: [PATCH 05/58] Grouped gemm working Co-authored-by: Lucas Wilkinson Signed-off-by: ElizaWszola --- .../broadcast_load_epilogue_array_c3x.hpp | 464 ++++++++++++++++++ .../epilogue/broadcast_load_epilogue_c3x.hpp | 5 + .../epilogue/scaled_mm_epilogues_c3x.hpp | 64 +++ csrc/ops.h | 12 +- .../cutlass_w8a8/grouped_gemm_test.cu | 224 ++++----- .../cutlass_w8a8/scaled_mm_entry.cu | 26 +- csrc/torch_bindings.cpp | 8 +- tests/kernels/test_cutlass.py | 155 ++---- 8 files changed, 704 insertions(+), 254 deletions(-) create mode 100644 csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 000000000000..e652179718c9 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,464 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + if (threadIdx.x ==128){ + printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + } + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + // if (threadIdx.x ==128){ + // printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + // } + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 58b1e8ff159f..9f049efd07b4 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -422,6 +422,11 @@ struct Sm90ColOrScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + if (threadIdx.x ==128){ + printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79..ad7c45a076e6 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,4 +1,5 @@ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" /* This file defines custom epilogues for fusing channel scales, token scales, @@ -45,6 +46,16 @@ struct ScaledEpilogueBase { 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + template + using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. @@ -72,6 +83,15 @@ struct ScaledEpilogueBase { std::is_same_v>); return Arguments{data_ptr}; } + + template + static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } + }; /* @@ -312,4 +332,48 @@ struct ScaledEpilogueBiasAzpToken } }; +/* +TODO document +This is an epilogue with ptr arrays to a_scales and b_scales +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(const float* const* a_scales_ptr, + const float* const* b_scales_ptr, + bool a_col_broadcast, + bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor(a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor(b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + }; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index fce4346fa421..b655d3bfab58 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,13 +145,11 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets); +void cutlass_grouped_mm(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 03d23c773969..c9d299c11130 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -38,11 +38,6 @@ using namespace cute; namespace { -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. template struct enable_sm90_or_later : Kernel { template @@ -54,19 +49,13 @@ struct enable_sm90_or_later : Kernel { }; using ProblemShape = - cutlass::gemm::GroupProblemShape>; // - // per group -using ElementAB_Type = - cutlass::float_e4m3_t; // Element type for A matrix operand -// using ElementB = cutlass::float_e4m3_t; // -// Element type for B matrix operand + cutlass::gemm::GroupProblemShape>; +using ElementAB_Type = cutlass::float_e4m3_t; using ElementC_Type = cutlass::half_t; -// Core kernel configurations -using ElementAccumulator = float; // Element type for internal accumulation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -154,129 +143,109 @@ void print(const std::tuple& _tup) { } //////////// -template -void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets, - EpilogueArgs&&... epilogue_params) { +template +void cutlass_group_gemm_caller(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; - using ElementAcc = float; - int groups = problem_sizes.size(0); + int groups = (int)a_tensors.size(); + TORCH_CHECK((int)b_tensors.size() == groups, + "Number of B tensors must match number of groups."); + TORCH_CHECK((int)out_tensors.size() == groups, + "Number of output tensors must match number of groups."); + std::vector a_ptrs_host(groups); std::vector b_ptrs_host(groups); std::vector c_ptrs_host(groups); std::vector d_ptrs_host(groups); + std::vector a_scales_ptrs_host(groups); + std::vector b_scales_ptrs_host(groups); + + std::vector problem_sizes_host; + problem_sizes_host.reserve(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = static_cast(a.data_ptr()) + - a_offsets[g].item(); - b_ptrs_host.at(g) = static_cast(b.data_ptr()) + - b_offsets[g].item(); - c_ptrs_host.at(g) = static_cast(out.data_ptr()) + - out_offsets[g].item(); - d_ptrs_host.at(g) = - static_cast(out.data_ptr()) + out_offsets[g].item(); - printf("off: %d %d %d\n", a_offsets[g].item(), - b_offsets[g].item(), out_offsets[g].item()); + a_ptrs_host[g] = + reinterpret_cast(a_tensors[g].data_ptr()); + b_ptrs_host[g] = + reinterpret_cast(b_tensors[g].data_ptr()); + c_ptrs_host[g] = + reinterpret_cast(out_tensors[g].data_ptr()); + d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); + a_scales_ptrs_host[g] = + reinterpret_cast(a_scales[g].data_ptr()); + b_scales_ptrs_host[g] = + reinterpret_cast(b_scales[g].data_ptr()); + + int64_t m = a_tensors[g].size(0); + int64_t k = a_tensors[g].size(1); + + int64_t k_b = b_tensors[g].size(0); + int64_t n = b_tensors[g].size(1); + + TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, + " while B has k=", k_b); + + // Optionally, verify output shape matches (m,n) + TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, + "Output tensor shape does not match m,n from A,B: ", "Got ", + out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); + + problem_sizes_host.push_back({(int)m, (int)n, (int)k}); } using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; - // using StrideA = typename GemmKernel::InternalStrideA; - // using StrideB = typename GemmKernel::InternalStrideB; - // using StrideC = typename GemmKernel::InternalStrideC; - // // using StrideD = typename GemmKernel::InternalStrideD; + std::vector a_stride_host(groups); + std::vector b_stride_host(groups); + std::vector c_stride_host(groups); - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); + for (int32_t g = 0; g < groups; ++g) { + int64_t lda = a_tensors[g].stride(0); // row-major (m x k) + int64_t ldb = b_tensors[g].stride(1); // column-major (k x n) + int64_t ldc = out_tensors[g].stride(0); // row-major (m x n) - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; - using StrideC = - typename GemmKernel::InternalStrideC; // typename Gemm::StrideC; - - // StrideA a_stride{lda, Int<1>{}, Int<0>{}}; - // StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; - // StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - std::vector a_stride_host(groups, StrideA{lda, Int<1>{}, Int<0>{}}); - std::vector b_stride_host(groups, StrideB{ldb, Int<1>{}, Int<0>{}}); - std::vector c_stride_host(groups, StrideC{ldc, Int<1>{}, Int<0>{}}); - - printf("a: "); - print(a_stride_host[0]); - printf("\nb: "); - print(b_stride_host[0]); - printf("\nc: "); - print(c_stride_host[0]); - printf("\n"); - - // for (int g = 0; g < groups; ++g) { - // int32_t m = problem_sizes[g][0].item(); - // int32_t n = problem_sizes[g][1].item(); - // int32_t k = problem_sizes[g][2].item(); - // a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, - // // row - // b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, - // // col - // c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, - // // row - // } + a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; + b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; + c_stride_host[g] = StrideC{ldc, Int<1>{}, Int<0>{}}; + } cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; - - std::vector problem_sizes_host; - problem_sizes_host.reserve(groups); - for (int32_t g = 0; g < groups; ++g) { - int32_t m = problem_sizes[g][0].item(); - int32_t n = problem_sizes[g][1].item(); - int32_t k = problem_sizes[g][2].item(); - problem_sizes_host.push_back({m, n, k}); - printf("mnk: %d, %d, %d\n", m, n, k); - } - - auto problem_sizes_ptr = - make_device_ptr(problem_sizes_host); + auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; - // ElementAB* a_host_print; - // int numel = a.numel(); - // cudaMalloc(&a_host_print, groups * sizeof(ElementAB)); - // cudaMemcpy(a_host_print, static_cast(a.data_ptr()), numel* - // sizeof(ElementAB), cudaMemcpyDeviceToHost); - // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* - // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); + auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); + auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); + auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); - auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); - auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); + auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); + auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); - auto a_stride_ptr = make_device_ptr(a_stride_host); - auto b_stride_ptr = make_device_ptr(b_stride_host); - auto c_stride_ptr = make_device_ptr(c_stride_host); + auto a_stride_ptr = make_device_ptr(a_stride_host); + auto b_stride_ptr = make_device_ptr(b_stride_host); + auto c_stride_ptr = make_device_ptr(c_stride_host); typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), + a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), + a_scales[0].numel() != 1, b_scales[0].numel() != 1), c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; @@ -284,30 +253,26 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, epilogue_args, hw_info}; - // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; + // std::cout << "gemm_op.can_implement(args): " + // << (int)gemm_op.can_implement(args) << std::endl; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors[0].device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - + auto stream = at::cuda::getCurrentCUDAStream(a_tensors[0].device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } -// typedef InType = cutlass::float_e4m3_t; -// typedef OutType = torch::half; - template typename Epilogue> struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); + static_assert(std::is_same_v); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = @@ -354,18 +319,23 @@ struct sm90_fp8_config_M64 { } // namespace -// TODO hardcode types here? -void cutlass_grouped_mm_sm90( - torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, torch::Tensor const& b_offsets) { - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - // int32_t m = a.size(1); +void cutlass_grouped_mm_sm90(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { + TORCH_CHECK(a_tensors.size() > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size() > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size() > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors[0].dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors[0].dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); using Cutlass3xGemmDefault = typename sm90_fp8_config_default< - ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogue>::Cutlass3xGemm; + ElementAB_Type, ElementC_Type, + vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; // using Cutlass3xGemmM64 = // typename sm90_fp8_config_M64::Cutlass3xGemm; @@ -388,7 +358,5 @@ void cutlass_grouped_mm_sm90( // } else { // // m in (128, inf) cutlass_group_gemm_caller( - out, a, b, problem_sizes, out_offsets, a_offsets, b_offsets, a_scales, - b_scales); - // } + out_tensors, a_tensors, b_tensors, a_scales, b_scales); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 961437893dee..eb5d09a6de7b 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -28,11 +28,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm_sm90( - torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, torch::Tensor const& b_offsets); +void cutlass_grouped_mm_sm90(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales); #endif @@ -158,15 +158,13 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets) { - cutlass_grouped_mm_sm90(out, a, b, a_scales, b_scales, problem_sizes, - out_offsets, a_offsets, b_offsets); +void cutlass_grouped_mm(c10::List const& out_tensors, + c10::List const& a_tensors, + c10::List const& b_tensors, + c10::List const& a_scales, + c10::List const& b_scales) { + cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, + b_scales); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a10c661b22a6..22a1a1a4ae08 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -313,10 +313,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( - "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " - " Tensor b_scales, Tensor problem_sizes, " - " Tensor out_offsets, Tensor a_offsets, " - " Tensor b_offsets) -> ()"); + "cutlass_grouped_mm(Tensor![] out_tensors," + " Tensor[] a_tensors," + " Tensor[] b_tensors, Tensor[] a_scales, " + " Tensor[] b_scales) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); // Mamba selective scan kernel diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 1532feba47d6..4c909669aa5d 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -457,116 +457,69 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) -# TODO fix scales -@pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) -@pytest.mark.parametrize("num_groups", [1, 4, 10]) +@pytest.mark.parametrize("num_groups", [8]) @pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) @pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, - per_act_token: bool, per_out_ch: bool, - use_bias: bool): +def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): - # Test for a cutlass kernel with per-token activation quantization - # and per-output channel weight quantization. + # Device and dtype setup device = "cuda" out_dtype = torch.half - alignment = 16 # 128 // 8 - problem_sizes = torch.empty((num_groups, 3), device="cpu") - offsets_a = torch.empty((num_groups), device="cpu", dtype=torch.int32) - offsets_b = torch.empty((num_groups), device="cpu", dtype=torch.int32) - offsets_c = torch.empty((num_groups), device="cpu", dtype=torch.int32) - tot_a = 0 - tot_b = 0 - tot_c = 0 - m = alignment * random.randint(1, 64) - n = alignment * random.randint(1, 64) - k = alignment * random.randint(1, 64) - for g in range(num_groups): - tot_a += m - tot_b += n - tot_c += m - print(m, n, k) - offsets_a[g] = g * m * k - offsets_b[g] = g * k * n - offsets_c[g] = g * m * n - problem_sizes[g][0] = m - problem_sizes[g][1] = n - problem_sizes[g][2] = k - - a = to_fp8(torch.randn((tot_a, k), device=device)) - - b_float = torch.randn((tot_b, k), device=device) - # for g in range(num_groups): - # b_float[g * k:(g + 1) * k] = torch.full((k, n), g + 1) - # print(b_float) - - b = to_fp8(b_float.t()) - c = torch.zeros((tot_c, n), device=device).to(out_dtype) - baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) - - # print(a) - # print(b) - - # print(offsets_a) - # print(offsets_b) - # print(offsets_c) - # print(tot_a, tot_b, tot_c) - - # print(a.stride(), b.stride(), c.stride()) - - scale_a = (torch.randn(((m, 1) if per_act_token else (1, 1)), - device=device, - dtype=torch.float32)) - scale_b = (torch.randn(((1, n) if per_out_ch else (1, 1)), - device=device, - dtype=torch.float32)) - - # if use_bias: - # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 - # else: - # bias = None - - # print(a) - - # TODO strides we can get later the same way as in scaled_mm_c3x.cu - torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, - offsets_c, offsets_a, offsets_b) - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) - - # print(a.dtype) - # print(a) - - # torch.set_printoptions(profile='full') - # # print(c[2*m:3*m]) - # print(torch.max(c, dim=1)) - # print(torch.max(c, dim=0)) - # print(c) + # Create separate A, B, C tensors for each group + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + out_tensors = [] + baseline_tensors = [] + alignment = 16 # 128 // 8 + # For variation, each group g has dimensions + # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) for g in range(num_groups): - print(a[g * m:(g + 1) * m].shape, b[:, g * n:(g + 1) * n].shape) - baseline[g * m:(g + 1) * m] = baseline_scaled_mm( - a[g * m:(g + 1) * m], - b[:, g * n:(g + 1) * n], - # scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], - # # scale_b[:, g * n:(g + 1) * n] if per_out_ch else scale_b[:, g], - # scale_b[g], - scale_a, - scale_b, - out_dtype, - None) - print(baseline[g * m:(g + 1) * m]) - print(c[g * m:(g + 1) * m]) + m_g = alignment * random.randint(1, 64) + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + + m_a_scales = m_g if per_act_token else 1 + n_b_scales = n_g if per_out_ch else 1 + + print(m_g, n_g, k_g) + + # Create group-specific A and B (FP8) and output (FP16/FP32) + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) + # Set up A/B scales + scale_a = torch.randn((m_a_scales, 1), + device=device, + dtype=torch.float32) + scale_b = torch.randn((1, n_b_scales), + device=device, + dtype=torch.float32) + + a_tensors.append(a_g) + b_tensors.append(b_g) + out_tensors.append(c_g) + a_scales_tensors.append(scale_a) + b_scales_tensors.append(scale_b) + + # Compute baseline result for this group + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, + None) + baseline_tensors.append(baseline_g) + + torch.ops._C.cutlass_grouped_mm(out_tensors, a_tensors, b_tensors, + a_scales_tensors, b_scales_tensors) + + # Validate each group's result against the baseline + for c_g, baseline_g in zip(out_tensors, baseline_tensors): + print(baseline_g) + print(c_g) print("*") - - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) - # print(baseline) - # print(c) - - torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) - - # opcheck(torch.ops._C.cutlass_scaled_mm, - # (out, a, b, scale_a, scale_b, bias)) + torch.testing.assert_close(c_g, baseline_g, rtol=1e-2, atol=5e-2) From e2b1fc05479311f3efac5daf7d34af0b9626ee3c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 17 Dec 2024 16:53:53 +0000 Subject: [PATCH 06/58] Small cleanup Signed-off-by: ElizaWszola --- CMakeLists.txt | 2 +- .../broadcast_load_epilogue_array_c3x.hpp | 7 --- ...grouped_gemm_test.cu => grouped_mm_c3x.cu} | 45 ++----------------- tests/kernels/test_cutlass.py | 11 ++--- 4 files changed, 10 insertions(+), 55 deletions(-) rename csrc/quantization/cutlass_w8a8/{grouped_gemm_test.cu => grouped_mm_c3x.cu} (90%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d6185e75633..c19812ab5491 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,7 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu") + "csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index e652179718c9..5c1d6e3f46be 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -241,10 +241,6 @@ struct Sm90RowOrScalarBroadcastArray { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - if (threadIdx.x ==128){ - printf("ROW M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - } - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem), @@ -435,9 +431,6 @@ struct Sm90ColOrScalarBroadcastArray { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - // if (threadIdx.x ==128){ - // printf("COL M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu similarity index 90% rename from csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu rename to csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index c9d299c11130..b08d67d04664 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -5,7 +5,7 @@ #include "cutlass/cutlass.h" -// TODO let's see which of these we'll need +// TODO clean up the includes we no longer need #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" @@ -26,10 +26,6 @@ #include "common.hpp" -// get rid of these? -// #include "helper.h" -// using namespace cute; - using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 @@ -129,20 +125,6 @@ cutlass::platform::unique_ptr> make_device_ptr( return cutlass::platform::unique_ptr>(data_device); } -/////////////// -template -void print(const TupType& _tup, std::index_sequence) { - std::cout << "("; - (..., (std::cout << (I == 0 ? "" : ", ") << std::get(_tup))); - std::cout << ")\n"; -} - -template -void print(const std::tuple& _tup) { - print(_tup, std::make_index_sequence()); -} -//////////// - template void cutlass_group_gemm_caller(c10::List const& out_tensors, c10::List const& a_tensors, @@ -242,6 +224,8 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, typename GemmKernel::MainloopArguments mainloop_args{ a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; + // Currently, we are only able to do broadcast on either all or none a_scales + // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), @@ -255,8 +239,6 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; - // std::cout << "gemm_op.can_implement(args): " - // << (int)gemm_op.can_implement(args) << std::endl; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); @@ -336,27 +318,6 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, using Cutlass3xGemmDefault = typename sm90_fp8_config_default< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - // using Cutlass3xGemmM64 = - // typename sm90_fp8_config_M64::Cutlass3xGemm; - // using Cutlass3xGemmM128 = - // typename sm90_fp8_config_M128::Cutlass3xGemm; - - // // uint32_t const m = a.size(0); - // uint32_t const mp2 = - // std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - // if (mp2 <= 64) { - // // m in [1, 64] - // cutlass_group_gemm_caller(out, a, b, a_scales, - // b_scales); - // } else if (mp2 <= 128) { - // // m in (64, 128] - // cutlass_group_gemm_caller(out, a, b, a_scales, - // b_scales); - // } else { - // // m in (128, inf) cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales); } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 4c909669aa5d..445a06f57a96 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -457,10 +457,11 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) +# TODO add bias @pytest.mark.parametrize("num_groups", [8]) -@pytest.mark.parametrize("per_act_token", [True, False]) # [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) # [True, False]) -@pytest.mark.parametrize("use_bias", [False]) # [True, False]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("use_bias", [False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, @@ -479,9 +480,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, baseline_tensors = [] alignment = 16 # 128 // 8 - # For variation, each group g has dimensions + # For variation, each group has dimensions # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) - for g in range(num_groups): + for _ in range(num_groups): m_g = alignment * random.randint(1, 64) n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64) From acfd3ef49ca32c9a70d22d9e592671b51c0212a3 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 13 Jan 2025 15:36:33 +0000 Subject: [PATCH 07/58] Benchmark grouped cutlass against bfloat16 torch.mm Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 199 ++++++++++++++++++ benchmarks/kernels/benchmark_shapes.py | 21 ++ vllm/_custom_ops.py | 7 + 3 files changed, 227 insertions(+) create mode 100644 benchmarks/kernels/benchmark_grouped_gemm_cutlass.py diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py new file mode 100644 index 000000000000..be401cec03c6 --- /dev/null +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -0,0 +1,199 @@ +from typing import List, Tuple + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES_MOE + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"] + # "nm-testing/deepseekv2-lite", + # "ibm-granite/granite-3.0-1b-a400m", + # "ibm-granite/granite-3.0-3b-a800m"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +NUM_GROUPS_OPTS = [8] +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] + +def grouped_gemm(a_g_tensors: List[torch.Tensor], + b_g_tensors: List[torch.Tensor], + out_g_tensors: List[torch.Tensor], + a_scales_tensors: List[torch.Tensor], + b_scales_tensors: List[torch.Tensor]): + ops.cutlass_grouped_mm(out_g_tensors, a_g_tensors, b_g_tensors, + a_scales_tensors, b_scales_tensors) + +def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], + b_tensors: List[torch.Tensor], + out_tensors: List[torch.Tensor]): + for g in range(num_groups): + a = a_tensors[g] + b = b_tensors[g] + out = torch.mm(a, b) + out_tensors[g] = out + +def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, + per_act_token: bool, per_out_ch: bool, + mkn: List[Tuple[int, int, int]]): + label = "Quant Matmul" + + sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_groups, per_act_token, + per_out_ch, mkn)) + + print(f"Testing: {sub_label}") + + device = "cuda" + out_dtype = torch.half + + def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + a_tensors = [] + b_tensors = [] + a_g_tensors = [] + b_g_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + out_tensors = [] + out_g_tensors = [] + baseline_tensors = [] + + for g in range(num_groups): + m_g = mkn[g][0] + k_g = mkn[g][1] + n_g = mkn[g][2] + + m_a_scales = m_g if per_act_token else 1 + n_b_scales = n_g if per_out_ch else 1 + + a = torch.randn((m_g, k_g), device=device) + b = torch.randn((n_g, k_g), device=device).t() + c = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) + + a_g = to_fp8(a) + b_g = to_fp8(b) + c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) + + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) + + a_tensors.append(a.to(dtype=torch.bfloat16)) + b_tensors.append(b.to(dtype=torch.bfloat16)) + out_tensors.append(c) + a_g_tensors.append(a_g) + b_g_tensors.append(b_g) + out_g_tensors.append(c_g) + baseline_tensors.append(c_g) + a_scales_tensors.append(scale_a) + b_scales_tensors.append(scale_b) + + globals = { + # Gen params + "a_tensors": a_tensors, + "b_tensors": b_tensors, + "a_g_tensors": a_g_tensors, + "b_g_tensors": b_g_tensors, + "out_g_tensors": out_g_tensors, + "out_tensors": out_tensors, + "baseline_tensors": baseline_tensors, + "a_scales_tensors": a_scales_tensors, + "b_scales_tensors": b_scales_tensors, + "num_groups": num_groups, + # Kernels + "grouped_gemm": grouped_gemm, + "baseline_gemm": baseline_gemm, + } + + min_run_time = 1 + num_warmup = 5 + + # Warmup pytorch + for _ in range(num_warmup): + grouped_gemm(a_g_tensors, b_g_tensors, out_g_tensors, a_scales_tensors, + b_scales_tensors) + + results.append( + benchmark.Timer( + stmt="grouped_gemm(a_g_tensors, b_g_tensors, out_g_tensors, a_scales_tensors, b_scales_tensors)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup pytorch + for _ in range(num_warmup): + baseline_gemm(num_groups, a_tensors, b_tensors, out_tensors) + + results.append( + benchmark.Timer( + stmt= + "output = baseline_gemm(num_groups, a_tensors, b_tensors, out_tensors)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="baseline_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: List[benchmark.Measurement] = [] + + for model in args.models: + for layer in WEIGHT_SHAPES_MOE[model]: + num_groups = layer[0] + size_k = layer[1] + size_n = layer[2] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in DEFAULT_BATCH_SIZES: + mkn = [(size_m, size_k, size_n)] * num_groups + bench_run(results, model, num_groups, per_act_token, + per_out_ch, mkn) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 ... +# +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 4eeeca35a37c..9550236aa671 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -73,3 +73,24 @@ [7168, 8192], ], } + +WEIGHT_SHAPES_MOE = { + "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ + [8, 4096, 28672], + [8, 14336, 4096], + ], + "nm-testing/deepseekv2-lite": [ + [64, 2048, 352], + [64, 1408, 256], + [64, 128, 5632], + [64, 88, 4096], + ], + "ibm-granite/granite-3.0-1b-a400m": [ + [32, 1024, 2048], + [32, 1024, 1024], + ], + "ibm-granite/granite-3.0-3b-a800m": [ + [40, 1536, 2048], + [40, 1024, 1536], + ], +} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index afb350591e56..7703ec0d966e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -490,6 +490,13 @@ def cutlass_scaled_mm(a: torch.Tensor, return out +def cutlass_grouped_mm(out: List[torch.Tensor], a: List[torch.Tensor], + b: List[torch.Tensor], scale_a: List[torch.Tensor], + scale_b: List[torch.Tensor]) -> torch.Tensor: + torch.ops._C.cutlass_grouped_mm(out, a, b, scale_a, scale_b) + return out + + def cutlass_scaled_mm_azp(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, From f1a56669f59cc1bf138523c1247dde9db7f28e56 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 17 Jan 2025 16:27:58 +0000 Subject: [PATCH 08/58] Start working on fused moe cutlass implementation Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 196 +++++++++++++++++- csrc/cpu/torch_bindings.cpp | 8 + .../epilogue/broadcast_load_epilogue_c3x.hpp | 6 +- csrc/ops.h | 6 + .../cutlass_w8a8/grouped_mm_c3x.cu | 47 +++++ .../cutlass_w8a8/scaled_mm_entry.cu | 16 ++ csrc/torch_bindings.cpp | 7 + tests/kernels/test_cutlass_moe.py | 145 +++++++++++++ 8 files changed, 422 insertions(+), 9 deletions(-) create mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index be401cec03c6..67923262a585 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_moe, fused_topk, fused_experts) DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"] # "nm-testing/deepseekv2-lite", @@ -17,6 +19,11 @@ PER_ACT_TOKEN_OPTS = [False, True] PER_OUT_CH_OPTS = [False, True] +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + def grouped_gemm(a_g_tensors: List[torch.Tensor], b_g_tensors: List[torch.Tensor], out_g_tensors: List[torch.Tensor], @@ -33,6 +40,30 @@ def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], b = b_tensors[g] out = torch.mm(a, b) out_tensors[g] = out + +def cutlass_fused(a_tensors: List[torch.Tensor], + w1_tensors: List[torch.Tensor], + w2_tensors: List[torch.Tensor], + c1_tensors: List[torch.Tensor], + c2_tensors: List[torch.Tensor], + c2_tensors_fp8: List[torch.Tensor], + c3_tensors: List[torch.Tensor], + a_scales: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + c2_scales: List[torch.Tensor], + num_groups: int): + # output_dtype = c3_tensors[0].dtype + N = c2_tensors[0].shape[1] + ops.cutlass_grouped_mm(c1_tensors, a_tensors, w1_tensors, + a_scales, w1_scales) + # TODO make this work as it should + for idx in range(num_groups): + torch.ops._C.silu_and_mul(c2_tensors[idx], c1_tensors[idx].view(-1, N)) + print(c2_tensors[idx]) + c2_tensors_fp8[idx] = to_fp8(c2_tensors[idx].half()) + ops.cutlass_grouped_mm(c3_tensors, c2_tensors, w2_tensors, + c2_scales, w2_scales) def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, per_act_token: bool, per_out_ch: bool, @@ -47,11 +78,6 @@ def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, device = "cuda" out_dtype = torch.half - - def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) a_tensors = [] b_tensors = [] @@ -142,6 +168,164 @@ def to_fp8(tensor: torch.Tensor): description="baseline_gemm", ).blocked_autorange(min_run_time=min_run_time)) +def bench_run_moe(results: List[benchmark.Measurement], model: str, num_groups: int, + per_act_token: bool, per_out_ch: bool, + mkn: List[Tuple[int, int, int]]): + label = "Quant Matmul" + + sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_groups, per_act_token, + per_out_ch, mkn)) + + print(f"Testing: {sub_label}") + + device = "cuda" + out_dtype = torch.bfloat16 + + def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + m_tot = sum([elem[0] for elem in mkn]) + k_g = mkn[0][1] + n_g = mkn[0][2] + + a_tensors = [] + w1_tensors = [] + w2_tensors = [] + c1_tensors = [] + c2_tensors = [] + c2_tensors_fp8 = [] + c3_tensors = [] + a_scales = [] + w1_scales = [] + w2_scales = [] + c2_scales = [] + + a = torch.randn((m_tot, k_g), device=device, dtype=out_dtype) + w1 = torch.randn((num_groups, 2 * n_g, k_g), device=device, dtype=out_dtype) + w2 = torch.randn((num_groups, k_g, n_g), device=device, dtype=out_dtype) + scored_output = torch.randn((m_tot, num_groups), device="cuda", dtype=out_dtype) + topk = 2 + # triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + + #TODO grouped topk for deepseek + topk_weights, topk_ids = fused_topk(a, scored_output, topk, renormalize=True) + fused_experts(a, w1, w2, topk_weights, topk_ids) + topk_ids_cpu = topk_ids.cpu() + + occurrences = [0] * num_groups + expert_offsets = [0] * (num_groups + 1) + for id in topk_ids_cpu.flatten(): + occurrences[id] += 1 + + for e in range(num_groups): + expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + + print(expert_offsets, m_tot) + + a = torch.randn((m_tot, k_g)) + a_group[0] = a[sorted_token_ids[0]] + + # TODO + # create full input tensor m_tot x k_g x topk + # get shuffle data like sorted_token_ids etc. + # create view + + for g in range(num_groups): + m_g = occurrences[g] + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + w1_g = to_fp8(torch.randn((2 * n_g, k_g), device=device).t()) + w2_g = to_fp8(torch.randn((k_g, n_g), device=device).t()) + c1_g = torch.zeros((m_g, 2 * n_g), device=device, dtype=torch.bfloat16) + c2_g = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) + c2_g_fp8 = to_fp8(torch.zeros((m_g, n_g), device=device)) + c3_g = torch.zeros((m_g, k_g), device=device, dtype=torch.bfloat16) + # m_a_scales = m_g if per_act_token else 1 + # n_b_scales = n_g if per_out_ch else 1 + m_scales = 1 + n2_scales = 1 + k_scales = 1 + scale_a = (torch.randn((m_scales, 1), device=device, + dtype=torch.float32)) + scale_w1 = (torch.randn((n2_scales, 1), device=device, + dtype=torch.float32)) + scale_w2 = (torch.randn((k_scales, 1), device=device, + dtype=torch.float32)) + scale_c2 = (torch.randn((m_scales, 1), device=device, + dtype=torch.float32)) + + a_tensors.append(a_g) + w1_tensors.append(w1_g) + w2_tensors.append(w2_g) + c1_tensors.append(c1_g) + c2_tensors.append(c2_g) + c2_tensors_fp8.append(c2_g_fp8) + c3_tensors.append(c3_g) + a_scales.append(scale_a) + w1_scales.append(scale_w1) + w2_scales.append(scale_w2) + c2_scales.append(scale_c2) + + globals = { + # Gen params + "num_groups": num_groups, + # Grouped gemm params + "a_tensors": a_tensors, + "w1_tensors": w1_tensors, + "w2_tensors": w2_tensors, + "c1_tensors": c1_tensors, + "c2_tensors": c2_tensors, + "c2_tensors_fp8": c2_tensors_fp8, + "c3_tensors": c3_tensors, + "a_scales": a_scales, + "w1_scales": w1_scales, + "w2_scales": w2_scales, + "c2_scales": c2_scales, + # Triton params (fused_moe) + "a": a, + "w1": w1, + "w2": w2, + "scored_output": scored_output, + "topk": topk, + # Kernels + "fused_moe": fused_moe, + "cutlass_fused": cutlass_fused, + } + + min_run_time = 1 + num_warmup = 5 + + # Warmup triton + for _ in range(num_warmup): + fused_moe(a, w1, w2, scored_output, topk, renormalize=False) + + results.append( + benchmark.Timer( + stmt="fused_moe(a, w1, w2, scored_output, topk, renormalize=False)", + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup cutlass + for _ in range(num_warmup): + cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, + c2_tensors_fp8, c3_tensors, a_scales, w1_scales, + w2_scales, c2_scales, num_groups) + + results.append( + benchmark.Timer( + stmt= + "cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, c2_tensors_fp8, c3_tensors, a_scales, w1_scales, w2_scales, c2_scales, num_groups)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="baseline_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): @@ -165,7 +349,7 @@ def main(args): for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = [(size_m, size_k, size_n)] * num_groups - bench_run(results, model, num_groups, per_act_token, + bench_run_moe(results, model, num_groups, per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index a720348dee3e..96ddcab7cea2 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -125,6 +125,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor out_offsets, Tensor a_offsets, " " Tensor b_offsets) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + + ops.def( + "compute_expert_offsets(Tensor! trg_a_ptrs," + " Tensor! a, Tensor topk_ids," + " Tensor! expert_offsets, SymInt num_experts) -> ()"); + ops.impl("compute_expert_offsets", torch::kCUDA, + &compute_expert_offsets); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 9f049efd07b4..ad33eec9ef8f 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -424,9 +424,9 @@ struct Sm90ColOrScalarBroadcast { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - if (threadIdx.x ==128){ - printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - } + // if (threadIdx.x ==128){ + // printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); + // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/ops.h b/csrc/ops.h index 736a40091f03..d7ec0e0f9128 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -161,6 +161,12 @@ void cutlass_grouped_mm(c10::List const& out_tensors, c10::List const& a_scales, c10::List const& b_scales); +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index caa7edb888a3..835a144aed4b 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -321,3 +321,50 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales); } + +__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, + cutlass::float_e4m3_t* base_a_ptr, + const int* __restrict__ topk_ids, + int64_t* expert_offsets, + int topk_length) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int64_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + trg_a_ptrs[i] = base_a_ptr + tot_offset; + tot_offset += expert_offsets[i + 1]; + expert_offsets[i + 1] = tot_offset; + } + } +} + +// For a given "a" of size [M,K] performs a permutation of the M rows based +// on the given "perm" indices. +__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, + int const* __restrict__ perm_int_ptr, + cutlass::float_e4m3_t* __restrict__ out_ptr, + int size_m, int size_k, int block_rows) { + // TODO +} + +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts) { + get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(), + (cutlass::float_e4m3_t*)a.data_ptr(), + (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), + topk_ids.numel()); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index e60b64d7797b..d9d2a91d0659 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -36,6 +36,13 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, c10::List const& a_scales, c10::List const& b_scales); + +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); + #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -159,6 +166,15 @@ void cutlass_grouped_mm(c10::List const& out_tensors, b_scales); } +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, + torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts) { + compute_expert_offsets_caller(trg_a_ptrs, a, topk_ids, expert_offsets, + num_experts); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b862144aa16f..65d48c7f1465 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -329,6 +329,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor[] b_scales) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + ops.def( + "compute_expert_offsets(Tensor! trg_a_ptrs," + " Tensor! a, Tensor topk_ids," + " Tensor! expert_offsets, SymInt num_experts) -> ()"); + ops.impl("compute_expert_offsets", torch::kCUDA, + &compute_expert_offsets); + // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability ops.def( diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py new file mode 100644 index 000000000000..083de75e1d34 --- /dev/null +++ b/tests/kernels/test_cutlass_moe.py @@ -0,0 +1,145 @@ +import pytest +import torch +from transformers import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from typing import List + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) +from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +NUM_EXPERTS = [8, 64] +TOP_KS = [2, 6] + +# TODO move to a better file later +# TODO handle scores +def cutlass_moe(a: torch.Tensor, + a_q: torch.Tensor, + a_scale: torch.Tensor, + w1_qs: List[torch.Tensor], + w2_qs: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +): + # TODO look at the code in benchmark_grouped_gemm_cutlass.py + # and get the relevant parts + # (also the fused_moe function) + + num_groups = len(w1_qs) + topk = topk_ids.shape[1] + num_tokens = topk_ids.shape[0] + + # TODO make this GPU only + occurrences = [0] * num_groups + expert_offsets = [0] * (num_groups + 1) + for id in topk_ids.cpu().flatten(): + occurrences[id] += 1 + for e in range(num_groups): + expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + + # TODO duplicate A rows topk times + # compute sorted_token_ids (argsort?) + # shuffle A according to this so each group input is contiguous + + # print(topk_ids) + # print(expert_offsets) + a_map = topk_ids.flatten().argsort() + rep_a_q = a_q.repeat_interleave(topk, dim=0) + + print(a_map) + print(rep_a_q) + + a_q_s = [] + for e in range(num_groups): + a_q_s.append(rep_a_q[a_map[expert_offsets[e]:expert_offsets[e+1]]]) + print(a_q_s) + return + # get a_map and expert_indices on device + + # TODO shuffle rep_a_q according to a_map + # get a_ptrs = a + expert_indices[:-1] + + a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") + expert_offsets = torch.empty((num_groups + 1), dtype=torch.int64, device="cuda") + # TODO might need to call it from inside cutlass code? + # help(ops) + + # print(a_ptrs) + # print(rep_a_q) + print(topk_ids) + # print(expert_offsets) + # print(num_groups) + torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), + expert_offsets, num_groups) + print(a_ptrs) + print(expert_offsets) + +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) +# @pytest.mark.parametrize("n", [128, 2048]) +# @pytest.mark.parametrize("k", [128, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("m", [10]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +def test_cutlass_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + current_platform.seed_everything(7) + + dtype = torch.bfloat16 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + a_q, a_scale = ops.scaled_fp8_quant(a) + + w1_qs = [] + w2_qs = [] + w1_scales = [] + w2_scales = [] + + for expert in range(e): + w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) + w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) + w1_qs.append(w1_q) + w2_qs.append(w2_q) + w1_scales.append(w1_scale) + w2_scales.append(w2_scale) + + # (assume score is a vector of ones for now) + score = torch.ones((m, e), device="cuda", dtype=dtype) + + e_range = torch.full((m, e), 1.0 / e) + topk_ids = torch.multinomial(e_range, topk).int().sort()[0] + topk_weights = torch.rand((m, topk)) + + torch_output = torch_moe(a, w1, w2, score, topk) + cutlass_output = cutlass_moe(a, a_q, a_scale, w1_qs, w2_qs, w1_scales, + w2_scales, topk_weights, topk_ids) + + # torch.testing.assert_close(torch_output, + # cutlass_output, + # atol=2e-2, + # rtol=0) From 6414e317bae3fb36a0871f5c68c6db517e974008 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 20 Jan 2025 23:49:35 +0000 Subject: [PATCH 09/58] Working halfway Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 73 +++++++++++++++---- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 835a144aed4b..4abb84e3e0bb 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -163,6 +163,8 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, b_scales_ptrs_host[g] = reinterpret_cast(b_scales[g].data_ptr()); + // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], + // c_ptrs_host[g], d_ptrs_host[g],) int64_t m = a_tensors[g].size(0); int64_t k = a_tensors[g].size(1); @@ -348,23 +350,68 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, } } -// For a given "a" of size [M,K] performs a permutation of the M rows based -// on the given "perm" indices. -__global__ void permute_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, - int const* __restrict__ perm_int_ptr, - cutlass::float_e4m3_t* __restrict__ out_ptr, - int size_m, int size_k, int block_rows) { - // TODO -} +// // For a given "a" of size [M,K] performs a permutation of the M rows based +// // on the given "perm" indices. +// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, +// int const* __restrict__ perm_int_ptr, +// cutlass::float_e4m3_t* __restrict__ out_ptr, +// int size_m, int size_k, int block_rows) { +// int start_row = block_rows * blockIdx.x; +// int finish_row = start_row + block_rows; +// if (finish_row > size_m) { +// finish_row = size_m; +// } +// int cur_block_rows = finish_row - start_row; + +// int row_stride = size_k * sizeof(cutlass::float_e4m3_t) / 16; + +// auto permute_row = [&](int row) { +// int iters = size_k / blockDim.x; +// int rest = size_k % blockDim.x; + +// int a_offset = perm_int_ptr[row] * row_stride; +// int out_offset = row * row_stride; + +// cutlass::float_e4m3_t const* a_row_fp8 = a_ptr + a_offset; +// cutlass::float_e4m3_t* out_fp8 = out_ptr + out_offset; + +// int base_k = 0; + +// for (int i = 0; i < iters; i++) { +// int cur_k = base_k + threadIdx.x; +// out_fp8[cur_k] = a_row_fp8[cur_k]; +// base_k += blockDim.x; +// } + +// if (rest) { +// if (threadIdx.x < rest) { +// int cur_k = base_k + threadIdx.x; +// out_fp8[cur_k] = a_row_fp8[cur_k]; +// } +// } +// }; +// } void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const int64_t num_experts) { - get_a_expert_offsets<<<1, num_experts>>>((float_e4m3_t**)trg_a_ptrs.data_ptr(), - (cutlass::float_e4m3_t*)a.data_ptr(), - (const int*)topk_ids.data_ptr(), - (int64_t*)expert_offsets.data_ptr(), - topk_ids.numel()); + get_a_expert_offsets<<<1, num_experts>>>( + (cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(), + (cutlass::float_e4m3_t*)a.data_ptr(), + (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), + topk_ids.numel()); } + +// void permute_fp8_rows(torch::Tensor& a_ptr, +// torch::Tensor& perm_ptr, +// torch::Tensor& out_ptr, +// int size_m, int size_k, int topk, int block_rows) { +// permute_fp8_rows_kernel<<>>( +// (cutlass::float_e4m3_t const*)a_ptr.data_ptr(), +// (const int*)perm_ptr.data_ptr(), +// (cutlass::float_e4m3_t const*)out_ptr.data_ptr(), size_m * topk, +// size_k, block_rows); +// } From 67e2dd4494c659011374067cf5d1a16a99840388 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 23 Jan 2025 15:48:05 +0000 Subject: [PATCH 10/58] working mul test but the topk_weights are not yet included in kernel Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 181 ++++++++++++++++++++++-------- 1 file changed, 137 insertions(+), 44 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 083de75e1d34..95b975c6a70e 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -34,43 +34,34 @@ def cutlass_moe(a: torch.Tensor, w2_scales: List[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, + m: int, n: int, k: int, ): # TODO look at the code in benchmark_grouped_gemm_cutlass.py # and get the relevant parts # (also the fused_moe function) + # print(a.shape, a_scale.shape) + # print(w1_qs[0].shape, w1_scales[0].shape) + # print(w2_qs[0].shape, w2_scales[0].shape) + num_groups = len(w1_qs) topk = topk_ids.shape[1] num_tokens = topk_ids.shape[0] + # print("tk_cut:", topk_ids) # TODO make this GPU only - occurrences = [0] * num_groups - expert_offsets = [0] * (num_groups + 1) - for id in topk_ids.cpu().flatten(): - occurrences[id] += 1 - for e in range(num_groups): - expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] + # occurrences = [0] * num_groups + # expert_offsets = [0] * (num_groups + 1) + # for id in topk_ids.cpu().flatten(): + # occurrences[id] += 1 + # for e in range(num_groups): + # expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] # TODO duplicate A rows topk times # compute sorted_token_ids (argsort?) # shuffle A according to this so each group input is contiguous - # print(topk_ids) - # print(expert_offsets) - a_map = topk_ids.flatten().argsort() - rep_a_q = a_q.repeat_interleave(topk, dim=0) - - print(a_map) - print(rep_a_q) - - a_q_s = [] - for e in range(num_groups): - a_q_s.append(rep_a_q[a_map[expert_offsets[e]:expert_offsets[e+1]]]) - print(a_q_s) - return - # get a_map and expert_indices on device - - # TODO shuffle rep_a_q according to a_map + # TODO # get a_ptrs = a + expert_indices[:-1] a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") @@ -80,22 +71,103 @@ def cutlass_moe(a: torch.Tensor, # print(a_ptrs) # print(rep_a_q) - print(topk_ids) + # print(topk_ids) # print(expert_offsets) # print(num_groups) + + # print(topk_ids) + a_map = topk_ids.flatten().argsort() + rep_a_q = a_q.repeat_interleave(topk, dim=0) + torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), expert_offsets, num_groups) - print(a_ptrs) - print(expert_offsets) + # print(expert_offsets) + # print(a_ptrs) + # print(expert_offsets) + + # print("a_map:", a_map) + # print("rep_a_q:", rep_a_q) + + a_q_s = [] + a_scales_s = [] + c_s1 = [] + c_s2 = [] + for e in range(num_groups): + expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] + cut_out = rep_a_q.view(dtype=torch.uint8)[expert_map].view( + dtype=a_q.dtype) + a_q_s.append(cut_out.clone()) + # print("CU:", expert_map, cut_out) + #TODO if we have 1 scale per token, we need to do a_scale[expert_map] + a_scales_s.append(a_scale.clone()) + c_s1.append(torch.zeros((cut_out.shape[0], n * 2), device="cuda", + dtype=torch.half)) + c_s2.append(torch.zeros((cut_out.shape[0], k), device="cuda", + dtype=torch.half)) + # print("a_q_s:", a_q_s[0].shape) + # print("a_scales_s:", a_scales_s[0].shape) + # print("cs:", c_s[0].shape) + # print("w1_qs:", w1_qs[0].shape) + # print("w1_scales", w1_scales[0].shape) + + # print("a_q_s:", a_q_s) + # print("a_scales_s:", a_scales_s) + # print(w1_qs) + # print(w1_scales) + torch.ops._C.cutlass_grouped_mm(c_s1, a_q_s, w1_qs, + a_scales_s, w1_scales) + # c_s1 = [c.reshape((-1, n)) for c in c_s1] + # print([w.stride() for w in w1_qs]) + + # print(c_s1) + + ### UNCOMMENT THIS TO DO ONLY A SINGLE MUL + # intermediate1 = torch.empty((m * topk, n * 2), device="cuda", dtype=torch.half) + # for e in range(num_groups): + # expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] + # intermediate1[expert_map] = c_s1[e] + # return intermediate1.reshape(m, topk, n * 2).sum(dim=1) + ### + + # # print(out) + # intermediate2 = torch.empty((m * topk, n), device="cuda", dtype=torch.half) + # torch.ops._C.silu_and_mul(intermediate2, intermediate1) + + intermediate2 = [] + intermediate2_scales = [] + for e in range(num_groups): + inter2 = torch.empty((c_s1[e].shape[0], n), device="cuda", dtype=torch.half) + torch.ops._C.silu_and_mul(inter2, c_s1[e]) + inter2_v, inter2_s = ops.scaled_fp8_quant(inter2) + # print("cutlass:", inter2) + intermediate2.append(inter2_v) + intermediate2_scales.append(inter2_s.reshape((1, 1))) + + # print(m, k, n, a_q_s[0].shape, w2_qs[0].shape, "->", intermediate2[0].shape) + # print(m, k, n, intermediate2[0].shape, w2_qs[0].shape, intermediate2_scales[0].shape, w2_scales[0].shape) + torch.ops._C.cutlass_grouped_mm(c_s2, intermediate2, w2_qs, + intermediate2_scales, w2_scales) + # print("cutlass:", c_s2) + intermediate3 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) + for e in range(num_groups): + expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] + intermediate3[expert_map] = c_s2[e] + + # print("cutlass:", intermediate3.view(m, topk, k)) + # print("cutlass:", topk_weights.view(m, topk, 1).half()) + out = (intermediate3.reshape(m, topk, k) * + topk_weights.view(m, topk, 1).half()).sum(dim=1) + # return intermediate3.reshape(m, topk, k).sum(dim=1) + return out # @pytest.mark.parametrize("m", [1, 33, 64, 222]) # @pytest.mark.parametrize("n", [128, 2048]) # @pytest.mark.parametrize("k", [128, 1024]) # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("m", [10]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("m", [16]) +@pytest.mark.parametrize("n", [16]) +@pytest.mark.parametrize("k", [16]) @pytest.mark.parametrize("e", [8]) @pytest.mark.parametrize("topk", [2]) def test_cutlass_moe( @@ -107,7 +179,7 @@ def test_cutlass_moe( ): current_platform.seed_everything(7) - dtype = torch.bfloat16 + dtype = torch.half a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -115,6 +187,10 @@ def test_cutlass_moe( a_q, a_scale = ops.scaled_fp8_quant(a) + # print(a) + # print(a_q) + # print(a_scale) + w1_qs = [] w2_qs = [] w1_scales = [] @@ -123,23 +199,40 @@ def test_cutlass_moe( for expert in range(e): w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) - w1_qs.append(w1_q) - w2_qs.append(w2_q) - w1_scales.append(w1_scale) - w2_scales.append(w2_scale) + w1_qs.append(w1_q.t()) + w2_qs.append(w2_q.t()) + w1_scales.append(w1_scale.reshape((1, 1))) + w2_scales.append(w2_scale.reshape((1, 1))) # (assume score is a vector of ones for now) - score = torch.ones((m, e), device="cuda", dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) - e_range = torch.full((m, e), 1.0 / e) - topk_ids = torch.multinomial(e_range, topk).int().sort()[0] - topk_weights = torch.rand((m, topk)) + # e_range = torch.full((m, e), 1.0 / e) + # topk_ids = torch.multinomial(e_range, topk).int().sort()[0] + # topk_weights = torch.rand((m, topk)) - torch_output = torch_moe(a, w1, w2, score, topk) - cutlass_output = cutlass_moe(a, a_q, a_scale, w1_qs, w2_qs, w1_scales, - w2_scales, topk_weights, topk_ids) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - # torch.testing.assert_close(torch_output, - # cutlass_output, - # atol=2e-2, - # rtol=0) + # torch_output = torch_moe(a, w1, w2, score, topk) + a_d = (a_q.float() * a_scale).half() + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_qs[expert].t().float() * w1_scales[expert]).half() + w2_d[expert] = (w2_qs[expert].t().float() * w2_scales[expert]).half() + torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) + # torch_output = torch_moe_single(a_d, w1_d, score, topk) + cutlass_output = cutlass_moe(a, a_q, a_scale, w1_qs, w2_qs, w1_scales, + w2_scales, topk_weights, topk_ids, + m, n, k) + + # print(torch_output.shape) + # print(cutlass_output.shape) + print(torch_output) + print(cutlass_output) + print(torch_output / cutlass_output) + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) From 65235297b86a7ef90ccbbdd4a77b74b994a35b9a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 23 Jan 2025 18:28:11 +0000 Subject: [PATCH 11/58] cleaned up cutlass moe test, fixes Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 292 ++++++++++++------------------ 1 file changed, 112 insertions(+), 180 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 95b975c6a70e..94fd4eba4456 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,175 +1,115 @@ import pytest import torch -from transformers import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from typing import List -import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) +from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) -from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] + # TODO move to a better file later # TODO handle scores -def cutlass_moe(a: torch.Tensor, - a_q: torch.Tensor, - a_scale: torch.Tensor, - w1_qs: List[torch.Tensor], - w2_qs: List[torch.Tensor], - w1_scales: List[torch.Tensor], - w2_scales: List[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - m: int, n: int, k: int, +def cutlass_moe( + a_q: torch.Tensor, + a_scale: torch.Tensor, + w1_qs: List[torch.Tensor], + w2_qs: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, ): - # TODO look at the code in benchmark_grouped_gemm_cutlass.py - # and get the relevant parts - # (also the fused_moe function) - - # print(a.shape, a_scale.shape) - # print(w1_qs[0].shape, w1_scales[0].shape) - # print(w2_qs[0].shape, w2_scales[0].shape) - num_groups = len(w1_qs) topk = topk_ids.shape[1] - num_tokens = topk_ids.shape[0] - # print("tk_cut:", topk_ids) - - # TODO make this GPU only - # occurrences = [0] * num_groups - # expert_offsets = [0] * (num_groups + 1) - # for id in topk_ids.cpu().flatten(): - # occurrences[id] += 1 - # for e in range(num_groups): - # expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] - - # TODO duplicate A rows topk times - # compute sorted_token_ids (argsort?) - # shuffle A according to this so each group input is contiguous - - # TODO - # get a_ptrs = a + expert_indices[:-1] a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") - expert_offsets = torch.empty((num_groups + 1), dtype=torch.int64, device="cuda") - # TODO might need to call it from inside cutlass code? - # help(ops) + expert_offsets = torch.empty((num_groups + 1), + dtype=torch.int64, + device="cuda") - # print(a_ptrs) - # print(rep_a_q) - # print(topk_ids) - # print(expert_offsets) - # print(num_groups) - - # print(topk_ids) a_map = topk_ids.flatten().argsort() rep_a_q = a_q.repeat_interleave(topk, dim=0) torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), expert_offsets, num_groups) - # print(expert_offsets) - # print(a_ptrs) - # print(expert_offsets) - - # print("a_map:", a_map) - # print("rep_a_q:", rep_a_q) a_q_s = [] a_scales_s = [] c_s1 = [] c_s2 = [] for e in range(num_groups): - expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] + expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] cut_out = rep_a_q.view(dtype=torch.uint8)[expert_map].view( dtype=a_q.dtype) a_q_s.append(cut_out.clone()) - # print("CU:", expert_map, cut_out) - #TODO if we have 1 scale per token, we need to do a_scale[expert_map] a_scales_s.append(a_scale.clone()) - c_s1.append(torch.zeros((cut_out.shape[0], n * 2), device="cuda", - dtype=torch.half)) - c_s2.append(torch.zeros((cut_out.shape[0], k), device="cuda", - dtype=torch.half)) - # print("a_q_s:", a_q_s[0].shape) - # print("a_scales_s:", a_scales_s[0].shape) - # print("cs:", c_s[0].shape) - # print("w1_qs:", w1_qs[0].shape) - # print("w1_scales", w1_scales[0].shape) - - # print("a_q_s:", a_q_s) - # print("a_scales_s:", a_scales_s) - # print(w1_qs) - # print(w1_scales) - torch.ops._C.cutlass_grouped_mm(c_s1, a_q_s, w1_qs, - a_scales_s, w1_scales) - # c_s1 = [c.reshape((-1, n)) for c in c_s1] - # print([w.stride() for w in w1_qs]) - - # print(c_s1) - - ### UNCOMMENT THIS TO DO ONLY A SINGLE MUL - # intermediate1 = torch.empty((m * topk, n * 2), device="cuda", dtype=torch.half) + c_s1.append( + torch.zeros((cut_out.shape[0], n * 2), + device="cuda", + dtype=torch.half)) + c_s2.append( + torch.zeros((cut_out.shape[0], k), device="cuda", + dtype=torch.half)) + + torch.ops._C.cutlass_grouped_mm(c_s1, a_q_s, w1_qs, a_scales_s, w1_scales) + + # ### UNCOMMENT THIS TO DO ONLY A SINGLE MUL + # intermediate1 = torch.empty((m * topk, n * 2), + # device="cuda", + # dtype=torch.half) # for e in range(num_groups): # expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] # intermediate1[expert_map] = c_s1[e] # return intermediate1.reshape(m, topk, n * 2).sum(dim=1) - ### + # ### - # # print(out) - # intermediate2 = torch.empty((m * topk, n), device="cuda", dtype=torch.half) - # torch.ops._C.silu_and_mul(intermediate2, intermediate1) + full_groups = [] intermediate2 = [] intermediate2_scales = [] for e in range(num_groups): - inter2 = torch.empty((c_s1[e].shape[0], n), device="cuda", dtype=torch.half) - torch.ops._C.silu_and_mul(inter2, c_s1[e]) - inter2_v, inter2_s = ops.scaled_fp8_quant(inter2) - # print("cutlass:", inter2) - intermediate2.append(inter2_v) - intermediate2_scales.append(inter2_s.reshape((1, 1))) - - # print(m, k, n, a_q_s[0].shape, w2_qs[0].shape, "->", intermediate2[0].shape) - # print(m, k, n, intermediate2[0].shape, w2_qs[0].shape, intermediate2_scales[0].shape, w2_scales[0].shape) - torch.ops._C.cutlass_grouped_mm(c_s2, intermediate2, w2_qs, - intermediate2_scales, w2_scales) - # print("cutlass:", c_s2) + if c_s1[e].shape[0] != 0: + full_groups.append(e) + inter2 = torch.empty((c_s1[e].shape[0], n), + device="cuda", + dtype=torch.half) + torch.ops._C.silu_and_mul(inter2, c_s1[e]) + inter2_v, inter2_s = ops.scaled_fp8_quant(inter2) + intermediate2.append(inter2_v) + intermediate2_scales.append(inter2_s.reshape((1, 1))) + + def filter_list(items: List, idxs: List): + return [items[idx] for idx in idxs] + + torch.ops._C.cutlass_grouped_mm(filter_list(c_s2, + full_groups), intermediate2, + filter_list(w2_qs, full_groups), + intermediate2_scales, + filter_list(w2_scales, full_groups)) intermediate3 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) for e in range(num_groups): - expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] + expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] intermediate3[expert_map] = c_s2[e] - - # print("cutlass:", intermediate3.view(m, topk, k)) - # print("cutlass:", topk_weights.view(m, topk, 1).half()) + out = (intermediate3.reshape(m, topk, k) * topk_weights.view(m, topk, 1).half()).sum(dim=1) - # return intermediate3.reshape(m, topk, k).sum(dim=1) return out -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) -# @pytest.mark.parametrize("n", [128, 2048]) -# @pytest.mark.parametrize("k", [128, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("m", [16]) -@pytest.mark.parametrize("n", [16]) -@pytest.mark.parametrize("k", [16]) -@pytest.mark.parametrize("e", [8]) -@pytest.mark.parametrize("topk", [2]) + +@pytest.mark.parametrize("m", [16, 32, 64, 224]) +@pytest.mark.parametrize("n", [128, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) def test_cutlass_moe( m: int, n: int, @@ -178,61 +118,53 @@ def test_cutlass_moe( topk: int, ): current_platform.seed_everything(7) - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - a_q, a_scale = ops.scaled_fp8_quant(a) - - # print(a) - # print(a_q) - # print(a_scale) - - w1_qs = [] - w2_qs = [] - w1_scales = [] - w2_scales = [] - - for expert in range(e): - w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) - w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) - w1_qs.append(w1_q.t()) - w2_qs.append(w2_q.t()) - w1_scales.append(w1_scale.reshape((1, 1))) - w2_scales.append(w2_scale.reshape((1, 1))) - - # (assume score is a vector of ones for now) - score = torch.randn((m, e), device="cuda", dtype=dtype) - - # e_range = torch.full((m, e), 1.0 / e) - # topk_ids = torch.multinomial(e_range, topk).int().sort()[0] - # topk_weights = torch.rand((m, topk)) - - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - # torch_output = torch_moe(a, w1, w2, score, topk) - a_d = (a_q.float() * a_scale).half() - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_qs[expert].t().float() * w1_scales[expert]).half() - w2_d[expert] = (w2_qs[expert].t().float() * w2_scales[expert]).half() - torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) - # torch_output = torch_moe_single(a_d, w1_d, score, topk) - cutlass_output = cutlass_moe(a, a_q, a_scale, w1_qs, w2_qs, w1_scales, - w2_scales, topk_weights, topk_ids, - m, n, k) - - # print(torch_output.shape) - # print(cutlass_output.shape) - print(torch_output) - print(cutlass_output) - print(torch_output / cutlass_output) - - torch.testing.assert_close(torch_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + a_q, a_scale = ops.scaled_fp8_quant(a) + + w1_qs = [] + w2_qs = [] + w1_scales = [] + w2_scales = [] + + for expert in range(e): + w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) + w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) + w1_qs.append(w1_q.t()) + w2_qs.append(w2_q.t()) + w1_scales.append(w1_scale.reshape((1, 1))) + w2_scales.append(w2_scale.reshape((1, 1))) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + a_d = (a_q.float() * a_scale).half() + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_qs[expert].t().float() * + w1_scales[expert]).half() + w2_d[expert] = (w2_qs[expert].t().float() * + w2_scales[expert]).half() + torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) + cutlass_output = cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, + w2_scales, topk_weights, topk_ids, m, n, + k) + + # print(torch_output) + # print(cutlass_output) + # print(torch_output / cutlass_output) + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) From b302d986c52d59df01f51cc5973f17f176654990 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 23 Jan 2025 22:11:58 +0000 Subject: [PATCH 12/58] benchmark fused Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 333 ++++-------------- benchmarks/kernels/benchmark_shapes.py | 7 +- tests/kernels/test_cutlass_moe.py | 100 +----- .../layers/fused_moe/fused_moe.py | 94 ++++- 4 files changed, 172 insertions(+), 362 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 67923262a585..f99ca4ab09ae 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,24 +6,28 @@ from vllm import _custom_ops as ops from vllm.utils import FlexibleArgumentParser -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, fused_topk, fused_experts) +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, + cutlass_moe, + fused_experts) -DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1"] - # "nm-testing/deepseekv2-lite", - # "ibm-granite/granite-3.0-1b-a400m", - # "ibm-granite/granite-3.0-3b-a800m"] -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1", + "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", + "ibm-granite/granite-3.0-3b-a800m"] +DEFAULT_BATCH_SIZES = [16, 32, 64, 128, 256, 512] + +NUM_GROUPS_OPTS = [8] #[8, 64] +PER_ACT_TOKEN_OPTS = [False] #[False, True] +PER_OUT_CH_OPTS = [False] #[False, True] +TOPKS = [2, 6] -NUM_GROUPS_OPTS = [8] -PER_ACT_TOKEN_OPTS = [False, True] -PER_OUT_CH_OPTS = [False, True] def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) return torch.round(tensor.clamp( min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + def grouped_gemm(a_g_tensors: List[torch.Tensor], b_g_tensors: List[torch.Tensor], out_g_tensors: List[torch.Tensor], @@ -32,6 +36,7 @@ def grouped_gemm(a_g_tensors: List[torch.Tensor], ops.cutlass_grouped_mm(out_g_tensors, a_g_tensors, b_g_tensors, a_scales_tensors, b_scales_tensors) + def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor], out_tensors: List[torch.Tensor]): @@ -41,289 +46,99 @@ def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], out = torch.mm(a, b) out_tensors[g] = out -def cutlass_fused(a_tensors: List[torch.Tensor], - w1_tensors: List[torch.Tensor], - w2_tensors: List[torch.Tensor], - c1_tensors: List[torch.Tensor], - c2_tensors: List[torch.Tensor], - c2_tensors_fp8: List[torch.Tensor], - c3_tensors: List[torch.Tensor], - a_scales: List[torch.Tensor], - w1_scales: List[torch.Tensor], - w2_scales: List[torch.Tensor], - c2_scales: List[torch.Tensor], - num_groups: int): - # output_dtype = c3_tensors[0].dtype - N = c2_tensors[0].shape[1] - ops.cutlass_grouped_mm(c1_tensors, a_tensors, w1_tensors, - a_scales, w1_scales) - # TODO make this work as it should - for idx in range(num_groups): - torch.ops._C.silu_and_mul(c2_tensors[idx], c1_tensors[idx].view(-1, N)) - print(c2_tensors[idx]) - c2_tensors_fp8[idx] = to_fp8(c2_tensors[idx].half()) - ops.cutlass_grouped_mm(c3_tensors, c2_tensors, w2_tensors, - c2_scales, w2_scales) - -def bench_run(results: List[benchmark.Measurement], model: str, num_groups: int, - per_act_token: bool, per_out_ch: bool, - mkn: List[Tuple[int, int, int]]): +# TODO marlin baseline +def bench_run(results: List[benchmark.Measurement], model: str, + num_experts: int, topk: int, per_act_token: bool, + per_out_ch: bool, mkn: Tuple[int, int, int]): label = "Quant Matmul" - sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_groups, per_act_token, - per_out_ch, mkn)) + sub_label = ("{}, num_experts={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_experts, per_act_token, + per_out_ch, mkn)) print(f"Testing: {sub_label}") - device = "cuda" - out_dtype = torch.half - - a_tensors = [] - b_tensors = [] - a_g_tensors = [] - b_g_tensors = [] - a_scales_tensors = [] - b_scales_tensors = [] - out_tensors = [] - out_g_tensors = [] - baseline_tensors = [] - - for g in range(num_groups): - m_g = mkn[g][0] - k_g = mkn[g][1] - n_g = mkn[g][2] - - m_a_scales = m_g if per_act_token else 1 - n_b_scales = n_g if per_out_ch else 1 - - a = torch.randn((m_g, k_g), device=device) - b = torch.randn((n_g, k_g), device=device).t() - c = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) - - a_g = to_fp8(a) - b_g = to_fp8(b) - c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) - - scale_a = (torch.randn((m_a_scales, 1), device=device, - dtype=torch.float32)) - scale_b = (torch.randn((1, n_b_scales), device=device, - dtype=torch.float32)) - - a_tensors.append(a.to(dtype=torch.bfloat16)) - b_tensors.append(b.to(dtype=torch.bfloat16)) - out_tensors.append(c) - a_g_tensors.append(a_g) - b_g_tensors.append(b_g) - out_g_tensors.append(c_g) - baseline_tensors.append(c_g) - a_scales_tensors.append(scale_a) - b_scales_tensors.append(scale_b) - - globals = { - # Gen params - "a_tensors": a_tensors, - "b_tensors": b_tensors, - "a_g_tensors": a_g_tensors, - "b_g_tensors": b_g_tensors, - "out_g_tensors": out_g_tensors, - "out_tensors": out_tensors, - "baseline_tensors": baseline_tensors, - "a_scales_tensors": a_scales_tensors, - "b_scales_tensors": b_scales_tensors, - "num_groups": num_groups, - # Kernels - "grouped_gemm": grouped_gemm, - "baseline_gemm": baseline_gemm, - } - - min_run_time = 1 - num_warmup = 5 - - # Warmup pytorch - for _ in range(num_warmup): - grouped_gemm(a_g_tensors, b_g_tensors, out_g_tensors, a_scales_tensors, - b_scales_tensors) - - results.append( - benchmark.Timer( - stmt="grouped_gemm(a_g_tensors, b_g_tensors, out_g_tensors, a_scales_tensors, b_scales_tensors)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="grouped_gemm", - ).blocked_autorange(min_run_time=min_run_time)) - - # Warmup pytorch - for _ in range(num_warmup): - baseline_gemm(num_groups, a_tensors, b_tensors, out_tensors) - - results.append( - benchmark.Timer( - stmt= - "output = baseline_gemm(num_groups, a_tensors, b_tensors, out_tensors)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="baseline_gemm", - ).blocked_autorange(min_run_time=min_run_time)) - -def bench_run_moe(results: List[benchmark.Measurement], model: str, num_groups: int, - per_act_token: bool, per_out_ch: bool, - mkn: List[Tuple[int, int, int]]): - label = "Quant Matmul" - - sub_label = ("{}, num_groups={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_groups, per_act_token, - per_out_ch, mkn)) + (m, k, n) = mkn - print(f"Testing: {sub_label}") + dtype = torch.half - device = "cuda" - out_dtype = torch.bfloat16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 - def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + a_q, a_scale = ops.scaled_fp8_quant(a) - m_tot = sum([elem[0] for elem in mkn]) - k_g = mkn[0][1] - n_g = mkn[0][2] - - a_tensors = [] - w1_tensors = [] - w2_tensors = [] - c1_tensors = [] - c2_tensors = [] - c2_tensors_fp8 = [] - c3_tensors = [] - a_scales = [] + w1_qs = [] + w2_qs = [] w1_scales = [] w2_scales = [] - c2_scales = [] - - a = torch.randn((m_tot, k_g), device=device, dtype=out_dtype) - w1 = torch.randn((num_groups, 2 * n_g, k_g), device=device, dtype=out_dtype) - w2 = torch.randn((num_groups, k_g, n_g), device=device, dtype=out_dtype) - scored_output = torch.randn((m_tot, num_groups), device="cuda", dtype=out_dtype) - topk = 2 - # triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) - - #TODO grouped topk for deepseek - topk_weights, topk_ids = fused_topk(a, scored_output, topk, renormalize=True) - fused_experts(a, w1, w2, topk_weights, topk_ids) - topk_ids_cpu = topk_ids.cpu() - - occurrences = [0] * num_groups - expert_offsets = [0] * (num_groups + 1) - for id in topk_ids_cpu.flatten(): - occurrences[id] += 1 - - for e in range(num_groups): - expert_offsets[e + 1] = expert_offsets[e] + occurrences[e] - - print(expert_offsets, m_tot) - a = torch.randn((m_tot, k_g)) - a_group[0] = a[sorted_token_ids[0]] + for expert in range(num_experts): + w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) + w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) + w1_qs.append(w1_q.t()) + w2_qs.append(w2_q.t()) + w1_scales.append(w1_scale.reshape((1, 1))) + w2_scales.append(w2_scale.reshape((1, 1))) - # TODO - # create full input tensor m_tot x k_g x topk - # get shuffle data like sorted_token_ids etc. - # create view + score = torch.randn((m, num_experts), device="cuda", dtype=dtype) - for g in range(num_groups): - m_g = occurrences[g] - a_g = to_fp8(torch.randn((m_g, k_g), device=device)) - w1_g = to_fp8(torch.randn((2 * n_g, k_g), device=device).t()) - w2_g = to_fp8(torch.randn((k_g, n_g), device=device).t()) - c1_g = torch.zeros((m_g, 2 * n_g), device=device, dtype=torch.bfloat16) - c2_g = torch.zeros((m_g, n_g), device=device, dtype=torch.bfloat16) - c2_g_fp8 = to_fp8(torch.zeros((m_g, n_g), device=device)) - c3_g = torch.zeros((m_g, k_g), device=device, dtype=torch.bfloat16) - # m_a_scales = m_g if per_act_token else 1 - # n_b_scales = n_g if per_out_ch else 1 - m_scales = 1 - n2_scales = 1 - k_scales = 1 - scale_a = (torch.randn((m_scales, 1), device=device, - dtype=torch.float32)) - scale_w1 = (torch.randn((n2_scales, 1), device=device, - dtype=torch.float32)) - scale_w2 = (torch.randn((k_scales, 1), device=device, - dtype=torch.float32)) - scale_c2 = (torch.randn((m_scales, 1), device=device, - dtype=torch.float32)) - - a_tensors.append(a_g) - w1_tensors.append(w1_g) - w2_tensors.append(w2_g) - c1_tensors.append(c1_g) - c2_tensors.append(c2_g) - c2_tensors_fp8.append(c2_g_fp8) - c3_tensors.append(c3_g) - a_scales.append(scale_a) - w1_scales.append(scale_w1) - w2_scales.append(scale_w2) - c2_scales.append(scale_c2) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) globals = { - # Gen params - "num_groups": num_groups, - # Grouped gemm params - "a_tensors": a_tensors, - "w1_tensors": w1_tensors, - "w2_tensors": w2_tensors, - "c1_tensors": c1_tensors, - "c2_tensors": c2_tensors, - "c2_tensors_fp8": c2_tensors_fp8, - "c3_tensors": c3_tensors, - "a_scales": a_scales, - "w1_scales": w1_scales, - "w2_scales": w2_scales, - "c2_scales": c2_scales, - # Triton params (fused_moe) + # Baseline params "a": a, "w1": w1, "w2": w2, - "scored_output": scored_output, + "score": score, "topk": topk, + # Cutlass params + "a_q": a_q, + "a_scale": a_scale, + "w1_qs": w1_qs, + "w2_qs": w2_qs, + "w1_scales": w1_scales, + "w2_scales": w2_scales, + "m": m, + "n": n, + "k": k, + # Gen params + "topk_weights": topk_weights, + "topk_ids": topk_ids, # Kernels - "fused_moe": fused_moe, - "cutlass_fused": cutlass_fused, + "fused_experts": fused_experts, + "cutlass_moe": cutlass_moe, } min_run_time = 1 num_warmup = 5 - # Warmup triton + # Warmup pytorch for _ in range(num_warmup): - fused_moe(a, w1, w2, scored_output, topk, renormalize=False) + fused_experts(a, w1, w2, topk_weights, topk_ids) results.append( benchmark.Timer( - stmt="fused_moe(a, w1, w2, scored_output, topk, renormalize=False)", + stmt="fused_experts(a, w1, w2, topk_weights, topk_ids)", globals=globals, label=label, sub_label=sub_label, - description="grouped_gemm", + description="baseline_gemm", ).blocked_autorange(min_run_time=min_run_time)) - - # Warmup cutlass + + # Warmup pytorch for _ in range(num_warmup): - cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, - c2_tensors_fp8, c3_tensors, a_scales, w1_scales, - w2_scales, c2_scales, num_groups) + cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, w2_scales, + topk_weights, topk_ids, m, n, k) results.append( benchmark.Timer( stmt= - "cutlass_fused(a_tensors, w1_tensors, w2_tensors, c1_tensors, c2_tensors, c2_tensors_fp8, c3_tensors, a_scales, w1_scales, w2_scales, c2_scales, num_groups)", # noqa: E501 + "cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, w2_scales, topk_weights, topk_ids, m, n, k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="baseline_gemm", + description="grouped_gemm", ).blocked_autorange(min_run_time=min_run_time)) def main(args): @@ -335,7 +150,7 @@ def main(args): for model in args.models: for layer in WEIGHT_SHAPES_MOE[model]: - num_groups = layer[0] + num_experts = layer[0] size_k = layer[1] size_n = layer[2] @@ -347,10 +162,11 @@ def main(args): for per_act_token in PER_ACT_TOKEN_OPTS: for per_out_ch in PER_OUT_CH_OPTS: - for size_m in DEFAULT_BATCH_SIZES: - mkn = [(size_m, size_k, size_n)] * num_groups - bench_run_moe(results, model, num_groups, per_act_token, - per_out_ch, mkn) + for topk in TOPKS: + for size_m in DEFAULT_BATCH_SIZES: + mkn = (size_m, size_k, size_n) + bench_run(results, model, num_experts, topk, + per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) compare.print() @@ -376,7 +192,10 @@ def main(args): parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) - parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", + nargs="+", + type=int, + default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 9550236aa671..ee21c90378f7 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -80,17 +80,12 @@ [8, 14336, 4096], ], "nm-testing/deepseekv2-lite": [ - [64, 2048, 352], - [64, 1408, 256], - [64, 128, 5632], - [64, 88, 4096], + [64, 2048, 1408], ], "ibm-granite/granite-3.0-1b-a400m": [ - [32, 1024, 2048], [32, 1024, 1024], ], "ibm-granite/granite-3.0-3b-a800m": [ - [40, 1536, 2048], [40, 1024, 1536], ], } diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 94fd4eba4456..df0be62d369a 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,110 +1,15 @@ import pytest import torch -from typing import List - from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk, cutlass_moe from vllm.platforms import current_platform from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] - -# TODO move to a better file later -# TODO handle scores -def cutlass_moe( - a_q: torch.Tensor, - a_scale: torch.Tensor, - w1_qs: List[torch.Tensor], - w2_qs: List[torch.Tensor], - w1_scales: List[torch.Tensor], - w2_scales: List[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - m: int, - n: int, - k: int, -): - num_groups = len(w1_qs) - topk = topk_ids.shape[1] - - a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") - expert_offsets = torch.empty((num_groups + 1), - dtype=torch.int64, - device="cuda") - - a_map = topk_ids.flatten().argsort() - rep_a_q = a_q.repeat_interleave(topk, dim=0) - - torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), - expert_offsets, num_groups) - - a_q_s = [] - a_scales_s = [] - c_s1 = [] - c_s2 = [] - for e in range(num_groups): - expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] - cut_out = rep_a_q.view(dtype=torch.uint8)[expert_map].view( - dtype=a_q.dtype) - a_q_s.append(cut_out.clone()) - a_scales_s.append(a_scale.clone()) - c_s1.append( - torch.zeros((cut_out.shape[0], n * 2), - device="cuda", - dtype=torch.half)) - c_s2.append( - torch.zeros((cut_out.shape[0], k), device="cuda", - dtype=torch.half)) - - torch.ops._C.cutlass_grouped_mm(c_s1, a_q_s, w1_qs, a_scales_s, w1_scales) - - # ### UNCOMMENT THIS TO DO ONLY A SINGLE MUL - # intermediate1 = torch.empty((m * topk, n * 2), - # device="cuda", - # dtype=torch.half) - # for e in range(num_groups): - # expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] - # intermediate1[expert_map] = c_s1[e] - # return intermediate1.reshape(m, topk, n * 2).sum(dim=1) - # ### - - full_groups = [] - - intermediate2 = [] - intermediate2_scales = [] - for e in range(num_groups): - if c_s1[e].shape[0] != 0: - full_groups.append(e) - inter2 = torch.empty((c_s1[e].shape[0], n), - device="cuda", - dtype=torch.half) - torch.ops._C.silu_and_mul(inter2, c_s1[e]) - inter2_v, inter2_s = ops.scaled_fp8_quant(inter2) - intermediate2.append(inter2_v) - intermediate2_scales.append(inter2_s.reshape((1, 1))) - - def filter_list(items: List, idxs: List): - return [items[idx] for idx in idxs] - - torch.ops._C.cutlass_grouped_mm(filter_list(c_s2, - full_groups), intermediate2, - filter_list(w2_qs, full_groups), - intermediate2_scales, - filter_list(w2_scales, full_groups)) - intermediate3 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) - for e in range(num_groups): - expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] - intermediate3[expert_map] = c_s2[e] - - out = (intermediate3.reshape(m, topk, k) * - topk_weights.view(m, topk, 1).half()).sum(dim=1) - return out - - @pytest.mark.parametrize("m", [16, 32, 64, 224]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @@ -157,8 +62,7 @@ def test_cutlass_moe( w2_scales[expert]).half() torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) cutlass_output = cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, - w2_scales, topk_weights, topk_ids, m, n, - k) + w2_scales, topk_weights, topk_ids, m, n, k) # print(torch_output) # print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3ea6217d7c0e..185e6f082ea5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -671,7 +671,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ] num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape + E, N, K = w1.shape # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -869,3 +869,95 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) + +# TODO handle scores +def cutlass_moe( + a_q: torch.Tensor, + a_scale: torch.Tensor, + w1_qs: List[torch.Tensor], + w2_qs: List[torch.Tensor], + w1_scales: List[torch.Tensor], + w2_scales: List[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, +): + num_groups = len(w1_qs) + topk = topk_ids.shape[1] + + a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") + expert_offsets = torch.empty((num_groups + 1), + dtype=torch.int64, + device="cuda") + + a_map = topk_ids.flatten().argsort() + rep_a_q = a_q.repeat_interleave(topk, dim=0) + + torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), + expert_offsets, num_groups) + + a_q_s = [] + a_scales_s = [] + c_s1 = [] + c_s2 = [] + for e in range(num_groups): + expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] + cut_out = rep_a_q.view(dtype=torch.uint8)[expert_map].view( + dtype=a_q.dtype) + a_q_s.append(cut_out.clone()) + a_scales_s.append(a_scale.clone()) + c_s1.append( + torch.zeros((cut_out.shape[0], n * 2), + device="cuda", + dtype=torch.half)) + c_s2.append( + torch.zeros((cut_out.shape[0], k), device="cuda", + dtype=torch.half)) + + torch.ops._C.cutlass_grouped_mm(c_s1, a_q_s, w1_qs, a_scales_s, w1_scales) + + # ### UNCOMMENT THIS TO DO ONLY A SINGLE MUL + # intermediate1 = torch.empty((m * topk, n * 2), + # device="cuda", + # dtype=torch.half) + # for e in range(num_groups): + # expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] + # intermediate1[expert_map] = c_s1[e] + # return intermediate1.reshape(m, topk, n * 2).sum(dim=1) + # ### + + full_groups = [] + + intermediate2 = [] + intermediate2_scales = [] + for e in range(num_groups): + if c_s1[e].shape[0] != 0: + full_groups.append(e) + inter2 = torch.empty((c_s1[e].shape[0], n), + device="cuda", + dtype=torch.half) + torch.ops._C.silu_and_mul(inter2, c_s1[e]) + inter2_v, inter2_s = ops.scaled_fp8_quant(inter2) + intermediate2.append(inter2_v) + intermediate2_scales.append(inter2_s.reshape((1, 1))) + + def filter_list(items: List, idxs: List): + return [items[idx] for idx in idxs] + + torch.ops._C.cutlass_grouped_mm(filter_list(c_s2, + full_groups), intermediate2, + filter_list(w2_qs, full_groups), + intermediate2_scales, + filter_list(w2_scales, full_groups)) + intermediate3 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) + for e in range(num_groups): + expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] + intermediate3[expert_map] = c_s2[e] + + intermediate3.reshape(m, topk, k).sum(dim=1) + out = (intermediate3.reshape(m, topk, k) * + topk_weights.view(m, topk, 1).half()).sum(dim=1) + return out + From 342d1a41a407d0808d87f94b2231b33f88196a55 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 24 Jan 2025 22:40:44 +0000 Subject: [PATCH 13/58] pass input as one tensor with an array of offsets rather than a list of tensors Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 24 ++++++----- csrc/ops.h | 5 ++- .../cutlass_w8a8/grouped_mm_c3x.cu | 42 +++++++++++-------- .../cutlass_w8a8/scaled_mm_entry.cu | 12 +++--- csrc/torch_bindings.cpp | 4 +- tests/kernels/test_cutlass.py | 24 ++++++++--- tests/kernels/test_cutlass_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 2 +- 8 files changed, 73 insertions(+), 44 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index f99ca4ab09ae..7d53b6e1352f 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -10,15 +10,15 @@ cutlass_moe, fused_experts) -DEFAULT_MODELS = ["nm-testing/Mixtral-8x7B-Instruct-v0.1", - "nm-testing/deepseekv2-lite", - "ibm-granite/granite-3.0-1b-a400m", - "ibm-granite/granite-3.0-3b-a800m"] +DEFAULT_MODELS = [ + "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" +] DEFAULT_BATCH_SIZES = [16, 32, 64, 128, 256, 512] -NUM_GROUPS_OPTS = [8] #[8, 64] -PER_ACT_TOKEN_OPTS = [False] #[False, True] -PER_OUT_CH_OPTS = [False] #[False, True] +NUM_GROUPS_OPTS = [8] #[8, 64] +PER_ACT_TOKEN_OPTS = [False] #[False, True] +PER_OUT_CH_OPTS = [False] #[False, True] TOPKS = [2, 6] @@ -46,6 +46,7 @@ def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], out = torch.mm(a, b) out_tensors[g] = out + # TODO marlin baseline def bench_run(results: List[benchmark.Measurement], model: str, num_experts: int, topk: int, per_act_token: bool, @@ -125,11 +126,11 @@ def bench_run(results: List[benchmark.Measurement], model: str, sub_label=sub_label, description="baseline_gemm", ).blocked_autorange(min_run_time=min_run_time)) - + # Warmup pytorch for _ in range(num_warmup): - cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, w2_scales, - topk_weights, topk_ids, m, n, k) + cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, w2_scales, + topk_weights, topk_ids, m, n, k) results.append( benchmark.Timer( @@ -141,6 +142,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, description="grouped_gemm", ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): @@ -166,7 +168,7 @@ def main(args): for size_m in DEFAULT_BATCH_SIZES: mkn = (size_m, size_k, size_n) bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) compare.print() diff --git a/csrc/ops.h b/csrc/ops.h index d7ec0e0f9128..dc9c559d6c63 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -156,10 +156,11 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, std::optional const& bias); void cutlass_grouped_mm(c10::List const& out_tensors, - c10::List const& a_tensors, + torch::Tensor const& a_tensors, c10::List const& b_tensors, c10::List const& a_scales, - c10::List const& b_scales); + c10::List const& b_scales, + torch::Tensor const& expert_offsets); void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 4abb84e3e0bb..65513d6df3ff 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -127,14 +127,15 @@ cutlass::platform::unique_ptr> make_device_ptr( template void cutlass_group_gemm_caller(c10::List const& out_tensors, - c10::List const& a_tensors, + torch::Tensor const& a_tensors, c10::List const& b_tensors, c10::List const& a_scales, - c10::List const& b_scales) { + c10::List const& b_scales, + torch::Tensor const& expert_offsets) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; - int groups = (int)a_tensors.size(); + int groups = (int)expert_offsets.size(0) - 1; TORCH_CHECK((int)b_tensors.size() == groups, "Number of B tensors must match number of groups."); TORCH_CHECK((int)out_tensors.size() == groups, @@ -150,9 +151,14 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, std::vector problem_sizes_host; problem_sizes_host.reserve(groups); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); + torch::Tensor a_ptrs_base = torch::full({groups + 1}, + (int64_t)a_tensors.data_ptr(), + options_int); + torch::Tensor a_ptrs = a_ptrs_base.add(expert_offsets, a_tensors.size(1)); + for (int g = 0; g < groups; ++g) { - a_ptrs_host[g] = - reinterpret_cast(a_tensors[g].data_ptr()); b_ptrs_host[g] = reinterpret_cast(b_tensors[g].data_ptr()); c_ptrs_host[g] = @@ -165,8 +171,8 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], // c_ptrs_host[g], d_ptrs_host[g],) - int64_t m = a_tensors[g].size(0); - int64_t k = a_tensors[g].size(1); + int64_t m = out_tensors[g].size(0); + int64_t k = a_tensors.size(1); int64_t k_b = b_tensors[g].size(0); int64_t n = b_tensors[g].size(1); @@ -192,7 +198,7 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, std::vector c_stride_host(groups); for (int32_t g = 0; g < groups; ++g) { - int64_t lda = a_tensors[g].stride(0); // row-major (m x k) + int64_t lda = a_tensors.stride(0); // row-major (m x k) int64_t ldb = b_tensors[g].stride(1); // column-major (k x n) int64_t ldc = out_tensors[g].stride(0); // row-major (m x n) @@ -211,7 +217,7 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; - auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); + // auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); @@ -224,7 +230,7 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, auto c_stride_ptr = make_device_ptr(c_stride_host); typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), + (const ElementAB_Type**)a_ptrs.data_ptr(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales @@ -245,10 +251,10 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors[0].device()); + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors[0].device().index()); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } @@ -303,16 +309,18 @@ struct sm90_fp8_config_M64 { } // namespace +// TODO void cutlass_grouped_mm_sm90(c10::List const& out_tensors, - c10::List const& a_tensors, + torch::Tensor const& a_tensors, c10::List const& b_tensors, c10::List const& a_scales, - c10::List const& b_scales) { - TORCH_CHECK(a_tensors.size() > 0, "No input A tensors provided."); + c10::List const& b_scales, + torch::Tensor const& expert_offsets) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(b_tensors.size() > 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size() > 0, "No output tensors provided."); - TORCH_CHECK(a_tensors[0].dtype() == torch::kFloat8_e4m3fn, + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, "A tensors must be of type float8_e4m3fn."); TORCH_CHECK(b_tensors[0].dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); @@ -321,7 +329,7 @@ void cutlass_grouped_mm_sm90(c10::List const& out_tensors, ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales); + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets); } __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index d9d2a91d0659..38201f080f64 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -31,10 +31,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); void cutlass_grouped_mm_sm90(c10::List const& out_tensors, - c10::List const& a_tensors, + torch::Tensor const& a_tensors, c10::List const& b_tensors, c10::List const& a_scales, - c10::List const& b_scales); + c10::List const& b_scales, + torch::Tensor const& expert_offsets); void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, @@ -158,12 +159,13 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, } void cutlass_grouped_mm(c10::List const& out_tensors, - c10::List const& a_tensors, + torch::Tensor const& a_tensors, c10::List const& b_tensors, c10::List const& a_scales, - c10::List const& b_scales) { + c10::List const& b_scales, + torch::Tensor const& expert_offsets) { cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, - b_scales); + b_scales, expert_offsets); } void compute_expert_offsets(torch::Tensor& trg_a_ptrs, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 65d48c7f1465..e46166920745 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -324,9 +324,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( "cutlass_grouped_mm(Tensor![] out_tensors," - " Tensor[] a_tensors," + " Tensor a_tensors," " Tensor[] b_tensors, Tensor[] a_scales, " - " Tensor[] b_scales) -> ()"); + " Tensor[] b_scales, Tensor expert_offsets) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 445a06f57a96..ca1adbafc106 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -479,13 +479,19 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, out_tensors = [] baseline_tensors = [] + expert_offsets = torch.zeros((num_groups + 1), + device=device, + dtype=torch.int32) + alignment = 16 # 128 // 8 # For variation, each group has dimensions # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) - for _ in range(num_groups): + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + for g in range(num_groups): m_g = alignment * random.randint(1, 64) - n_g = alignment * random.randint(1, 64) - k_g = alignment * random.randint(1, 64) + + expert_offsets[g + 1] = expert_offsets[g] + m_g m_a_scales = m_g if per_act_token else 1 n_b_scales = n_g if per_out_ch else 1 @@ -515,8 +521,16 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, None) baseline_tensors.append(baseline_g) - torch.ops._C.cutlass_grouped_mm(out_tensors, a_tensors, b_tensors, - a_scales_tensors, b_scales_tensors) + a_tensors_stacked = torch.empty((expert_offsets[num_groups], k_g), + device=device, + dtype=torch.float8_e4m3fn) + for g in range(num_groups): + a_tensors_stacked[expert_offsets[g]:expert_offsets[g + + 1]] = a_tensors[g] + + torch.ops._C.cutlass_grouped_mm(out_tensors, a_tensors_stacked, b_tensors, + a_scales_tensors, b_scales_tensors, + expert_offsets) # Validate each group's result against the baseline for c_g, baseline_g in zip(out_tensors, baseline_tensors): diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index df0be62d369a..efb60df3c2c5 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -10,6 +10,7 @@ NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] + @pytest.mark.parametrize("m", [16, 32, 64, 224]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @@ -62,7 +63,8 @@ def test_cutlass_moe( w2_scales[expert]).half() torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) cutlass_output = cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, - w2_scales, topk_weights, topk_ids, m, n, k) + w2_scales, topk_weights, topk_ids, m, n, + k) # print(torch_output) # print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 185e6f082ea5..ff15363b31a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -870,6 +870,7 @@ def fused_moe( a2_scale=a2_scale, block_shape=block_shape) + # TODO handle scores def cutlass_moe( a_q: torch.Tensor, @@ -960,4 +961,3 @@ def filter_list(items: List, idxs: List): out = (intermediate3.reshape(m, topk, k) * topk_weights.view(m, topk, 1).half()).sum(dim=1) return out - From 7549e3df5305b85a32290cf70d74a7674fb544fd Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 28 Jan 2025 04:09:38 +0000 Subject: [PATCH 14/58] Using tensors rather than tensor lists works with test_cutlass test Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 12 +- csrc/ops.h | 11 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 205 +++++++++++------- .../cutlass_w8a8/scaled_mm_entry.cu | 24 +- csrc/torch_bindings.cpp | 7 +- tests/kernels/test_cutlass.py | 61 +++++- 6 files changed, 213 insertions(+), 107 deletions(-) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 96ddcab7cea2..694428b82c59 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -119,12 +119,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); // CUTLASS w8a8 grouped GEMM // TODO complete this - ops.def( - "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " - " Tensor b_scales, Tensor problem_sizes, " - " Tensor out_offsets, Tensor a_offsets, " - " Tensor b_offsets) -> ()"); - ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); +// ops.def( +// "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " +// " Tensor b_scales, Tensor problem_sizes, " +// " Tensor out_offsets, Tensor a_offsets, " +// " Tensor b_offsets) -> ()"); +// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( "compute_expert_offsets(Tensor! trg_a_ptrs," diff --git a/csrc/ops.h b/csrc/ops.h index dc9c559d6c63..90ab8d1bf93b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -155,12 +155,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_grouped_mm(c10::List const& out_tensors, +void cutlass_grouped_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - c10::List const& b_tensors, - c10::List const& a_scales, - c10::List const& b_scales, - torch::Tensor const& expert_offsets); + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes); void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 65513d6df3ff..957ef0cf0677 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -126,67 +126,105 @@ cutlass::platform::unique_ptr> make_device_ptr( } template -void cutlass_group_gemm_caller(c10::List const& out_tensors, +void cutlass_group_gemm_caller(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - c10::List const& b_tensors, - c10::List const& a_scales, - c10::List const& b_scales, - torch::Tensor const& expert_offsets) { + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; int groups = (int)expert_offsets.size(0) - 1; - TORCH_CHECK((int)b_tensors.size() == groups, - "Number of B tensors must match number of groups."); - TORCH_CHECK((int)out_tensors.size() == groups, - "Number of output tensors must match number of groups."); + int k_size = a_tensors.size(1); + int n_size = out_tensors.size(1); - std::vector a_ptrs_host(groups); - std::vector b_ptrs_host(groups); - std::vector c_ptrs_host(groups); - std::vector d_ptrs_host(groups); - std::vector a_scales_ptrs_host(groups); - std::vector b_scales_ptrs_host(groups); + bool per_act_token = a_scales.numel() != groups; + bool per_out_ch = b_scales.numel() != groups; - std::vector problem_sizes_host; - problem_sizes_host.reserve(groups); + // TORCH_CHECK((int)b_tensors.size() == groups, + // "Number of B tensors must match number of groups."); + // TORCH_CHECK((int)out_tensors.size() == groups, + // "Number of output tensors must match number of groups."); + + // std::vector a_ptrs_host(groups); + // std::vector b_ptrs_host(groups); + // std::vector c_ptrs_host(groups); + // std::vector d_ptrs_host(groups); + // std::vector a_scales_ptrs_host(groups); + // std::vector b_scales_ptrs_host(groups); + + // std::vector problem_sizes_host; + // problem_sizes_host.reserve(groups); + + int b_single_size = k_size * n_size; + int b_scale_single_size = per_out_ch ? out_tensors.size(1) : 1; auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); torch::Tensor a_ptrs_base = torch::full({groups + 1}, (int64_t)a_tensors.data_ptr(), options_int); - torch::Tensor a_ptrs = a_ptrs_base.add(expert_offsets, a_tensors.size(1)); - - for (int g = 0; g < groups; ++g) { - b_ptrs_host[g] = - reinterpret_cast(b_tensors[g].data_ptr()); - c_ptrs_host[g] = - reinterpret_cast(out_tensors[g].data_ptr()); - d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); - a_scales_ptrs_host[g] = - reinterpret_cast(a_scales[g].data_ptr()); - b_scales_ptrs_host[g] = - reinterpret_cast(b_scales[g].data_ptr()); - - // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], - // c_ptrs_host[g], d_ptrs_host[g],) - int64_t m = out_tensors[g].size(0); - int64_t k = a_tensors.size(1); - - int64_t k_b = b_tensors[g].size(0); - int64_t n = b_tensors[g].size(1); - - TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, - " while B has k=", k_b); - - // Optionally, verify output shape matches (m,n) - TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, - "Output tensor shape does not match m,n from A,B: ", "Got ", - out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); - - problem_sizes_host.push_back({(int)m, (int)n, (int)k}); - } + torch::Tensor out_ptrs_base = torch::full({groups + 1}, + (int64_t)out_tensors.data_ptr(), + options_int); + torch::Tensor b_ptrs_base = torch::full({groups + 1}, + (int64_t)b_tensors.data_ptr(), + options_int); + torch::Tensor a_scales_base = torch::full({groups + 1}, + (int64_t)a_scales.data_ptr(), + options_int); + torch::Tensor b_scales_base = torch::full({groups + 1}, + (int64_t)b_scales.data_ptr(), + options_int); + + torch::Tensor b_offsets = torch::arange(0, b_single_size * (groups + 1), + b_single_size, options_int); + torch::Tensor a_scales_offsets = torch::arange(0, groups + 1, options_int); + torch::Tensor b_scales_offsets = torch::arange(0, b_scale_single_size * + (groups + 1), b_scale_single_size, + options_int); + + // multiply by offset of k 8-bit elements + torch::Tensor a_ptrs = a_ptrs_base.add(expert_offsets, a_tensors.size(1)); + // multiply by offset of n 16-bit elements + torch::Tensor out_ptrs = out_ptrs_base.add(expert_offsets, 2 * out_tensors.size(1)); + // multiply by offset of n 8-bit elements + torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets); + + torch::Tensor a_scales_ptrs = a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, 4); + torch::Tensor b_scales_ptrs = b_scales_base.add(b_scales_offsets, 4); + + // for (int g = 0; g < groups; ++g) { + // // b_ptrs_host[g] = + // // reinterpret_cast(b_list[g].data_ptr()); + // // c_ptrs_host[g] = + // // reinterpret_cast(out_tensors[g].data_ptr()); + // // d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); + // // a_scales_ptrs_host[g] = + // // reinterpret_cast(a_scales[g].data_ptr()); + // // b_scales_ptrs_host[g] = + // // reinterpret_cast(b_scales[g].data_ptr()); + + // // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], + // // c_ptrs_host[g], d_ptrs_host[g],) + // // int64_t m = out_tensors[g].size(0); + // // int64_t k = a_tensors.size(1); + + // // int64_t k_b = b_tensors[g].size(0); + // // int64_t n = b_tensors[g].size(1); + + // // TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, + // // " while B has k=", k_b); + + // // // Optionally, verify output shape matches (m,n) + // // TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, + // // "Output tensor shape does not match m,n from A,B: ", "Got ", + // // out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); + + // // problem_sizes_host.push_back({(int)m, (int)n, (int)k}); + // } using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -199,8 +237,9 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, for (int32_t g = 0; g < groups; ++g) { int64_t lda = a_tensors.stride(0); // row-major (m x k) - int64_t ldb = b_tensors[g].stride(1); // column-major (k x n) - int64_t ldc = out_tensors[g].stride(0); // row-major (m x n) + int64_t ldb = a_tensors.stride(0); // column-major (k x n) + int64_t ldc = out_tensors.stride(0); // row-major (m x n) + printf("strides: %ld %ld %ld\n", lda, ldb, ldc); a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; @@ -213,33 +252,49 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); - ProblemShape prob_shape{groups, problem_sizes_ptr.get(), - problem_sizes_host.data()}; + // auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); + // ProblemShape prob_shape{groups, problem_sizes_ptr.get(), + // problem_sizes_host.data()}; + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + reinterpret_cast( + problem_sizes.data_ptr()); + ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; // auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); - auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); - auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); + // auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); + // auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + // auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); - auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); + // auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); + // auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); auto a_stride_ptr = make_device_ptr(a_stride_host); auto b_stride_ptr = make_device_ptr(b_stride_host); auto c_stride_ptr = make_device_ptr(c_stride_host); + // auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); + // auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); + typename GemmKernel::MainloopArguments mainloop_args{ - (const ElementAB_Type**)a_ptrs.data_ptr(), a_stride_ptr.get(), b_ptrs_ptr.get(), - b_stride_ptr.get()}; + (const ElementAB_Type**)a_ptrs.data_ptr(), a_stride_ptr.get(), + (const ElementAB_Type**)b_ptrs.data_ptr(), b_stride_ptr.get()}; // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( - a_scales_ptrs_ptr.get(), b_scales_ptrs_ptr.get(), - a_scales[0].numel() != 1, b_scales[0].numel() != 1), - c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), - c_stride_ptr.get()}; + (const ElementAccumulator**)a_scales_ptrs.data_ptr(), + (const ElementAccumulator**)b_scales_ptrs.data_ptr(), + per_act_token, per_out_ch), + (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(), + (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()}; + + // typename GemmKernel::EpilogueArguments epilogue_args{ + // Gemm::Epilogue::prepare_args( + // (const ElementAccumulator**)a_scales_ptrs.data_ptr(), + // b_scales_ptrs_ptr.get(), + // per_act_token, per_out_ch), + // (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(), + // (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, @@ -254,9 +309,11 @@ void cutlass_group_gemm_caller(c10::List const& out_tensors, torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); auto workspace = torch::empty(workspace_size, workspace_options); + // printf("before: %d\n", out_tensors[0][0]); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); + // printf("after: %d\n", out_tensors[0][0]); } template const& out_tensors, +void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - c10::List const& b_tensors, - c10::List const& a_scales, - c10::List const& b_scales, - torch::Tensor const& expert_offsets) { + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes) { TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); - TORCH_CHECK(b_tensors.size() > 0, "No input B tensors provided."); - TORCH_CHECK(out_tensors.size() > 0, "No output tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, "A tensors must be of type float8_e4m3fn."); - TORCH_CHECK(b_tensors[0].dtype() == torch::kFloat8_e4m3fn, + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); using Cutlass3xGemmDefault = typename sm90_fp8_config_default< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets); + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes); } __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 38201f080f64..8fad49b45b4a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -30,12 +30,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_grouped_mm_sm90(c10::List const& out_tensors, +void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - c10::List const& b_tensors, - c10::List const& a_scales, - c10::List const& b_scales, - torch::Tensor const& expert_offsets); + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes); void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, @@ -158,14 +159,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_grouped_mm(c10::List const& out_tensors, +void cutlass_grouped_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - c10::List const& b_tensors, - c10::List const& a_scales, - c10::List const& b_scales, - torch::Tensor const& expert_offsets) { + torch::Tensor const& b_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes) { cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, - b_scales, expert_offsets); + b_scales, expert_offsets, problem_sizes); } void compute_expert_offsets(torch::Tensor& trg_a_ptrs, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e46166920745..f31dc4f20df5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -323,10 +323,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( - "cutlass_grouped_mm(Tensor![] out_tensors," + "cutlass_grouped_mm(Tensor! out_tensors," " Tensor a_tensors," - " Tensor[] b_tensors, Tensor[] a_scales, " - " Tensor[] b_scales, Tensor expert_offsets) -> ()"); + " Tensor b_tensors, Tensor a_scales, " + " Tensor b_scales, Tensor expert_offsets, " + " Tensor problem_sizes) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index ca1adbafc106..f20f71f0fc9d 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -8,7 +8,7 @@ import pytest import torch -from tests.kernels.utils import opcheck +from tests.kernels.utils import opcheck, stack_and_dev from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -458,7 +458,7 @@ def test_cutlass_support_opcheck(): # TODO add bias -@pytest.mark.parametrize("num_groups", [8]) +@pytest.mark.parametrize("num_groups", [8, 64]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [False]) @@ -483,15 +483,23 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_groups, 3), + device=device, + dtype=torch.int32) + alignment = 16 # 128 // 8 # For variation, each group has dimensions # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64) + # one_b = to_fp8(torch.randn((n_g, k_g), device=device)) for g in range(num_groups): m_g = alignment * random.randint(1, 64) expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][0] = m_g + problem_sizes[g][1] = n_g + problem_sizes[g][2] = k_g m_a_scales = m_g if per_act_token else 1 n_b_scales = n_g if per_out_ch else 1 @@ -501,6 +509,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, # Create group-specific A and B (FP8) and output (FP16/FP32) a_g = to_fp8(torch.randn((m_g, k_g), device=device)) b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + # b_g = one_b.clone().t() c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) # Set up A/B scales scale_a = torch.randn((m_a_scales, 1), @@ -516,6 +525,8 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, a_scales_tensors.append(scale_a) b_scales_tensors.append(scale_b) + print(b_g.stride()) + # Compute baseline result for this group baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) @@ -524,17 +535,49 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, a_tensors_stacked = torch.empty((expert_offsets[num_groups], k_g), device=device, dtype=torch.float8_e4m3fn) + b_tensors_stacked = torch.empty((n_g * num_groups, k_g), + device=device, + dtype=torch.float8_e4m3fn) for g in range(num_groups): a_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] = a_tensors[g] + b_tensors_stacked[g*n_g:(g+1)*n_g, :] = b_tensors[g].t() + b_tensors_stacked = b_tensors_stacked.t() + + a_scales_tensors_stacked = torch.empty( + (expert_offsets[num_groups] if per_act_token else num_groups, 1), + device=device, + dtype=torch.float32) + if per_act_token: + for g in range(num_groups): + a_scales_tensors_stacked[expert_offsets[g]:expert_offsets[g + + 1]] = a_scales_tensors[g] + else: + for g in range(num_groups): + a_scales_tensors_stacked[g] = a_scales_tensors[g] + + b_scales_tensors_stacked = torch.empty( + (num_groups, n_b_scales), + device=device, + dtype=torch.float32) + for g in range(num_groups): + b_scales_tensors_stacked[g] = b_scales_tensors[g] - torch.ops._C.cutlass_grouped_mm(out_tensors, a_tensors_stacked, b_tensors, - a_scales_tensors, b_scales_tensors, - expert_offsets) + out_tensors_stacked = torch.zeros((expert_offsets[num_groups], n_g), + device=device, + dtype=out_dtype) + + torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked, + b_tensors_stacked, + a_scales_tensors_stacked, + b_scales_tensors_stacked, + expert_offsets, problem_sizes) # Validate each group's result against the baseline - for c_g, baseline_g in zip(out_tensors, baseline_tensors): - print(baseline_g) - print(c_g) + for g in range(num_groups): + baseline = baseline_tensors[g] + c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + print(baseline) + print(c) print("*") - torch.testing.assert_close(c_g, baseline_g, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) From 1ea7874a62b726c9a149b9b5ba6e0e8ddc701f27 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 28 Jan 2025 04:42:59 +0000 Subject: [PATCH 15/58] cleanup, add import Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 15 +- csrc/ops.h | 9 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 178 ++++++------------ csrc/torch_bindings.cpp | 3 +- tests/kernels/test_cutlass.py | 36 ++-- 5 files changed, 81 insertions(+), 160 deletions(-) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 2de5e595c64b..bbe6d2e8652d 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -119,19 +119,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); // CUTLASS w8a8 grouped GEMM // TODO complete this -// ops.def( -// "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " -// " Tensor b_scales, Tensor problem_sizes, " -// " Tensor out_offsets, Tensor a_offsets, " -// " Tensor b_offsets) -> ()"); -// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + // ops.def( + // "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, + // " " Tensor b_scales, Tensor problem_sizes, " " + // Tensor out_offsets, Tensor a_offsets, " " Tensor + // b_offsets) -> ()"); + // ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( "compute_expert_offsets(Tensor! trg_a_ptrs," " Tensor! a, Tensor topk_ids," " Tensor! expert_offsets, SymInt num_experts) -> ()"); - ops.impl("compute_expert_offsets", torch::kCUDA, - &compute_expert_offsets); + ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. diff --git a/csrc/ops.h b/csrc/ops.h index 394fb172a476..a67d7d757f3a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -167,11 +167,10 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes); -void compute_expert_offsets(torch::Tensor& trg_a_ptrs, - torch::Tensor& a, - const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - const int64_t num_experts); +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 957ef0cf0677..8ab15fd7be00 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -80,8 +80,6 @@ struct cutlass_3x_group_gemm { const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; - // the orig hat cutlass::epilogue::fusion::LinearCombination using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -143,88 +141,44 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, bool per_act_token = a_scales.numel() != groups; bool per_out_ch = b_scales.numel() != groups; - // TORCH_CHECK((int)b_tensors.size() == groups, - // "Number of B tensors must match number of groups."); - // TORCH_CHECK((int)out_tensors.size() == groups, - // "Number of output tensors must match number of groups."); - - // std::vector a_ptrs_host(groups); - // std::vector b_ptrs_host(groups); - // std::vector c_ptrs_host(groups); - // std::vector d_ptrs_host(groups); - // std::vector a_scales_ptrs_host(groups); - // std::vector b_scales_ptrs_host(groups); - - // std::vector problem_sizes_host; - // problem_sizes_host.reserve(groups); - int b_single_size = k_size * n_size; int b_scale_single_size = per_out_ch ? out_tensors.size(1) : 1; auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); - torch::Tensor a_ptrs_base = torch::full({groups + 1}, - (int64_t)a_tensors.data_ptr(), - options_int); - torch::Tensor out_ptrs_base = torch::full({groups + 1}, - (int64_t)out_tensors.data_ptr(), - options_int); - torch::Tensor b_ptrs_base = torch::full({groups + 1}, - (int64_t)b_tensors.data_ptr(), - options_int); - torch::Tensor a_scales_base = torch::full({groups + 1}, - (int64_t)a_scales.data_ptr(), - options_int); - torch::Tensor b_scales_base = torch::full({groups + 1}, - (int64_t)b_scales.data_ptr(), - options_int); + torch::Tensor a_ptrs_base = + torch::full({groups + 1}, reinterpret_cast(a_tensors.data_ptr()), + options_int); + torch::Tensor out_ptrs_base = torch::full( + {groups + 1}, reinterpret_cast(out_tensors.data_ptr()), + options_int); + torch::Tensor b_ptrs_base = + torch::full({groups + 1}, reinterpret_cast(b_tensors.data_ptr()), + options_int); + torch::Tensor a_scales_base = + torch::full({groups + 1}, reinterpret_cast(a_scales.data_ptr()), + options_int); + torch::Tensor b_scales_base = + torch::full({groups + 1}, reinterpret_cast(b_scales.data_ptr()), + options_int); torch::Tensor b_offsets = torch::arange(0, b_single_size * (groups + 1), - b_single_size, options_int); + b_single_size, options_int); torch::Tensor a_scales_offsets = torch::arange(0, groups + 1, options_int); - torch::Tensor b_scales_offsets = torch::arange(0, b_scale_single_size * - (groups + 1), b_scale_single_size, - options_int); - - // multiply by offset of k 8-bit elements - torch::Tensor a_ptrs = a_ptrs_base.add(expert_offsets, a_tensors.size(1)); - // multiply by offset of n 16-bit elements - torch::Tensor out_ptrs = out_ptrs_base.add(expert_offsets, 2 * out_tensors.size(1)); - // multiply by offset of n 8-bit elements - torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets); - - torch::Tensor a_scales_ptrs = a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, 4); - torch::Tensor b_scales_ptrs = b_scales_base.add(b_scales_offsets, 4); - - // for (int g = 0; g < groups; ++g) { - // // b_ptrs_host[g] = - // // reinterpret_cast(b_list[g].data_ptr()); - // // c_ptrs_host[g] = - // // reinterpret_cast(out_tensors[g].data_ptr()); - // // d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); - // // a_scales_ptrs_host[g] = - // // reinterpret_cast(a_scales[g].data_ptr()); - // // b_scales_ptrs_host[g] = - // // reinterpret_cast(b_scales[g].data_ptr()); - - // // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], - // // c_ptrs_host[g], d_ptrs_host[g],) - // // int64_t m = out_tensors[g].size(0); - // // int64_t k = a_tensors.size(1); - - // // int64_t k_b = b_tensors[g].size(0); - // // int64_t n = b_tensors[g].size(1); - - // // TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, - // // " while B has k=", k_b); - - // // // Optionally, verify output shape matches (m,n) - // // TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, - // // "Output tensor shape does not match m,n from A,B: ", "Got ", - // // out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); - - // // problem_sizes_host.push_back({(int)m, (int)n, (int)k}); - // } + torch::Tensor b_scales_offsets = torch::arange( + 0, b_scale_single_size * (groups + 1), b_scale_single_size, options_int); + + torch::Tensor a_ptrs = a_ptrs_base.add( + expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); + torch::Tensor out_ptrs = out_ptrs_base.add( + expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1)); + torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets, sizeof(ElementAB_Type)); + + torch::Tensor a_scales_ptrs = + a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, + sizeof(ElementAccumulator)); + torch::Tensor b_scales_ptrs = + b_scales_base.add(b_scales_offsets, sizeof(ElementAccumulator)); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -239,7 +193,6 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, int64_t lda = a_tensors.stride(0); // row-major (m x k) int64_t ldb = a_tensors.stride(0); // column-major (k x n) int64_t ldc = out_tensors.stride(0); // row-major (m x n) - printf("strides: %ld %ld %ld\n", lda, ldb, ldc); a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; @@ -252,49 +205,33 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - // auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); - // ProblemShape prob_shape{groups, problem_sizes_ptr.get(), - // problem_sizes_host.data()}; ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = reinterpret_cast( - problem_sizes.data_ptr()); + problem_sizes.data_ptr()); ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; - // auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); - // auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); - // auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - // auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - - // auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); - // auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); - auto a_stride_ptr = make_device_ptr(a_stride_host); auto b_stride_ptr = make_device_ptr(b_stride_host); auto c_stride_ptr = make_device_ptr(c_stride_host); - // auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - // auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - typename GemmKernel::MainloopArguments mainloop_args{ - (const ElementAB_Type**)a_ptrs.data_ptr(), a_stride_ptr.get(), - (const ElementAB_Type**)b_ptrs.data_ptr(), b_stride_ptr.get()}; + reinterpret_cast(a_ptrs.data_ptr()), + a_stride_ptr.get(), + reinterpret_cast(b_ptrs.data_ptr()), + b_stride_ptr.get()}; + // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - (const ElementAccumulator**)a_scales_ptrs.data_ptr(), - (const ElementAccumulator**)b_scales_ptrs.data_ptr(), - per_act_token, per_out_ch), - (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(), - (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()}; - - // typename GemmKernel::EpilogueArguments epilogue_args{ - // Gemm::Epilogue::prepare_args( - // (const ElementAccumulator**)a_scales_ptrs.data_ptr(), - // b_scales_ptrs_ptr.get(), - // per_act_token, per_out_ch), - // (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(), - // (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()}; + Gemm::Epilogue::prepare_args(reinterpret_cast( + a_scales_ptrs.data_ptr()), + reinterpret_cast( + b_scales_ptrs.data_ptr()), + per_act_token, per_out_ch), + reinterpret_cast(out_ptrs.data_ptr()), + c_stride_ptr.get(), + reinterpret_cast(out_ptrs.data_ptr()), + c_stride_ptr.get()}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, @@ -309,11 +246,9 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); auto workspace = torch::empty(workspace_size, workspace_options); - // printf("before: %d\n", out_tensors[0][0]); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); - // printf("after: %d\n", out_tensors[0][0]); } template typename Epilogue> struct sm90_fp8_config_M64 { - // M in [1, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; @@ -394,8 +328,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, cutlass::float_e4m3_t* base_a_ptr, const int* __restrict__ topk_ids, - int64_t* expert_offsets, - int topk_length) { + int64_t* expert_offsets, int topk_length) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -419,10 +352,12 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, // // For a given "a" of size [M,K] performs a permutation of the M rows based // // on the given "perm" indices. -// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, +// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* +// __restrict__ a_ptr, // int const* __restrict__ perm_int_ptr, -// cutlass::float_e4m3_t* __restrict__ out_ptr, -// int size_m, int size_k, int block_rows) { +// cutlass::float_e4m3_t* __restrict__ +// out_ptr, int size_m, int size_k, int +// block_rows) { // int start_row = block_rows * blockIdx.x; // int finish_row = start_row + block_rows; // if (finish_row > size_m) { @@ -459,17 +394,14 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, // }; // } -void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, - torch::Tensor& a, +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const int64_t num_experts) { - get_a_expert_offsets<<<1, num_experts>>>( + get_a_expert_offsets<<<1, num_experts>>>( (cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(), - (cutlass::float_e4m3_t*)a.data_ptr(), - (const int*)topk_ids.data_ptr(), - (int64_t*)expert_offsets.data_ptr(), - topk_ids.numel()); + (cutlass::float_e4m3_t*)a.data_ptr(), (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), topk_ids.numel()); } // void permute_fp8_rows(torch::Tensor& a_ptr, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2fdb13307424..81a97c3887bd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -337,8 +337,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "compute_expert_offsets(Tensor! trg_a_ptrs," " Tensor! a, Tensor topk_ids," " Tensor! expert_offsets, SymInt num_experts) -> ()"); - ops.impl("compute_expert_offsets", torch::kCUDA, - &compute_expert_offsets); + ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 0b71c5a1e6a4..a8c8630a0448 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,12 +2,13 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import random from typing import Type import pytest import torch -from tests.kernels.utils import opcheck, stack_and_dev +from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -452,7 +453,6 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, b_tensors = [] a_scales_tensors = [] b_scales_tensors = [] - out_tensors = [] baseline_tensors = [] expert_offsets = torch.zeros((num_groups + 1), @@ -460,15 +460,13 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, dtype=torch.int32) problem_sizes = torch.zeros((num_groups, 3), - device=device, - dtype=torch.int32) + device=device, + dtype=torch.int32) alignment = 16 # 128 // 8 # For variation, each group has dimensions - # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64) - # one_b = to_fp8(torch.randn((n_g, k_g), device=device)) for g in range(num_groups): m_g = alignment * random.randint(1, 64) @@ -485,8 +483,6 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, # Create group-specific A and B (FP8) and output (FP16/FP32) a_g = to_fp8(torch.randn((m_g, k_g), device=device)) b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) - # b_g = one_b.clone().t() - c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) # Set up A/B scales scale_a = torch.randn((m_a_scales, 1), device=device, @@ -497,12 +493,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, a_tensors.append(a_g) b_tensors.append(b_g) - out_tensors.append(c_g) a_scales_tensors.append(scale_a) b_scales_tensors.append(scale_b) - print(b_g.stride()) - # Compute baseline result for this group baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) @@ -517,7 +510,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, for g in range(num_groups): a_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] = a_tensors[g] - b_tensors_stacked[g*n_g:(g+1)*n_g, :] = b_tensors[g].t() + b_tensors_stacked[g * n_g:(g + 1) * n_g, :] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.t() a_scales_tensors_stacked = torch.empty( @@ -526,28 +519,27 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, dtype=torch.float32) if per_act_token: for g in range(num_groups): - a_scales_tensors_stacked[expert_offsets[g]:expert_offsets[g + - 1]] = a_scales_tensors[g] + a_scales_tensors_stacked[ + expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] else: for g in range(num_groups): a_scales_tensors_stacked[g] = a_scales_tensors[g] - b_scales_tensors_stacked = torch.empty( - (num_groups, n_b_scales), - device=device, - dtype=torch.float32) + b_scales_tensors_stacked = torch.empty((num_groups, n_b_scales), + device=device, + dtype=torch.float32) for g in range(num_groups): b_scales_tensors_stacked[g] = b_scales_tensors[g] out_tensors_stacked = torch.zeros((expert_offsets[num_groups], n_g), - device=device, - dtype=out_dtype) + device=device, + dtype=out_dtype) torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked, - b_scales_tensors_stacked, - expert_offsets, problem_sizes) + b_scales_tensors_stacked, expert_offsets, + problem_sizes) # Validate each group's result against the baseline for g in range(num_groups): From d608164b0c7464fa473604347f28e9a95a5fe232 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 29 Jan 2025 07:09:27 +0000 Subject: [PATCH 16/58] working fused op Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 12 +- csrc/ops.h | 8 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 71 +++++------ .../cutlass_w8a8/scaled_mm_entry.cu | 29 ++--- csrc/torch_bindings.cpp | 6 +- tests/kernels/test_cutlass.py | 10 +- tests/kernels/test_cutlass_moe.py | 57 ++++----- .../layers/fused_moe/fused_moe.py | 114 ++++++------------ 8 files changed, 140 insertions(+), 167 deletions(-) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index bbe6d2e8652d..fa3ab7afb946 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -126,11 +126,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // b_offsets) -> ()"); // ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); - ops.def( - "compute_expert_offsets(Tensor! trg_a_ptrs," - " Tensor! a, Tensor topk_ids," - " Tensor! expert_offsets, SymInt num_experts) -> ()"); - ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); + // ops.def( + // "compute_expert_offsets(Tensor! trg_a_ptrs," + // " Tensor! a, Tensor topk_ids," + // " Tensor! expert_offsets, SymInt num_experts) -> + // ()"); + // ops.impl("compute_expert_offsets", torch::kCUDA, + // &compute_expert_offsets); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. diff --git a/csrc/ops.h b/csrc/ops.h index a67d7d757f3a..ecb9a02ad54a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -167,10 +167,12 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes); -void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a, - const torch::Tensor& topk_ids, +void compute_expert_offsets(const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - const int64_t num_experts); + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const int64_t num_experts, const int64_t n, + const int64_t k); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 8ab15fd7be00..ae89d72db41a 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -134,7 +134,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; - int groups = (int)expert_offsets.size(0) - 1; + int groups = (int)expert_offsets.size(0); int k_size = a_tensors.size(1); int n_size = out_tensors.size(1); @@ -146,27 +146,22 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); - torch::Tensor a_ptrs_base = - torch::full({groups + 1}, reinterpret_cast(a_tensors.data_ptr()), - options_int); + torch::Tensor a_ptrs_base = torch::full( + groups, reinterpret_cast(a_tensors.data_ptr()), options_int); torch::Tensor out_ptrs_base = torch::full( - {groups + 1}, reinterpret_cast(out_tensors.data_ptr()), - options_int); - torch::Tensor b_ptrs_base = - torch::full({groups + 1}, reinterpret_cast(b_tensors.data_ptr()), - options_int); - torch::Tensor a_scales_base = - torch::full({groups + 1}, reinterpret_cast(a_scales.data_ptr()), - options_int); - torch::Tensor b_scales_base = - torch::full({groups + 1}, reinterpret_cast(b_scales.data_ptr()), - options_int); - - torch::Tensor b_offsets = torch::arange(0, b_single_size * (groups + 1), - b_single_size, options_int); - torch::Tensor a_scales_offsets = torch::arange(0, groups + 1, options_int); + groups, reinterpret_cast(out_tensors.data_ptr()), options_int); + torch::Tensor b_ptrs_base = torch::full( + groups, reinterpret_cast(b_tensors.data_ptr()), options_int); + torch::Tensor a_scales_base = torch::full( + groups, reinterpret_cast(a_scales.data_ptr()), options_int); + torch::Tensor b_scales_base = torch::full( + groups, reinterpret_cast(b_scales.data_ptr()), options_int); + + torch::Tensor b_offsets = + torch::arange(0, b_single_size * groups, b_single_size, options_int); + torch::Tensor a_scales_offsets = torch::arange(0, groups, options_int); torch::Tensor b_scales_offsets = torch::arange( - 0, b_scale_single_size * (groups + 1), b_scale_single_size, options_int); + 0, b_scale_single_size * groups, b_scale_single_size, options_int); torch::Tensor a_ptrs = a_ptrs_base.add( expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); @@ -189,6 +184,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, std::vector b_stride_host(groups); std::vector c_stride_host(groups); + // TODO pass strides? for (int32_t g = 0; g < groups; ++g) { int64_t lda = a_tensors.stride(0); // row-major (m x k) int64_t ldb = a_tensors.stride(0); // column-major (k x n) @@ -325,10 +321,11 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, problem_sizes); } -__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, - cutlass::float_e4m3_t* base_a_ptr, - const int* __restrict__ topk_ids, - int64_t* expert_offsets, int topk_length) { +__global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, + int32_t* expert_offsets, + int32_t* problem_sizes1, + int32_t* problem_sizes2, int topk_length, + int n, int k) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -336,15 +333,19 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, for (int i = 0; i < topk_length; ++i) { occurrences += (topk_ids[i] == expert_id); } - expert_offsets[expert_id + 1] = occurrences; + problem_sizes1[expert_id * 3] = occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; __syncthreads(); if (threadIdx.x == 0) { - int64_t tot_offset = 0; + int32_t tot_offset = 0; expert_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { - trg_a_ptrs[i] = base_a_ptr + tot_offset; - tot_offset += expert_offsets[i + 1]; + tot_offset += problem_sizes1[i * 3]; expert_offsets[i + 1] = tot_offset; } } @@ -394,14 +395,16 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, // }; // } -void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a, - const torch::Tensor& topk_ids, +void compute_expert_offsets_caller(const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - const int64_t num_experts) { + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const int64_t num_experts, const int64_t n, + const int64_t k) { get_a_expert_offsets<<<1, num_experts>>>( - (cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(), - (cutlass::float_e4m3_t*)a.data_ptr(), (const int*)topk_ids.data_ptr(), - (int64_t*)expert_offsets.data_ptr(), topk_ids.numel()); + (const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(), + (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), + topk_ids.numel(), n, k); } // void permute_fp8_rows(torch::Tensor& a_ptr, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 8fad49b45b4a..2bca67073105 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -38,12 +38,12 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes); - -void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, - torch::Tensor& a, - const torch::Tensor& topk_ids, +void compute_expert_offsets_caller(const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - const int64_t num_experts); + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const int64_t num_experts, const int64_t n, + const int64_t k); #endif @@ -166,17 +166,18 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes) { - cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, - b_scales, expert_offsets, problem_sizes); + cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes); } -void compute_expert_offsets(torch::Tensor& trg_a_ptrs, - torch::Tensor& a, - const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - const int64_t num_experts) { - compute_expert_offsets_caller(trg_a_ptrs, a, topk_ids, expert_offsets, - num_experts); +void compute_expert_offsets(const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const int64_t num_experts, const int64_t n, + const int64_t k) { + compute_expert_offsets_caller(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, num_experts, n, k); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 81a97c3887bd..d73485449273 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -334,9 +334,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( - "compute_expert_offsets(Tensor! trg_a_ptrs," - " Tensor! a, Tensor topk_ids," - " Tensor! expert_offsets, SymInt num_experts) -> ()"); + "compute_expert_offsets(Tensor topk_ids, Tensor! expert_offsets, " + " Tensor! problem_sizes1, Tensor! problem_sizes2, " + " SymInt num_experts, SymInt n, SymInt k) -> ()"); ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); // Check if cutlass sparse scaled_mm is supported for CUDA devices of the diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index a8c8630a0448..d37551d4bc11 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -504,14 +504,14 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, a_tensors_stacked = torch.empty((expert_offsets[num_groups], k_g), device=device, dtype=torch.float8_e4m3fn) - b_tensors_stacked = torch.empty((n_g * num_groups, k_g), + b_tensors_stacked = torch.empty((num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn) for g in range(num_groups): a_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] = a_tensors[g] - b_tensors_stacked[g * n_g:(g + 1) * n_g, :] = b_tensors[g].t() - b_tensors_stacked = b_tensors_stacked.t() + b_tensors_stacked[g] = b_tensors[g].t() + b_tensors_stacked = b_tensors_stacked.transpose(1, 2) a_scales_tensors_stacked = torch.empty( (expert_offsets[num_groups] if per_act_token else num_groups, 1), @@ -538,8 +538,8 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked, - b_scales_tensors_stacked, expert_offsets, - problem_sizes) + b_scales_tensors_stacked, + expert_offsets[:-1], problem_sizes) # Validate each group's result against the baseline for g in range(num_groups): diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index efb60df3c2c5..0e223e6cb832 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -3,15 +3,16 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk, cutlass_moe -from vllm.platforms import current_platform from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, + fused_topk) +from vllm.platforms import current_platform NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] -@pytest.mark.parametrize("m", [16, 32, 64, 224]) +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -36,39 +37,39 @@ def test_cutlass_moe( a_q, a_scale = ops.scaled_fp8_quant(a) - w1_qs = [] - w2_qs = [] - w1_scales = [] - w2_scales = [] + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) for expert in range(e): - w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) - w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) - w1_qs.append(w1_q.t()) - w2_qs.append(w2_q.t()) - w1_scales.append(w1_scale.reshape((1, 1))) - w2_scales.append(w2_scale.reshape((1, 1))) - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) a_d = (a_q.float() * a_scale).half() + w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half() + w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half() + w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) for expert in range(e): - w1_d[expert] = (w1_qs[expert].t().float() * - w1_scales[expert]).half() - w2_d[expert] = (w2_qs[expert].t().float() * - w2_scales[expert]).half() + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) - cutlass_output = cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, - w2_scales, topk_weights, topk_ids, m, n, - k) + cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, + w2_scale, topk_weights, topk_ids, m, n, k, + e) - # print(torch_output) - # print(cutlass_output) - # print(torch_output / cutlass_output) + print(torch_output) + print(cutlass_output) + print("*") torch.testing.assert_close(torch_output, cutlass_output, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f38d3ed35ccf..11807f4d373c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -880,93 +880,57 @@ def fused_moe( block_shape=block_shape) -# TODO handle scores def cutlass_moe( a_q: torch.Tensor, a_scale: torch.Tensor, - w1_qs: List[torch.Tensor], - w2_qs: List[torch.Tensor], - w1_scales: List[torch.Tensor], - w2_scales: List[torch.Tensor], + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, k: int, + num_groups: int, ): - num_groups = len(w1_qs) topk = topk_ids.shape[1] - a_ptrs = torch.empty((num_groups), dtype=torch.int64, device="cuda") expert_offsets = torch.empty((num_groups + 1), - dtype=torch.int64, + dtype=torch.int32, + device="cuda") + problem_sizes1 = torch.empty((num_groups, 3), + dtype=torch.int32, + device="cuda") + problem_sizes2 = torch.empty((num_groups, 3), + dtype=torch.int32, device="cuda") a_map = topk_ids.flatten().argsort() - rep_a_q = a_q.repeat_interleave(topk, dim=0) - - torch.ops._C.compute_expert_offsets(a_ptrs, rep_a_q, topk_ids.cuda(), - expert_offsets, num_groups) - - a_q_s = [] - a_scales_s = [] - c_s1 = [] - c_s2 = [] - for e in range(num_groups): - expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] - cut_out = rep_a_q.view(dtype=torch.uint8)[expert_map].view( - dtype=a_q.dtype) - a_q_s.append(cut_out.clone()) - a_scales_s.append(a_scale.clone()) - c_s1.append( - torch.zeros((cut_out.shape[0], n * 2), - device="cuda", - dtype=torch.half)) - c_s2.append( - torch.zeros((cut_out.shape[0], k), device="cuda", - dtype=torch.half)) - - torch.ops._C.cutlass_grouped_mm(c_s1, a_q_s, w1_qs, a_scales_s, w1_scales) - - # ### UNCOMMENT THIS TO DO ONLY A SINGLE MUL - # intermediate1 = torch.empty((m * topk, n * 2), - # device="cuda", - # dtype=torch.half) - # for e in range(num_groups): - # expert_map = a_map[expert_offsets[e]:expert_offsets[e+1]] - # intermediate1[expert_map] = c_s1[e] - # return intermediate1.reshape(m, topk, n * 2).sum(dim=1) - # ### - - full_groups = [] - - intermediate2 = [] - intermediate2_scales = [] - for e in range(num_groups): - if c_s1[e].shape[0] != 0: - full_groups.append(e) - inter2 = torch.empty((c_s1[e].shape[0], n), - device="cuda", - dtype=torch.half) - torch.ops._C.silu_and_mul(inter2, c_s1[e]) - inter2_v, inter2_s = ops.scaled_fp8_quant(inter2) - intermediate2.append(inter2_v) - intermediate2_scales.append(inter2_s.reshape((1, 1))) - - def filter_list(items: List, idxs: List): - return [items[idx] for idx in idxs] - - torch.ops._C.cutlass_grouped_mm(filter_list(c_s2, - full_groups), intermediate2, - filter_list(w2_qs, full_groups), - intermediate2_scales, - filter_list(w2_scales, full_groups)) - intermediate3 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) - for e in range(num_groups): - expert_map = a_map[expert_offsets[e]:expert_offsets[e + 1]] - intermediate3[expert_map] = c_s2[e] - - intermediate3.reshape(m, topk, k).sum(dim=1) - out = (intermediate3.reshape(m, topk, k) * - topk_weights.view(m, topk, 1).half()).sum(dim=1) - return out + rep_a_q = a_q.repeat_interleave( + topk, dim=0).view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a_scales = a_scale.repeat((num_groups, 1)) + + torch.ops._C.compute_expert_offsets(topk_ids.cuda(), expert_offsets, + problem_sizes1, problem_sizes2, + num_groups, n, k) + + c_s1 = torch.zeros((m * topk, n * 2), device="cuda", dtype=torch.half) + c_s2 = torch.zeros((m * topk, k), device="cuda", dtype=torch.half) + + torch.ops._C.cutlass_grouped_mm(c_s1, rep_a_q, w1_q, rep_a_scales, + w1_scale, expert_offsets[:-1], + problem_sizes1) + + intermediate = torch.empty((m * topk, n), device="cuda", dtype=torch.half) + torch.ops._C.silu_and_mul(intermediate, c_s1) + + intemediate_q, intermediate_scales = ops.scaled_fp8_quant(intermediate) + rep_intermediate_scales = intermediate_scales.repeat((num_groups, 1)) + + torch.ops._C.cutlass_grouped_mm(c_s2, intemediate_q, w2_q, + rep_intermediate_scales, w2_scale, + expert_offsets[:-1], problem_sizes2) + + return (c_s2[a_map.argsort()].view(m, topk, k) * + topk_weights.view(m, topk, 1).half()).sum(dim=1) From 286f6c845db359be32e3f9220282f93fd34b0649 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 29 Jan 2025 13:45:19 +0000 Subject: [PATCH 17/58] benchmark, create strides directly on device, small name refactor Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 48 +++++++++---------- .../cutlass_w8a8/grouped_mm_c3x.cu | 44 ++++++++++------- .../layers/fused_moe/fused_moe.py | 15 +++--- 3 files changed, 56 insertions(+), 51 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 7d53b6e1352f..0867c04b9bd8 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -5,10 +5,10 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, + fused_experts, + fused_topk) from vllm.utils import FlexibleArgumentParser -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, - cutlass_moe, - fused_experts) DEFAULT_MODELS = [ "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", @@ -69,18 +69,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, a_q, a_scale = ops.scaled_fp8_quant(a) - w1_qs = [] - w2_qs = [] - w1_scales = [] - w2_scales = [] - - for expert in range(num_experts): - w1_q, w1_scale = ops.scaled_fp8_quant(w1[expert]) - w2_q, w2_scale = ops.scaled_fp8_quant(w2[expert]) - w1_qs.append(w1_q.t()) - w2_qs.append(w2_q.t()) - w1_scales.append(w1_scale.reshape((1, 1))) - w2_scales.append(w2_scale.reshape((1, 1))) + w1_q = torch.empty((num_experts, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((num_experts, k, n), + device="cuda", + dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((num_experts, 1, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), + device="cuda", + dtype=torch.float32) score = torch.randn((m, num_experts), device="cuda", dtype=dtype) @@ -96,13 +96,14 @@ def bench_run(results: List[benchmark.Measurement], model: str, # Cutlass params "a_q": a_q, "a_scale": a_scale, - "w1_qs": w1_qs, - "w2_qs": w2_qs, - "w1_scales": w1_scales, - "w2_scales": w2_scales, + "w1_q": w1_q, + "w2_q": w2_q, + "w1_scale": w1_scale, + "w2_scale": w2_scale, "m": m, "n": n, "k": k, + "num_experts": num_experts, # Gen params "topk_weights": topk_weights, "topk_ids": topk_ids, @@ -129,13 +130,13 @@ def bench_run(results: List[benchmark.Measurement], model: str, # Warmup pytorch for _ in range(num_warmup): - cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, w2_scales, - topk_weights, topk_ids, m, n, k) + cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, + topk_ids, m, n, k, num_experts) results.append( benchmark.Timer( stmt= - "cutlass_moe(a_q, a_scale, w1_qs, w2_qs, w1_scales, w2_scales, topk_weights, topk_ids, m, n, k)", # noqa: E501 + "cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -174,9 +175,6 @@ def main(args): compare.print() -# For quick benchmarking use: -# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 ... -# if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark Marlin across specified models/shapes/batches") diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index ae89d72db41a..2ab4ee3d1e12 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -32,6 +32,16 @@ using namespace cute; #define ENABLE_SM90_KERNEL_LEVEL 1 #endif +template +__global__ void make_mm_strides(StrideA* stride_a, StrideB* stride_b, + StrideC* stride_c, int64_t lda, int64_t ldb, + int64_t ldc) { + int expert_id = threadIdx.x; + stride_a[expert_id] = StrideA{lda, Int<1>{}, Int<0>{}}; + stride_b[expert_id] = StrideB{ldb, Int<1>{}, Int<0>{}}; + stride_c[expert_id] = StrideC{ldc, Int<1>{}, Int<0>{}}; +} + namespace { template @@ -123,6 +133,14 @@ cutlass::platform::unique_ptr> make_device_ptr( return cutlass::platform::unique_ptr>(data_device); } +template +cutlass::platform::unique_ptr> allocate_device_ptr( + int count) { + T* data_device; + cudaMalloc(&data_device, count * sizeof(T)); + return cutlass::platform::unique_ptr>(data_device); +} + template void cutlass_group_gemm_caller(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, @@ -180,20 +198,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, using StrideB = Stride, Int<0>>; using StrideC = typename GemmKernel::InternalStrideC; - std::vector a_stride_host(groups); - std::vector b_stride_host(groups); - std::vector c_stride_host(groups); - - // TODO pass strides? - for (int32_t g = 0; g < groups; ++g) { - int64_t lda = a_tensors.stride(0); // row-major (m x k) - int64_t ldb = a_tensors.stride(0); // column-major (k x n) - int64_t ldc = out_tensors.stride(0); // row-major (m x n) - - a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; - b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; - c_stride_host[g] = StrideC{ldc, Int<1>{}, Int<0>{}}; - } + int64_t lda = a_tensors.stride(0); // row-major (m x k) + int64_t ldb = a_tensors.stride(0); // column-major (k x n) + int64_t ldc = out_tensors.stride(0); // row-major (m x n) + auto a_stride_ptr = allocate_device_ptr(groups); + auto b_stride_ptr = allocate_device_ptr(groups); + auto c_stride_ptr = allocate_device_ptr(groups); + make_mm_strides<<<1, groups>>>(a_stride_ptr.get(), b_stride_ptr.get(), + c_stride_ptr.get(), lda, ldb, ldc); cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; @@ -206,10 +218,6 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, problem_sizes.data_ptr()); ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; - auto a_stride_ptr = make_device_ptr(a_stride_host); - auto b_stride_ptr = make_device_ptr(b_stride_host); - auto c_stride_ptr = make_device_ptr(c_stride_host); - typename GemmKernel::MainloopArguments mainloop_args{ reinterpret_cast(a_ptrs.data_ptr()), a_stride_ptr.get(), diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 11807f4d373c..7cb8d65c12cd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -915,22 +915,21 @@ def cutlass_moe( problem_sizes1, problem_sizes2, num_groups, n, k) - c_s1 = torch.zeros((m * topk, n * 2), device="cuda", dtype=torch.half) - c_s2 = torch.zeros((m * topk, k), device="cuda", dtype=torch.half) + c1 = torch.zeros((m * topk, n * 2), device="cuda", dtype=torch.half) + c2 = torch.zeros((m * topk, k), device="cuda", dtype=torch.half) - torch.ops._C.cutlass_grouped_mm(c_s1, rep_a_q, w1_q, rep_a_scales, - w1_scale, expert_offsets[:-1], - problem_sizes1) + torch.ops._C.cutlass_grouped_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, + expert_offsets[:-1], problem_sizes1) intermediate = torch.empty((m * topk, n), device="cuda", dtype=torch.half) - torch.ops._C.silu_and_mul(intermediate, c_s1) + torch.ops._C.silu_and_mul(intermediate, c1) intemediate_q, intermediate_scales = ops.scaled_fp8_quant(intermediate) rep_intermediate_scales = intermediate_scales.repeat((num_groups, 1)) - torch.ops._C.cutlass_grouped_mm(c_s2, intemediate_q, w2_q, + torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, rep_intermediate_scales, w2_scale, expert_offsets[:-1], problem_sizes2) - return (c_s2[a_map.argsort()].view(m, topk, k) * + return (c2[a_map.argsort()].view(m, topk, k) * topk_weights.view(m, topk, 1).half()).sum(dim=1) From b6867bb2225e7fd7b9b89c737d2ffb21e82f7867 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 31 Jan 2025 06:23:46 +0000 Subject: [PATCH 18/58] works with cuda graphs Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 93 +++++++++++++------ .../epilogue/scaled_mm_epilogues_c3x.hpp | 24 ++--- .../cutlass_w8a8/grouped_mm_c3x.cu | 69 +++++++------- tests/kernels/test_cutlass.py | 2 +- tests/kernels/test_cutlass_moe.py | 85 ++++++++++++++++- .../layers/fused_moe/fused_moe.py | 2 +- 6 files changed, 198 insertions(+), 77 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 0867c04b9bd8..fc9ed3fa6023 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -5,6 +5,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, fused_experts, fused_topk) @@ -22,32 +23,24 @@ TOPKS = [2, 6] +def run_from_graph(a_q: torch.Tensor, a_scale: torch.Tensor, + w1_q: torch.Tensor, w2_q: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, + n: int, k: int, e: int): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, m, n, k, e) + + def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) return torch.round(tensor.clamp( min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) -def grouped_gemm(a_g_tensors: List[torch.Tensor], - b_g_tensors: List[torch.Tensor], - out_g_tensors: List[torch.Tensor], - a_scales_tensors: List[torch.Tensor], - b_scales_tensors: List[torch.Tensor]): - ops.cutlass_grouped_mm(out_g_tensors, a_g_tensors, b_g_tensors, - a_scales_tensors, b_scales_tensors) - - -def baseline_gemm(num_groups: int, a_tensors: List[torch.Tensor], - b_tensors: List[torch.Tensor], - out_tensors: List[torch.Tensor]): - for g in range(num_groups): - a = a_tensors[g] - b = b_tensors[g] - out = torch.mm(a, b) - out_tensors[g] = out - - -# TODO marlin baseline def bench_run(results: List[benchmark.Measurement], model: str, num_experts: int, topk: int, per_act_token: bool, per_out_ch: bool, mkn: Tuple[int, int, int]): @@ -82,10 +75,29 @@ def bench_run(results: List[benchmark.Measurement], model: str, device="cuda", dtype=torch.float32) + for expert in range(num_experts): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q_notransp = w1_q.clone() + w2_q_notransp = w2_q.clone() + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + score = torch.randn((m, num_experts), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + def replay_graph(graph): + graph.replay() + torch.cuda.synchronize() + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + run_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, m, n, k, num_experts) + torch.cuda.synchronize() + globals = { # Baseline params "a": a, @@ -93,6 +105,8 @@ def bench_run(results: List[benchmark.Measurement], model: str, "w2": w2, "score": score, "topk": topk, + "w1_q_notransp": w1_q_notransp, + "w2_q_notransp": w2_q_notransp, # Cutlass params "a_q": a_q, "a_scale": a_scale, @@ -104,31 +118,45 @@ def bench_run(results: List[benchmark.Measurement], model: str, "n": n, "k": k, "num_experts": num_experts, + # Cutlass cuda graph params + "graph": graph, # Gen params "topk_weights": topk_weights, "topk_ids": topk_ids, # Kernels "fused_experts": fused_experts, "cutlass_moe": cutlass_moe, + "replay_graph": replay_graph, } min_run_time = 1 num_warmup = 5 - # Warmup pytorch + # Warmup for _ in range(num_warmup): - fused_experts(a, w1, w2, topk_weights, topk_ids) + # fused_experts(a, w1, w2, topk_weights, topk_ids) + fused_experts(a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) results.append( benchmark.Timer( - stmt="fused_experts(a, w1, w2, topk_weights, topk_ids)", + # stmt="fused_experts(a, w1, w2, topk_weights, topk_ids)", + stmt= + "fused_experts(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_scale)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="baseline_gemm", + description="triton_moe", ).blocked_autorange(min_run_time=min_run_time)) - # Warmup pytorch + # Warmup for _ in range(num_warmup): cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts) @@ -140,7 +168,20 @@ def bench_run(results: List[benchmark.Measurement], model: str, globals=globals, label=label, sub_label=sub_label, - description="grouped_gemm", + description="grouped_gemm_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + for _ in range(num_warmup): + replay_graph(graph) + + results.append( + benchmark.Timer( + stmt="replay_graph(graph)", + globals=globals, + label=label, + sub_label=sub_label, + description="grouped_gemm_moe_cuda_graphs", ).blocked_autorange(min_run_time=min_run_time)) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 79236ccf608a..d3a8a79a6cb3 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -49,14 +49,16 @@ struct ScaledEpilogueBase { Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; template - using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; + using ColOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; template - using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; + using RowOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -93,7 +95,6 @@ struct ScaledEpilogueBase { std::is_same_v>); return Arguments{data_ptr, do_broadcast}; } - }; /* @@ -368,10 +369,11 @@ struct ScaledEpilogueArray static ArgumentType prepare_args(const float* const* a_scales_ptr, const float* const* b_scales_ptr, - bool a_col_broadcast, - bool b_row_broadcast) { - auto a_args = SUPER::template args_from_tensor(a_scales_ptr, a_col_broadcast); - auto b_args = SUPER::template args_from_tensor(b_scales_ptr, b_row_broadcast); + bool a_col_broadcast, bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor( + a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor( + b_scales_ptr, b_row_broadcast); typename EVTCompute0::Arguments evt0_args{b_args}; return ArgumentType{a_args, evt0_args}; diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 2ab4ee3d1e12..44dd49a4475f 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -32,14 +32,14 @@ using namespace cute; #define ENABLE_SM90_KERNEL_LEVEL 1 #endif -template -__global__ void make_mm_strides(StrideA* stride_a, StrideB* stride_b, - StrideC* stride_c, int64_t lda, int64_t ldb, - int64_t ldc) { - int expert_id = threadIdx.x; - stride_a[expert_id] = StrideA{lda, Int<1>{}, Int<0>{}}; - stride_b[expert_id] = StrideB{ldb, Int<1>{}, Int<0>{}}; - stride_c[expert_id] = StrideC{ldc, Int<1>{}, Int<0>{}}; +// for debugging +__global__ void print_elements(int64_t* tensor, int64_t elements) { + if (threadIdx.x == 0) { + for (int64_t i = 0; i < elements; ++i) { + printf("%ld/%ld ", i, tensor[i]); + } + printf("\n---\n"); + } } namespace { @@ -159,39 +159,40 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, bool per_act_token = a_scales.numel() != groups; bool per_out_ch = b_scales.numel() != groups; - int b_single_size = k_size * n_size; - int b_scale_single_size = per_out_ch ? out_tensors.size(1) : 1; + int b_single_size = k_size * n_size * sizeof(ElementAB_Type); + int b_scale_single_size = + (per_out_ch ? out_tensors.size(1) : 1) * sizeof(ElementAccumulator); + // TODO b and b scales pointers can be computed outside this function auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); torch::Tensor a_ptrs_base = torch::full( groups, reinterpret_cast(a_tensors.data_ptr()), options_int); torch::Tensor out_ptrs_base = torch::full( groups, reinterpret_cast(out_tensors.data_ptr()), options_int); - torch::Tensor b_ptrs_base = torch::full( - groups, reinterpret_cast(b_tensors.data_ptr()), options_int); torch::Tensor a_scales_base = torch::full( groups, reinterpret_cast(a_scales.data_ptr()), options_int); - torch::Tensor b_scales_base = torch::full( - groups, reinterpret_cast(b_scales.data_ptr()), options_int); - torch::Tensor b_offsets = - torch::arange(0, b_single_size * groups, b_single_size, options_int); torch::Tensor a_scales_offsets = torch::arange(0, groups, options_int); - torch::Tensor b_scales_offsets = torch::arange( - 0, b_scale_single_size * groups, b_scale_single_size, options_int); torch::Tensor a_ptrs = a_ptrs_base.add( expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); torch::Tensor out_ptrs = out_ptrs_base.add( expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1)); - torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets, sizeof(ElementAB_Type)); torch::Tensor a_scales_ptrs = a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, sizeof(ElementAccumulator)); - torch::Tensor b_scales_ptrs = - b_scales_base.add(b_scales_offsets, sizeof(ElementAccumulator)); + + int64_t b_tensor_base_addr = reinterpret_cast(b_tensors.data_ptr()); + int64_t b_scales_base_addr = reinterpret_cast(b_scales.data_ptr()); + + torch::Tensor b_ptrs = torch::arange( + b_tensor_base_addr, b_tensor_base_addr + b_single_size * groups, + b_single_size, options_int); + torch::Tensor b_scales_ptrs = torch::arange( + b_scales_base_addr, b_scales_base_addr + b_scale_single_size * groups, + b_scale_single_size, options_int); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -201,17 +202,11 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, int64_t lda = a_tensors.stride(0); // row-major (m x k) int64_t ldb = a_tensors.stride(0); // column-major (k x n) int64_t ldc = out_tensors.stride(0); // row-major (m x n) - auto a_stride_ptr = allocate_device_ptr(groups); - auto b_stride_ptr = allocate_device_ptr(groups); - auto c_stride_ptr = allocate_device_ptr(groups); - make_mm_strides<<<1, groups>>>(a_stride_ptr.get(), b_stride_ptr.get(), - c_stride_ptr.get(), lda, ldb, ldc); - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + + // TODO move creation of these outside this function + torch::Tensor a_strides = torch::full({groups}, lda, options_int); + torch::Tensor b_strides = torch::full({groups}, ldb, options_int); + torch::Tensor c_strides = torch::full({groups}, ldc, options_int); ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = reinterpret_cast( @@ -220,9 +215,9 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, typename GemmKernel::MainloopArguments mainloop_args{ reinterpret_cast(a_ptrs.data_ptr()), - a_stride_ptr.get(), + reinterpret_cast(a_strides.data_ptr()), reinterpret_cast(b_ptrs.data_ptr()), - b_stride_ptr.get()}; + reinterpret_cast(b_strides.data_ptr())}; // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales @@ -233,13 +228,13 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, b_scales_ptrs.data_ptr()), per_act_token, per_out_ch), reinterpret_cast(out_ptrs.data_ptr()), - c_stride_ptr.get(), + reinterpret_cast(c_strides.data_ptr()), reinterpret_cast(out_ptrs.data_ptr()), - c_stride_ptr.get()}; + reinterpret_cast(c_strides.data_ptr())}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, - epilogue_args, hw_info}; + epilogue_args}; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index d37551d4bc11..914447884b16 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -478,7 +478,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, m_a_scales = m_g if per_act_token else 1 n_b_scales = n_g if per_out_ch else 1 - print(m_g, n_g, k_g) + print("shape:", m_g, n_g, k_g) # Create group-specific A and B (FP8) and output (FP16/FP32) a_g = to_fp8(torch.randn((m_g, k_g), device=device)) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 0e223e6cb832..b3ab7c486bd0 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -12,12 +12,23 @@ TOP_KS = [2, 6] +def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, + k: int, e: int): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, m, n, k, e) + + @pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -def test_cutlass_moe( +def test_cutlass_moe_no_graph( m: int, n: int, k: int, @@ -75,3 +86,75 @@ def test_cutlass_moe( cutlass_output, atol=5e-2, rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) +@pytest.mark.parametrize("n", [128, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_cutlass_moe_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + a_q, a_scale = ops.scaled_fp8_quant(a) + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + a_d = (a_q.float() * a_scale).half() + w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half() + w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half() + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, m, n, k, e) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + print(torch_output) + print(cutlass_output) + # print((cutlass_output - torch_output) / torch_output) + print("*") + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7cb8d65c12cd..6760ef17dae6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -680,7 +680,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ] num_tokens, _ = hidden_states.shape - E, N, K = w1.shape + E, N, _ = w1.shape # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE From df04bc0bc87c278e62b3aa29ec2548341889a276 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 31 Jan 2025 07:07:25 +0000 Subject: [PATCH 19/58] move stride tensor creation outside c++ code, cleanup Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 2 - csrc/cpu/torch_bindings.cpp | 15 --- csrc/ops.h | 13 ++- .../cutlass_w8a8/grouped_mm_c3x.cu | 92 +++---------------- .../cutlass_w8a8/scaled_mm_entry.cu | 29 +++--- csrc/torch_bindings.cpp | 3 +- tests/kernels/test_cutlass.py | 12 ++- vllm/_custom_ops.py | 7 -- .../layers/fused_moe/fused_moe.py | 22 ++++- 9 files changed, 66 insertions(+), 129 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index fc9ed3fa6023..49990f75183a 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -134,7 +134,6 @@ def replay_graph(graph): # Warmup for _ in range(num_warmup): - # fused_experts(a, w1, w2, topk_weights, topk_ids) fused_experts(a, w1_q_notransp, w2_q_notransp, @@ -147,7 +146,6 @@ def replay_graph(graph): results.append( benchmark.Timer( - # stmt="fused_experts(a, w1, w2, topk_weights, topk_ids)", stmt= "fused_experts(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_scale)", # noqa: E501 globals=globals, diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index fa3ab7afb946..6e8a32549864 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,21 +118,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // CUTLASS w8a8 grouped GEMM // TODO complete this - // ops.def( - // "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, - // " " Tensor b_scales, Tensor problem_sizes, " " - // Tensor out_offsets, Tensor a_offsets, " " Tensor - // b_offsets) -> ()"); - // ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); - - // ops.def( - // "compute_expert_offsets(Tensor! trg_a_ptrs," - // " Tensor! a, Tensor topk_ids," - // " Tensor! expert_offsets, SymInt num_experts) -> - // ()"); - // ops.impl("compute_expert_offsets", torch::kCUDA, - // &compute_expert_offsets); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. diff --git a/csrc/ops.h b/csrc/ops.h index ecb9a02ad54a..111cbb2a159b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -159,13 +159,12 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_grouped_mm(torch::Tensor& out_tensors, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes); +void cutlass_grouped_mm( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); void compute_expert_offsets(const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 44dd49a4475f..c3884dbf29be 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -142,13 +142,12 @@ cutlass::platform::unique_ptr> allocate_device_ptr( } template -void cutlass_group_gemm_caller(torch::Tensor& out_tensors, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes) { +void cutlass_group_gemm_caller( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; @@ -199,15 +198,6 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, using StrideB = Stride, Int<0>>; using StrideC = typename GemmKernel::InternalStrideC; - int64_t lda = a_tensors.stride(0); // row-major (m x k) - int64_t ldb = a_tensors.stride(0); // column-major (k x n) - int64_t ldc = out_tensors.stride(0); // row-major (m x n) - - // TODO move creation of these outside this function - torch::Tensor a_strides = torch::full({groups}, lda, options_int); - torch::Tensor b_strides = torch::full({groups}, ldb, options_int); - torch::Tensor c_strides = torch::full({groups}, ldc, options_int); - ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = reinterpret_cast( problem_sizes.data_ptr()); @@ -300,13 +290,12 @@ struct sm90_fp8_config_M64 { } // namespace // TODO -void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes) { +void cutlass_grouped_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); @@ -321,7 +310,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes); + problem_sizes, a_strides, b_strides, c_strides); } __global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, @@ -354,50 +343,6 @@ __global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, } } -// // For a given "a" of size [M,K] performs a permutation of the M rows based -// // on the given "perm" indices. -// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* -// __restrict__ a_ptr, -// int const* __restrict__ perm_int_ptr, -// cutlass::float_e4m3_t* __restrict__ -// out_ptr, int size_m, int size_k, int -// block_rows) { -// int start_row = block_rows * blockIdx.x; -// int finish_row = start_row + block_rows; -// if (finish_row > size_m) { -// finish_row = size_m; -// } -// int cur_block_rows = finish_row - start_row; - -// int row_stride = size_k * sizeof(cutlass::float_e4m3_t) / 16; - -// auto permute_row = [&](int row) { -// int iters = size_k / blockDim.x; -// int rest = size_k % blockDim.x; - -// int a_offset = perm_int_ptr[row] * row_stride; -// int out_offset = row * row_stride; - -// cutlass::float_e4m3_t const* a_row_fp8 = a_ptr + a_offset; -// cutlass::float_e4m3_t* out_fp8 = out_ptr + out_offset; - -// int base_k = 0; - -// for (int i = 0; i < iters; i++) { -// int cur_k = base_k + threadIdx.x; -// out_fp8[cur_k] = a_row_fp8[cur_k]; -// base_k += blockDim.x; -// } - -// if (rest) { -// if (threadIdx.x < rest) { -// int cur_k = base_k + threadIdx.x; -// out_fp8[cur_k] = a_row_fp8[cur_k]; -// } -// } -// }; -// } - void compute_expert_offsets_caller(const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, @@ -409,14 +354,3 @@ void compute_expert_offsets_caller(const torch::Tensor& topk_ids, (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), topk_ids.numel(), n, k); } - -// void permute_fp8_rows(torch::Tensor& a_ptr, -// torch::Tensor& perm_ptr, -// torch::Tensor& out_ptr, -// int size_m, int size_k, int topk, int block_rows) { -// permute_fp8_rows_kernel<<>>( -// (cutlass::float_e4m3_t const*)a_ptr.data_ptr(), -// (const int*)perm_ptr.data_ptr(), -// (cutlass::float_e4m3_t const*)out_ptr.data_ptr(), size_m * topk, -// size_k, block_rows); -// } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 2bca67073105..cf6a43e207df 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -30,13 +30,12 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes); +void cutlass_grouped_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); void compute_expert_offsets_caller(const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, @@ -159,15 +158,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_grouped_mm(torch::Tensor& out_tensors, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes) { +void cutlass_grouped_mm( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes); + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides); } void compute_expert_offsets(const torch::Tensor& topk_ids, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d73485449273..4faee8f0c1b9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -330,7 +330,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor a_tensors," " Tensor b_tensors, Tensor a_scales, " " Tensor b_scales, Tensor expert_offsets, " - " Tensor problem_sizes) -> ()"); + " Tensor problem_sizes, Tensor a_strides, " + " Tensor b_strides, Tensor c_strides) -> ()"); ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 914447884b16..9cac9f977f2f 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -535,11 +535,21 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, device=device, dtype=out_dtype) + ab_strides = torch.full((num_groups), + a_tensors_stacked.stride(0), + device="cuda", + dtype=torch.int64) + c_strides = torch.full((num_groups), + out_tensors_stacked.stride(0), + device="cuda", + dtype=torch.int64) + torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked, b_scales_tensors_stacked, - expert_offsets[:-1], problem_sizes) + expert_offsets[:-1], problem_sizes, + ab_strides, ab_strides, c_strides) # Validate each group's result against the baseline for g in range(num_groups): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6c1f5749ec5b..440bc52012ab 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -463,13 +463,6 @@ def cutlass_scaled_mm(a: torch.Tensor, return out -def cutlass_grouped_mm(out: List[torch.Tensor], a: List[torch.Tensor], - b: List[torch.Tensor], scale_a: List[torch.Tensor], - scale_b: List[torch.Tensor]) -> torch.Tensor: - torch.ops._C.cutlass_grouped_mm(out, a, b, scale_a, scale_b) - return out - - def cutlass_scaled_mm_azp(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6760ef17dae6..13e271df3fc3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -917,9 +917,18 @@ def cutlass_moe( c1 = torch.zeros((m * topk, n * 2), device="cuda", dtype=torch.half) c2 = torch.zeros((m * topk, k), device="cuda", dtype=torch.half) + ab_strides1 = torch.full((num_groups, ), + a_q.stride(0), + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_groups, ), + c1.stride(0), + device="cuda", + dtype=torch.int64) torch.ops._C.cutlass_grouped_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, - expert_offsets[:-1], problem_sizes1) + expert_offsets[:-1], problem_sizes1, + ab_strides1, ab_strides1, c_strides1) intermediate = torch.empty((m * topk, n), device="cuda", dtype=torch.half) torch.ops._C.silu_and_mul(intermediate, c1) @@ -927,9 +936,18 @@ def cutlass_moe( intemediate_q, intermediate_scales = ops.scaled_fp8_quant(intermediate) rep_intermediate_scales = intermediate_scales.repeat((num_groups, 1)) + ab_strides2 = torch.full((num_groups, ), + intemediate_q.stride(0), + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_groups, ), + c2.stride(0), + device="cuda", + dtype=torch.int64) torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, rep_intermediate_scales, w2_scale, - expert_offsets[:-1], problem_sizes2) + expert_offsets[:-1], problem_sizes2, + ab_strides2, ab_strides2, c_strides2) return (c2[a_map.argsort()].view(m, topk, k) * topk_weights.view(m, topk, 1).half()).sum(dim=1) From 88c713409e0cf008804f89295319bc01c982746f Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 31 Jan 2025 16:03:43 +0000 Subject: [PATCH 20/58] cleanup benchmark Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 90 +++++++++++-------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 49990f75183a..7b5bf178ae68 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -17,24 +17,11 @@ ] DEFAULT_BATCH_SIZES = [16, 32, 64, 128, 256, 512] -NUM_GROUPS_OPTS = [8] #[8, 64] PER_ACT_TOKEN_OPTS = [False] #[False, True] PER_OUT_CH_OPTS = [False] #[False, True] TOPKS = [2, 6] -def run_from_graph(a_q: torch.Tensor, a_scale: torch.Tensor, - w1_q: torch.Tensor, w2_q: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, - n: int, k: int, e: int): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e) - - def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) return torch.round(tensor.clamp( @@ -87,8 +74,45 @@ def bench_run(results: List[benchmark.Measurement], model: str, topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - def replay_graph(graph): - graph.replay() + def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + a_scale: torch.Tensor, num_repeats: int): + for _ in range(num_repeats): + fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) + + def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + m: int, n: int, k: int, num_experts: int, + num_repeats: int): + for _ in range(num_repeats): + cutlass_moe(a, a_scale, w1, w2, w1_scale, w2_scale, topk_weights, + topk_ids, m, n, k, num_experts) + + def run_from_graph(a_q: torch.Tensor, a_scale: torch.Tensor, + w1_q: torch.Tensor, w2_q: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + m: int, n: int, k: int, e: int): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, m, n, k, e) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() torch.cuda.synchronize() stream = torch.cuda.Stream() @@ -98,6 +122,9 @@ def replay_graph(graph): topk_weights, topk_ids, m, n, k, num_experts) torch.cuda.synchronize() + min_run_time = 5 + num_warmup = 5 + globals = { # Baseline params "a": a, @@ -124,30 +151,19 @@ def replay_graph(graph): "topk_weights": topk_weights, "topk_ids": topk_ids, # Kernels - "fused_experts": fused_experts, - "cutlass_moe": cutlass_moe, + "run_triton_moe": run_triton_moe, + "run_cutlass_moe": run_cutlass_moe, "replay_graph": replay_graph, } - min_run_time = 1 - num_warmup = 5 - # Warmup - for _ in range(num_warmup): - fused_experts(a, - w1_q_notransp, - w2_q_notransp, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) + run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, + w1_scale, w2_scale, a_scale, num_warmup) results.append( benchmark.Timer( stmt= - "fused_experts(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_scale)", # noqa: E501 + "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, 1)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -155,14 +171,13 @@ def replay_graph(graph): ).blocked_autorange(min_run_time=min_run_time)) # Warmup - for _ in range(num_warmup): - cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, num_experts) + run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, + topk_ids, m, n, k, num_experts, num_warmup) results.append( benchmark.Timer( stmt= - "cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts)", # noqa: E501 + "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts, 1)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -170,12 +185,11 @@ def replay_graph(graph): ).blocked_autorange(min_run_time=min_run_time)) # Warmup - for _ in range(num_warmup): - replay_graph(graph) + replay_graph(graph, num_warmup) results.append( benchmark.Timer( - stmt="replay_graph(graph)", + stmt="replay_graph(graph, 1)", globals=globals, label=label, sub_label=sub_label, From 02e1d4e31897dfe115435c102999d342f6ba6d03 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 4 Feb 2025 06:45:21 +0000 Subject: [PATCH 21/58] profile Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 97 ++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index b3ab7c486bd0..a5dbbbd7f0ca 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -5,6 +5,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, + fused_experts, fused_topk) from vllm.platforms import current_platform @@ -151,10 +152,104 @@ def test_cutlass_moe_cuda_graph( print(torch_output) print(cutlass_output) - # print((cutlass_output - torch_output) / torch_output) print("*") torch.testing.assert_close(torch_output, cutlass_output, atol=5e-2, rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) +@pytest.mark.parametrize("n", [128, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +def test_cutlass_moe_profile( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + a_q, a_scale = ops.scaled_fp8_quant(a) + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q_notransp = w1_q.clone() + w2_q_notransp = w2_q.clone() + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half() + w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half() + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + # ruff: noqa: SIM117 + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof_cutlass: + with torch.profiler.record_function("cutlass_output"): + cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, + w1_scale, w2_scale, topk_weights, + topk_ids, m, n, k, e) + print("profile cutlass:") + print( + prof_cutlass.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total", row_limit=50)) + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof_triton: + with torch.profiler.record_function("triton_output"): + triton_output = fused_experts(a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) + + print("profile triton:") + print( + prof_triton.key_averages(group_by_input_shape=True).table( + sort_by="cuda_time_total", row_limit=50)) + + # Uncomment to produce trace files + # cutlass_trace_name = f"trace_cutlass-{m}x{n}x{k}-{e}x{topk}.json" + # triton_trace_name = f"trace_triton-{m}x{n}x{k}-{e}x{topk}.json" + # prof_cutlass.export_chrome_trace(cutlass_trace_name) + # prof_triton.export_chrome_trace(triton_trace_name) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) From 1d9c429206534c6920261195b8d4038d2b8c973b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 14 Feb 2025 07:25:43 +0000 Subject: [PATCH 22/58] tuned shapes, fix Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 278 ++---------------- .../cutlass_w8a8/grouped_mm_c3x.cuh | 238 +++++++++++++++ tests/kernels/test_cutlass.py | 4 +- 3 files changed, 271 insertions(+), 249 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index c3884dbf29be..1ce24cff5795 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -5,26 +5,7 @@ #include "cutlass/cutlass.h" -// TODO clean up the includes we no longer need - -#include "cute/tensor.hpp" -#include "cutlass/tensor_ref.h" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" - -#include "cutlass_extensions/common.hpp" +#include "grouped_mm_c3x.cuh" using namespace cute; @@ -33,239 +14,29 @@ using namespace cute; #endif // for debugging -__global__ void print_elements(int64_t* tensor, int64_t elements) { - if (threadIdx.x == 0) { - for (int64_t i = 0; i < elements; ++i) { - printf("%ld/%ld ", i, tensor[i]); - } - printf("\n---\n"); - } -} +// __global__ void print_elements(int64_t* tensor, int64_t elements) { +// if (threadIdx.x == 0) { +// for (int64_t i = 0; i < elements; ++i) { +// printf("%ld/%ld ", i, tensor[i]); +// } +// printf("\n---\n"); +// } +// } namespace { -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); -#endif - } -}; - -using ProblemShape = - cutlass::gemm::GroupProblemShape>; -using ElementAB_Type = cutlass::float_e4m3_t; -using ElementC_Type = cutlass::half_t; - -using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; -using OperatorClass = cutlass::arch::OpClassTensorOp; - -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::RowMajor; - -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_group_gemm { - using ElementAB = ElementAB_; - using ElementC = ElementC_; - using ElementAccumulator = float; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementC, - ElementC, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideC = - cute::remove_pointer_t, cute::Int<0>>>; - - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, 4, ElementC, LayoutC*, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementAB, LayoutA*, 16, ElementAB, LayoutB*, - 16, ElementAccumulator, TileShape, ClusterShape, Stages, - KernelSchedule>::CollectiveOp; - - using KernelType = enable_sm90_or_later>; - - struct GemmKernel : public KernelType {}; -}; - -template -struct ItemDeleter { - void operator()(T* ptr) { - cudaFree(ptr); // noexcept - } -}; - -template -cutlass::platform::unique_ptr> make_device_ptr( - std::vector& data_host) { - T* data_device; - int count = data_host.size(); - cudaMalloc(&data_device, count * sizeof(T)); - cudaMemcpy(data_device, data_host.data(), count * sizeof(T), - cudaMemcpyHostToDevice); - return cutlass::platform::unique_ptr>(data_device); -} - -template -cutlass::platform::unique_ptr> allocate_device_ptr( - int count) { - T* data_device; - cudaMalloc(&data_device, count * sizeof(T)); - return cutlass::platform::unique_ptr>(data_device); -} - -template -void cutlass_group_gemm_caller( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - using ElementAB = typename Gemm::ElementAB; - using ElementC = typename Gemm::ElementC; - - int groups = (int)expert_offsets.size(0); - int k_size = a_tensors.size(1); - int n_size = out_tensors.size(1); - - bool per_act_token = a_scales.numel() != groups; - bool per_out_ch = b_scales.numel() != groups; - - int b_single_size = k_size * n_size * sizeof(ElementAB_Type); - int b_scale_single_size = - (per_out_ch ? out_tensors.size(1) : 1) * sizeof(ElementAccumulator); - - // TODO b and b scales pointers can be computed outside this function - auto options_int = - torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); - torch::Tensor a_ptrs_base = torch::full( - groups, reinterpret_cast(a_tensors.data_ptr()), options_int); - torch::Tensor out_ptrs_base = torch::full( - groups, reinterpret_cast(out_tensors.data_ptr()), options_int); - torch::Tensor a_scales_base = torch::full( - groups, reinterpret_cast(a_scales.data_ptr()), options_int); - - torch::Tensor a_scales_offsets = torch::arange(0, groups, options_int); - - torch::Tensor a_ptrs = a_ptrs_base.add( - expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); - torch::Tensor out_ptrs = out_ptrs_base.add( - expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1)); - - torch::Tensor a_scales_ptrs = - a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, - sizeof(ElementAccumulator)); - - int64_t b_tensor_base_addr = reinterpret_cast(b_tensors.data_ptr()); - int64_t b_scales_base_addr = reinterpret_cast(b_scales.data_ptr()); - - torch::Tensor b_ptrs = torch::arange( - b_tensor_base_addr, b_tensor_base_addr + b_single_size * groups, - b_single_size, options_int); - torch::Tensor b_scales_ptrs = torch::arange( - b_scales_base_addr, b_scales_base_addr + b_scale_single_size * groups, - b_scale_single_size, options_int); - - using GemmKernel = typename Gemm::GemmKernel; - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; - using StrideC = typename GemmKernel::InternalStrideC; - - ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = - reinterpret_cast( - problem_sizes.data_ptr()); - ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; - - typename GemmKernel::MainloopArguments mainloop_args{ - reinterpret_cast(a_ptrs.data_ptr()), - reinterpret_cast(a_strides.data_ptr()), - reinterpret_cast(b_ptrs.data_ptr()), - reinterpret_cast(b_strides.data_ptr())}; - - // Currently, we are only able to do broadcast on either all or none a_scales - // and on either all or none b_scales - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args(reinterpret_cast( - a_scales_ptrs.data_ptr()), - reinterpret_cast( - b_scales_ptrs.data_ptr()), - per_act_token, per_out_ch), - reinterpret_cast(out_ptrs.data_ptr()), - reinterpret_cast(c_strides.data_ptr()), - reinterpret_cast(out_ptrs.data_ptr()), - reinterpret_cast(c_strides.data_ptr())}; - - typename GemmKernel::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, - epilogue_args}; - - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - template typename Epilogue> struct sm90_fp8_config_default { - static_assert(std::is_same_v); - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - using Cutlass3xGemm = - cutlass_3x_group_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] + // M in (16, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using Cutlass3xGemm = cutlass_3x_group_gemm; @@ -273,14 +44,15 @@ struct sm90_fp8_config_M128 { template typename Epilogue> -struct sm90_fp8_config_M64 { +struct sm90_fp8_config_M16 { + // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; - using ClusterShape = cute::Shape; + using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm::Cutlass3xGemm; using Cutlass3xGemmDefault = typename sm90_fp8_config_default< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + + uint32_t const m = a_tensors.size(0); + + if (m <= 16) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } } __global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh new file mode 100644 index 000000000000..8d7ee64d04f0 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -0,0 +1,238 @@ +#include "cutlass/cutlass.h" + +// TODO clean up the includes we no longer need + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +#include "cutlass_extensions/common.hpp" + +using namespace cute; + +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + #define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +// for debugging +// __global__ void print_elements(int64_t* tensor, int64_t elements) { +// if (threadIdx.x == 0) { +// for (int64_t i = 0; i < elements; ++i) { +// printf("%ld/%ld ", i, tensor[i]); +// } +// printf("\n---\n"); +// } +// } + +namespace { + +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +using ProblemShape = + cutlass::gemm::GroupProblemShape>; +using ElementAB_Type = cutlass::float_e4m3_t; +using ElementC_Type = cutlass::half_t; + +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_group_gemm { + using ElementAB = ElementAB_; + using ElementC = ElementC_; + using ElementAccumulator = float; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementC, + ElementC, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; + + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, 4, ElementC, LayoutC*, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, 16, ElementAB, LayoutB*, + 16, ElementAccumulator, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_or_later>; + + struct GemmKernel : public KernelType {}; +}; + +template +struct ItemDeleter { + void operator()(T* ptr) { + cudaFree(ptr); // noexcept + } +}; + +template +cutlass::platform::unique_ptr> make_device_ptr( + std::vector& data_host) { + T* data_device; + int count = data_host.size(); + cudaMalloc(&data_device, count * sizeof(T)); + cudaMemcpy(data_device, data_host.data(), count * sizeof(T), + cudaMemcpyHostToDevice); + return cutlass::platform::unique_ptr>(data_device); +} + +template +cutlass::platform::unique_ptr> allocate_device_ptr( + int count) { + T* data_device; + cudaMalloc(&data_device, count * sizeof(T)); + return cutlass::platform::unique_ptr>(data_device); +} + +template +void cutlass_group_gemm_caller( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + using ElementAB = typename Gemm::ElementAB; + using ElementC = typename Gemm::ElementC; + + int groups = (int)expert_offsets.size(0); + int k_size = a_tensors.size(1); + int n_size = out_tensors.size(1); + + bool per_act_token = a_scales.numel() != groups; + bool per_out_ch = b_scales.numel() != groups; + + int b_single_size = k_size * n_size * sizeof(ElementAB_Type); + int b_scale_single_size = + (per_out_ch ? out_tensors.size(1) : 1) * sizeof(ElementAccumulator); + + // TODO b and b scales pointers can be computed outside this function + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); + torch::Tensor a_ptrs_base = torch::full( + groups, reinterpret_cast(a_tensors.data_ptr()), options_int); + torch::Tensor out_ptrs_base = torch::full( + groups, reinterpret_cast(out_tensors.data_ptr()), options_int); + torch::Tensor a_scales_base = torch::full( + groups, reinterpret_cast(a_scales.data_ptr()), options_int); + + torch::Tensor a_scales_offsets = torch::arange(0, groups, options_int); + + torch::Tensor a_ptrs = a_ptrs_base.add( + expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); + torch::Tensor out_ptrs = out_ptrs_base.add( + expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1)); + + torch::Tensor a_scales_ptrs = + a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, + sizeof(ElementAccumulator)); + + int64_t b_tensor_base_addr = reinterpret_cast(b_tensors.data_ptr()); + int64_t b_scales_base_addr = reinterpret_cast(b_scales.data_ptr()); + + torch::Tensor b_ptrs = torch::arange( + b_tensor_base_addr, b_tensor_base_addr + b_single_size * groups, + b_single_size, options_int); + torch::Tensor b_scales_ptrs = torch::arange( + b_scales_base_addr, b_scales_base_addr + b_scale_single_size * groups, + b_scale_single_size, options_int); + + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + reinterpret_cast( + problem_sizes.data_ptr()); + ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; + + typename GemmKernel::MainloopArguments mainloop_args{ + reinterpret_cast(a_ptrs.data_ptr()), + reinterpret_cast(a_strides.data_ptr()), + reinterpret_cast(b_ptrs.data_ptr()), + reinterpret_cast(b_strides.data_ptr())}; + + // Currently, we are only able to do broadcast on either all or none a_scales + // and on either all or none b_scales + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args(reinterpret_cast( + a_scales_ptrs.data_ptr()), + reinterpret_cast( + b_scales_ptrs.data_ptr()), + per_act_token, per_out_ch), + reinterpret_cast(out_ptrs.data_ptr()), + reinterpret_cast(c_strides.data_ptr()), + reinterpret_cast(out_ptrs.data_ptr()), + reinterpret_cast(c_strides.data_ptr())}; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args}; + + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +} // namespace diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 9cac9f977f2f..e9c0f21d2787 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -535,11 +535,11 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, device=device, dtype=out_dtype) - ab_strides = torch.full((num_groups), + ab_strides = torch.full((num_groups, ), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64) - c_strides = torch.full((num_groups), + c_strides = torch.full((num_groups, ), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64) From ae90eee9fcc826e9f6d50104413efb8140141618 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Feb 2025 13:53:56 +0000 Subject: [PATCH 23/58] Performance, add channelwise scales everywhere Signed-off-by: ElizaWszola --- csrc/ops.h | 13 +-- .../cutlass_w8a8/grouped_mm_c3x.cu | 70 ++++++++++---- .../cutlass_w8a8/grouped_mm_c3x.cuh | 93 ++++++++++--------- .../cutlass_w8a8/scaled_mm_entry.cu | 29 +++--- csrc/torch_bindings.cpp | 9 +- tests/kernels/test_cutlass.py | 32 ++++--- tests/kernels/test_cutlass_moe.py | 50 ++++++++-- .../layers/fused_moe/fused_moe.py | 26 ++++-- 8 files changed, 202 insertions(+), 120 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 8d153a70f160..e1d872c033a4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -167,12 +167,13 @@ void cutlass_grouped_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void compute_expert_offsets(const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const int64_t num_experts, const int64_t n, - const int64_t k); +void get_grouped_mm_data(const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, + torch::Tensor& arg_sort_prim, + const int64_t num_experts, const int64_t n, + const int64_t k); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 1ce24cff5795..2e948d7d7a88 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -4,6 +4,7 @@ #include #include "cutlass/cutlass.h" +#include "grouped_mm_c3x.cuh" #include "grouped_mm_c3x.cuh" @@ -13,16 +14,6 @@ using namespace cute; #define ENABLE_SM90_KERNEL_LEVEL 1 #endif -// for debugging -// __global__ void print_elements(int64_t* tensor, int64_t elements) { -// if (threadIdx.x == 0) { -// for (int64_t i = 0; i < elements; ++i) { -// printf("%ld/%ld ", i, tensor[i]); -// } -// printf("\n---\n"); -// } -// } - namespace { template ; }; +template typename Epilogue> +struct sm90_fp8_config_N8192 { + // N in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + } // namespace -// TODO void cutlass_grouped_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, @@ -77,6 +84,9 @@ void cutlass_grouped_mm_sm90( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); + using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< + ElementAB_Type, ElementC_Type, + vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; @@ -85,8 +95,14 @@ void cutlass_grouped_mm_sm90( vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + // uint32_t const k = a_tensors.size(1); - if (m <= 16) { + if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else if (m <= 16) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); @@ -100,7 +116,8 @@ void cutlass_grouped_mm_sm90( __global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, int32_t* expert_offsets, int32_t* problem_sizes1, - int32_t* problem_sizes2, int topk_length, + int32_t* problem_sizes2, int32_t* arg_sort, + int32_t* arg_sort_prim, int topk_length, int n, int k) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -125,16 +142,31 @@ __global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, expert_offsets[i + 1] = tot_offset; } } + + __syncthreads(); + + int start = expert_offsets[expert_id]; + int end = expert_offsets[expert_id + 1]; + for (int i = 0; i < topk_length; ++i) { + if (topk_ids[i] == expert_id) { + arg_sort[start] = i; + arg_sort_prim[i] = start; + ++start; + if (start == end) { + break; + } + } + } } -void compute_expert_offsets_caller(const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const int64_t num_experts, const int64_t n, - const int64_t k) { +void get_grouped_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, + const int64_t num_experts, const int64_t n, const int64_t k) { get_a_expert_offsets<<<1, num_experts>>>( (const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), + (int32_t*)arg_sort.data_ptr(), (int32_t*)arg_sort_prim.data_ptr(), topk_ids.numel(), n, k); } diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 8d7ee64d04f0..2a61f7aaedea 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass/cutlass.h" // TODO clean up the includes we no longer need @@ -31,23 +33,36 @@ using namespace cute; // __global__ void print_elements(int64_t* tensor, int64_t elements) { // if (threadIdx.x == 0) { // for (int64_t i = 0; i < elements; ++i) { -// printf("%ld/%ld ", i, tensor[i]); +// printf("%ld ", tensor[i]); // } // printf("\n---\n"); // } // } -namespace { +__global__ void get_group_gemm_starts( + int32_t* expert_offsets, int64_t* a_offsets, int64_t* b_offsets, + int64_t* out_offsets, int64_t* a_scales_offsets, int64_t* b_scales_offsets, + const int64_t a_base_as_int, const int64_t b_base_as_int, + const int64_t out_base_as_int, const int64_t a_scales_base_as_int, + const int64_t b_scales_base_as_int, int n, int k, bool per_act_token, + bool per_out_ch, int64_t ab_size, int64_t c_size, int64_t acc_size) { + int expert_id = threadIdx.x; + // int num_experts = blockDim.x; + + int expert_offset = expert_offsets[expert_id]; + + a_offsets[expert_id] = a_base_as_int + expert_offset * k * ab_size; + b_offsets[expert_id] = b_base_as_int + expert_id * k * n * ab_size; + out_offsets[expert_id] = out_base_as_int + expert_offset * n * c_size; + a_scales_offsets[expert_id] = + a_scales_base_as_int + + (per_act_token ? expert_offset : /*expert_id*/ 0) * acc_size; + b_scales_offsets[expert_id] = + b_scales_base_as_int + + (per_out_ch ? n * expert_id : expert_id) * acc_size; +} -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); -#endif - } -}; +namespace { using ProblemShape = cutlass::gemm::GroupProblemShape>; @@ -150,43 +165,34 @@ void cutlass_group_gemm_caller( int k_size = a_tensors.size(1); int n_size = out_tensors.size(1); - bool per_act_token = a_scales.numel() != groups; + bool per_act_token = a_scales.numel() != 1; bool per_out_ch = b_scales.numel() != groups; - int b_single_size = k_size * n_size * sizeof(ElementAB_Type); - int b_scale_single_size = - (per_out_ch ? out_tensors.size(1) : 1) * sizeof(ElementAccumulator); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - // TODO b and b scales pointers can be computed outside this function auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); - torch::Tensor a_ptrs_base = torch::full( - groups, reinterpret_cast(a_tensors.data_ptr()), options_int); - torch::Tensor out_ptrs_base = torch::full( - groups, reinterpret_cast(out_tensors.data_ptr()), options_int); - torch::Tensor a_scales_base = torch::full( - groups, reinterpret_cast(a_scales.data_ptr()), options_int); - - torch::Tensor a_scales_offsets = torch::arange(0, groups, options_int); - - torch::Tensor a_ptrs = a_ptrs_base.add( - expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); - torch::Tensor out_ptrs = out_ptrs_base.add( - expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1)); - - torch::Tensor a_scales_ptrs = - a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, - sizeof(ElementAccumulator)); - - int64_t b_tensor_base_addr = reinterpret_cast(b_tensors.data_ptr()); - int64_t b_scales_base_addr = reinterpret_cast(b_scales.data_ptr()); - - torch::Tensor b_ptrs = torch::arange( - b_tensor_base_addr, b_tensor_base_addr + b_single_size * groups, - b_single_size, options_int); - torch::Tensor b_scales_ptrs = torch::arange( - b_scales_base_addr, b_scales_base_addr + b_scale_single_size * groups, - b_scale_single_size, options_int); + + torch::Tensor a_ptrs = torch::empty(groups, options_int); + torch::Tensor b_ptrs = torch::empty(groups, options_int); + torch::Tensor out_ptrs = torch::empty(groups, options_int); + torch::Tensor a_scales_ptrs = torch::empty(groups, options_int); + torch::Tensor b_scales_ptrs = torch::empty(groups, options_int); + + get_group_gemm_starts<<<1, groups, 0, stream>>>( + reinterpret_cast(expert_offsets.data_ptr()), + reinterpret_cast(a_ptrs.data_ptr()), + reinterpret_cast(b_ptrs.data_ptr()), + reinterpret_cast(out_ptrs.data_ptr()), + reinterpret_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(a_tensors.data_ptr()), + reinterpret_cast(b_tensors.data_ptr()), + reinterpret_cast(out_tensors.data_ptr()), + reinterpret_cast(a_scales.data_ptr()), + reinterpret_cast(b_scales.data_ptr()), out_tensors.size(1), + a_tensors.size(1), per_act_token, per_out_ch, sizeof(ElementAB_Type), + sizeof(ElementC_Type), sizeof(ElementAccumulator)); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -230,7 +236,6 @@ void cutlass_group_gemm_caller( torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 60b15a1cc57f..446641967630 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -37,12 +37,11 @@ void cutlass_grouped_mm_sm90( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void compute_expert_offsets_caller(const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const int64_t num_experts, const int64_t n, - const int64_t k); +void get_grouped_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, + const int64_t num_experts, const int64_t n, const int64_t k); #endif @@ -179,14 +178,16 @@ void cutlass_grouped_mm( c_strides); } -void compute_expert_offsets(const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const int64_t num_experts, const int64_t n, - const int64_t k) { - compute_expert_offsets_caller(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, num_experts, n, k); +void get_grouped_mm_data(const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, + torch::Tensor& arg_sort_prim, + const int64_t num_experts, const int64_t n, + const int64_t k) { + get_grouped_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, arg_sort, arg_sort_prim, + num_experts, n, k); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2cdce4a4eb11..9158ed0c9088 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -335,10 +335,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( - "compute_expert_offsets(Tensor topk_ids, Tensor! expert_offsets, " - " Tensor! problem_sizes1, Tensor! problem_sizes2, " - " SymInt num_experts, SymInt n, SymInt k) -> ()"); - ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); + "get_grouped_mm_data(Tensor topk_ids, Tensor! expert_offsets, " + " Tensor! problem_sizes1, Tensor! problem_sizes2, " + " Tensor! arg_sort, Tensor! arg_sort_prim, " + " SymInt num_experts, SymInt n, SymInt k) -> ()"); + ops.impl("get_grouped_mm_data", torch::kCUDA, &get_grouped_mm_data); // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) ops.def( diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 87b6a19730b7..9653f33a488a 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -540,6 +540,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, device=device, dtype=torch.int32) + if not per_act_token: + one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32) + alignment = 16 # 128 // 8 # For variation, each group has dimensions n_g = alignment * random.randint(1, 64) @@ -560,19 +563,23 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, # Create group-specific A and B (FP8) and output (FP16/FP32) a_g = to_fp8(torch.randn((m_g, k_g), device=device)) b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + # Set up A/B scales - scale_a = torch.randn((m_a_scales, 1), - device=device, - dtype=torch.float32) scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32) - - a_tensors.append(a_g) - b_tensors.append(b_g) - a_scales_tensors.append(scale_a) b_scales_tensors.append(scale_b) + if per_act_token: + scale_a = torch.randn((m_a_scales, 1), + device=device, + dtype=torch.float32) + a_scales_tensors.append(scale_a) + else: + scale_a = one_scale_a + # Compute baseline result for this group baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) @@ -584,23 +591,22 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, b_tensors_stacked = torch.empty((num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn) + for g in range(num_groups): a_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] = a_tensors[g] b_tensors_stacked[g] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.transpose(1, 2) - a_scales_tensors_stacked = torch.empty( - (expert_offsets[num_groups] if per_act_token else num_groups, 1), - device=device, - dtype=torch.float32) if per_act_token: + a_scales_tensors_stacked = torch.empty((expert_offsets[num_groups], 1), + device=device, + dtype=torch.float32) for g in range(num_groups): a_scales_tensors_stacked[ expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] else: - for g in range(num_groups): - a_scales_tensors_stacked[g] = a_scales_tensors[g] + a_scales_tensors_stacked = one_scale_a b_scales_tensors_stacked = torch.empty((num_groups, n_b_scales), device=device, diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index a5dbbbd7f0ca..b8b412dabd92 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import pytest import torch @@ -29,12 +30,16 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) def test_cutlass_moe_no_graph( m: int, n: int, k: int, e: int, topk: int, + per_act_token: bool, + per_out_ch: bool, ): current_platform.seed_everything(7) with set_current_vllm_config( @@ -47,18 +52,28 @@ def test_cutlass_moe_no_graph( w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - a_q, a_scale = ops.scaled_fp8_quant(a) + a_q, a_scale = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn) w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) - w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) a_d = (a_q.float() * a_scale).half() @@ -94,12 +109,16 @@ def test_cutlass_moe_no_graph( @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) def test_cutlass_moe_cuda_graph( m: int, n: int, k: int, e: int, topk: int, + per_act_token: bool, + per_out_ch: bool, ): current_platform.seed_everything(7) with set_current_vllm_config( @@ -112,18 +131,28 @@ def test_cutlass_moe_cuda_graph( w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - a_q, a_scale = ops.scaled_fp8_quant(a) + a_q, a_scale = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn) w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) - w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) a_d = (a_q.float() * a_scale).half() @@ -160,6 +189,7 @@ def test_cutlass_moe_cuda_graph( rtol=1e-2) +@pytest.mark.skip("profiling only") @pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6db72a11e967..780a97f2c267 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1379,6 +1379,7 @@ def cutlass_moe( num_groups: int, ): topk = topk_ids.shape[1] + per_act_token = a_scale.numel() != 1 expert_offsets = torch.empty((num_groups + 1), dtype=torch.int32, @@ -1390,14 +1391,19 @@ def cutlass_moe( dtype=torch.int32, device="cuda") - a_map = topk_ids.flatten().argsort() + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device="cuda") + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device="cuda") + + torch.ops._C.get_grouped_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, num_groups, + n, k) + rep_a_q = a_q.repeat_interleave( topk, dim=0).view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a_scales = a_scale.repeat((num_groups, 1)) - - torch.ops._C.compute_expert_offsets(topk_ids.cuda(), expert_offsets, - problem_sizes1, problem_sizes2, - num_groups, n, k) + if per_act_token: + rep_a_scales = a_scale.repeat_interleave(topk, dim=0)[a_map] + else: + rep_a_scales = a_scale c1 = torch.zeros((m * topk, n * 2), device="cuda", dtype=torch.half) c2 = torch.zeros((m * topk, k), device="cuda", dtype=torch.half) @@ -1417,8 +1423,8 @@ def cutlass_moe( intermediate = torch.empty((m * topk, n), device="cuda", dtype=torch.half) torch.ops._C.silu_and_mul(intermediate, c1) - intemediate_q, intermediate_scales = ops.scaled_fp8_quant(intermediate) - rep_intermediate_scales = intermediate_scales.repeat((num_groups, 1)) + intemediate_q, intermediate_scales = ops.scaled_fp8_quant( + intermediate, use_per_token_if_dynamic=per_act_token) ab_strides2 = torch.full((num_groups, ), intemediate_q.stride(0), @@ -1429,9 +1435,9 @@ def cutlass_moe( device="cuda", dtype=torch.int64) torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, - rep_intermediate_scales, w2_scale, + intermediate_scales, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) - return (c2[a_map.argsort()].view(m, topk, k) * + return (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).half()).sum(dim=1) From f191b3579264a79cbd7a195b56e37fc3aa2d933e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 20 Feb 2025 06:56:28 +0000 Subject: [PATCH 24/58] name fix Signed-off-by: ElizaWszola --- csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 2e948d7d7a88..1166022785c8 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -113,12 +113,12 @@ void cutlass_grouped_mm_sm90( } } -__global__ void get_a_expert_offsets(const int* __restrict__ topk_ids, - int32_t* expert_offsets, - int32_t* problem_sizes1, - int32_t* problem_sizes2, int32_t* arg_sort, - int32_t* arg_sort_prim, int topk_length, - int n, int k) { +__global__ void get_grouped_mm_data(const int* __restrict__ topk_ids, + int32_t* expert_offsets, + int32_t* problem_sizes1, + int32_t* problem_sizes2, int32_t* arg_sort, + int32_t* arg_sort_prim, int topk_length, + int n, int k) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -164,7 +164,7 @@ void get_grouped_mm_data_caller( torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, const int64_t num_experts, const int64_t n, const int64_t k) { - get_a_expert_offsets<<<1, num_experts>>>( + get_grouped_mm_data<<<1, num_experts>>>( (const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), (int32_t*)arg_sort.data_ptr(), (int32_t*)arg_sort_prim.data_ptr(), From 51941ff96004a0d1ec86702acc4dd4ba73268648 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 20 Feb 2025 16:01:56 +0000 Subject: [PATCH 25/58] perf improvements in data preparation Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 6 +- .../cutlass_w8a8/grouped_mm_c3x.cuh | 17 +++--- tests/kernels/test_cutlass_moe.py | 8 ++- .../layers/fused_moe/fused_moe.py | 60 +++++++++++-------- 4 files changed, 55 insertions(+), 36 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 1166022785c8..cf1ce726b403 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -118,7 +118,7 @@ __global__ void get_grouped_mm_data(const int* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* arg_sort, int32_t* arg_sort_prim, int topk_length, - int n, int k) { + int n, int k, int topk) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -149,7 +149,7 @@ __global__ void get_grouped_mm_data(const int* __restrict__ topk_ids, int end = expert_offsets[expert_id + 1]; for (int i = 0; i < topk_length; ++i) { if (topk_ids[i] == expert_id) { - arg_sort[start] = i; + arg_sort[start] = i / topk; arg_sort_prim[i] = start; ++start; if (start == end) { @@ -168,5 +168,5 @@ void get_grouped_mm_data_caller( (const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), (int32_t*)arg_sort.data_ptr(), (int32_t*)arg_sort_prim.data_ptr(), - topk_ids.numel(), n, k); + topk_ids.numel(), n, k, topk_ids.size(1)); } diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 2a61f7aaedea..35d069ca11b7 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -83,21 +83,24 @@ template struct cutlass_3x_group_gemm { using ElementAB = ElementAB_; - using ElementC = ElementC_; + // TODO check if this works + using ElementC = void; + using ElementD = ElementC_; + // using ElementC = ElementC_; using ElementAccumulator = float; using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementC, - ElementC, EpilogueSchedule>; + ElementD, EpilogueSchedule>; - using Epilogue = Epilogue_; + using Epilogue = Epilogue_; using StrideC = cute::remove_pointer_t, cute::Int<0>>>; const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; @@ -105,7 +108,7 @@ struct cutlass_3x_group_gemm { typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, 4, ElementC, LayoutC*, 4, + ElementAccumulator, ElementC, LayoutC*, 4, ElementD, LayoutC*, 4, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = @@ -160,6 +163,7 @@ void cutlass_group_gemm_caller( torch::Tensor const& b_strides, torch::Tensor const& c_strides) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; int groups = (int)expert_offsets.size(0); int k_size = a_tensors.size(1); @@ -218,8 +222,7 @@ void cutlass_group_gemm_caller( reinterpret_cast( b_scales_ptrs.data_ptr()), per_act_token, per_out_ch), - reinterpret_cast(out_ptrs.data_ptr()), - reinterpret_cast(c_strides.data_ptr()), + nullptr, reinterpret_cast(c_strides.data_ptr()), reinterpret_cast(out_ptrs.data_ptr()), reinterpret_cast(c_strides.data_ptr())}; diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index b8b412dabd92..f2404b4b27fb 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -69,6 +69,11 @@ def test_cutlass_moe_no_graph( device="cuda", dtype=torch.float32) + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( w1[expert], use_per_token_if_dynamic=per_out_ch) @@ -92,7 +97,8 @@ def test_cutlass_moe_no_graph( torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, - e) + e, ab_strides1, c_strides1, ab_strides2, + c_strides2) print(torch_output) print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 780a97f2c267..693bf157dee7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1268,6 +1268,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, return out_hidden_states +#TODO make the grouped gemm kernel consistent with scaled gemm kernel def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1377,6 +1378,10 @@ def cutlass_moe( n: int, k: int, num_groups: int, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, ): topk = topk_ids.shape[1] per_act_token = a_scale.numel() != 1 @@ -1398,23 +1403,27 @@ def cutlass_moe( problem_sizes2, a_map, c_map, num_groups, n, k) - rep_a_q = a_q.repeat_interleave( - topk, dim=0).view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - if per_act_token: - rep_a_scales = a_scale.repeat_interleave(topk, dim=0)[a_map] - else: - rep_a_scales = a_scale - - c1 = torch.zeros((m * topk, n * 2), device="cuda", dtype=torch.half) - c2 = torch.zeros((m * topk, k), device="cuda", dtype=torch.half) - ab_strides1 = torch.full((num_groups, ), - a_q.stride(0), - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_groups, ), - c1.stride(0), - device="cuda", - dtype=torch.int64) + # TODO use extend here, or try to create map without repeating and + # interleaving + # TODO reuse MoE align_block kernel? + # rep_a_q = a_q.repeat_interleave( + # topk, dim=0).view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a_scales = a_scale[a_map] if per_act_token else a_scale + + # TODO check if we need zeros here + c1 = torch.empty((m * topk, n * 2), device="cuda", dtype=torch.half) + c2 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) + # TODO move stride creation outside this function, they're going to be + # constant for all calls + # ab_strides1 = torch.full((num_groups, ), + # a_q.stride(0), + # device="cuda", + # dtype=torch.int64) + # c_strides1 = torch.full((num_groups, ), + # c1.stride(0), + # device="cuda", + # dtype=torch.int64) torch.ops._C.cutlass_grouped_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, expert_offsets[:-1], problem_sizes1, @@ -1426,14 +1435,15 @@ def cutlass_moe( intemediate_q, intermediate_scales = ops.scaled_fp8_quant( intermediate, use_per_token_if_dynamic=per_act_token) - ab_strides2 = torch.full((num_groups, ), - intemediate_q.stride(0), - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_groups, ), - c2.stride(0), - device="cuda", - dtype=torch.int64) + # ab_strides2 = torch.full((num_groups, ), + # intemediate_q.stride(0), + # device="cuda", + # dtype=torch.int64) + # c_strides2 = torch.full((num_groups, ), + # c2.stride(0), + # device="cuda", + # dtype=torch.int64) + torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, intermediate_scales, w2_scale, expert_offsets[:-1], problem_sizes2, From d3cf1db4cfacbd3bfff90edb395b0c5c5393e9ad Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 24 Feb 2025 15:28:34 +0000 Subject: [PATCH 26/58] Integrate with deepseek v2 Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 35 +++++--- .../layers/fused_moe/__init__.py | 5 +- .../compressed_tensors_moe.py | 76 ++++++++++++++--- .../scaled_mm/GroupedMMLinearKernel.py | 60 ++++++++++++++ .../kernels/scaled_mm/grouped_cutlass.py | 81 +++++++++++++++++++ 5 files changed, 233 insertions(+), 24 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index f2404b4b27fb..45407e59ddd3 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -17,12 +17,14 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, - k: int, e: int): + k: int, e: int, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e) + topk_weights, topk_ids, m, n, k, e, ab_strides1, + c_strides1, ab_strides2, c_strides2) @pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) @@ -82,8 +84,11 @@ def test_cutlass_moe_no_graph( w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) a_d = (a_q.float() * a_scale).half() - w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half() - w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half() + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) @@ -162,8 +167,11 @@ def test_cutlass_moe_cuda_graph( w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) a_d = (a_q.float() * a_scale).half() - w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half() - w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half() + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) @@ -180,7 +188,9 @@ def test_cutlass_moe_cuda_graph( graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): cutlass_output = run(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e) + topk_weights, topk_ids, m, n, k, e, + ab_strides1, c_strides1, ab_strides2, + c_strides2) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() @@ -235,8 +245,11 @@ def test_cutlass_moe_profile( w2_q_notransp = w2_q.clone() w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) - w1_d = (w1_q.transpose(1, 2).float() * w1_scale).half() - w2_d = (w2_q.transpose(1, 2).float() * w2_scale).half() + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) @@ -254,7 +267,9 @@ def test_cutlass_moe_profile( with torch.profiler.record_function("cutlass_output"): cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, e) + topk_ids, m, n, k, e, ab_strides1, + c_strides1, ab_strides2, + c_strides2) print("profile cutlass:") print( prof_cutlass.key_averages(group_by_input_shape=True).table( diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 6f933c3fa3c9..c29849ce707d 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,8 +36,8 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + cutlass_moe, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -45,4 +45,5 @@ def get_config() -> Optional[Dict[str, Any]]: "fused_experts", "get_config_file_name", "grouped_topk", + "cutlass_moe", ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index db8e8a4b6c11..10cbeea23c6c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -139,6 +139,34 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.w13_input_scale = None layer.w2_input_scale = None + # TODO strides can be shared across multiple layers + ab_strides1 = torch.nn.Parameter(torch.full((num_experts, ), + hidden_size, + device="cuda", + dtype=torch.int64), + requires_grad=False) + c_strides1 = torch.nn.Parameter(torch.full( + (num_experts, ), + 2 * intermediate_size_per_partition, + device="cuda", + dtype=torch.int64), + requires_grad=False) + ab_strides2 = torch.nn.Parameter(torch.full( + (num_experts, ), + intermediate_size_per_partition, + device="cuda", + dtype=torch.int64), + requires_grad=False) + c_strides2 = torch.nn.Parameter(torch.full((num_experts, ), + hidden_size, + device="cuda", + dtype=torch.int64), + requires_grad=False) + layer.register_parameter("ab_strides1", ab_strides1) + layer.register_parameter("c_strides1", c_strides1) + layer.register_parameter("ab_strides2", ab_strides2) + layer.register_parameter("c_strides2", c_strides2) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. @@ -218,7 +246,7 @@ def apply( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts + from vllm.model_executor.layers.fused_moe import cutlass_moe topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -232,17 +260,41 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + # TODO + x_q, x_scale = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=False) + # print(x_q.shape, x_scale.shape, + # layer.w13_weight.shape, layer.w2_weight.shape, + # layer.w13_weight_scale.shape, layer.w2_weight_scale.shape) + return cutlass_moe( + x_q, + x_scale, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + x.shape[0], + layer.w2_weight.shape[2], + x.shape[1], + layer.w13_weight.shape[0], + layer.ab_strides1, + layer.c_strides1, + layer.ab_strides2, + layer.c_strides2, + ).bfloat16() + + # return fused_experts(x, + # layer.w13_weight, + # layer.w2_weight, + # topk_weights=topk_weights, + # topk_ids=topk_ids, + # inplace=True, + # use_fp8_w8a8=True, + # w1_scale=layer.w13_weight_scale, + # w2_scale=layer.w2_weight_scale, + # a1_scale=layer.w13_input_scale, + # a2_scale=layer.w2_input_scale) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py new file mode 100644 index 000000000000..3293e2482d6d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + + +#always symmetric for now +@dataclass +class GroupedMMLinearLayerConfig: + is_per_act_token: bool + is_per_out_ch: bool + is_static_input_scheme: bool + + +class GroupedMMLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement( + cls, c: GroupedMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, c: GroupedMMLinearLayerConfig, w_q_param_name: str, + w_s_param_name: str, i_s_param_name: str) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.i_s_name = i_s_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _get_weight_params( + self, layer: torch.nn.Module) -> Tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.i_s_name), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py new file mode 100644 index 000000000000..dafff95377ab --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.platforms import current_platform + +from .GroupedMMLinearKernel import (GroupedMMLinearKernel, + GroupedMMLinearLayerConfig) + + +class CutlassGroupMMLinearKernel(GroupedMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: GroupedMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and not current_platform.is_cpu()): + return False, "CutlassScaledMM requires running on CUDA or CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + # is_fused_module = len(layer.logical_widths) > 1 + # weight_scale = getattr(layer, self.w_s_name) + # if is_fused_module and not self.config.is_per_out_ch: + # weight_scale = convert_to_channelwise(weight_scale, + # layer.logical_widths) + # if is_fused_module and not self.config.is_per_act_token: + # input_scale = convert_to_channelwise(weight_scale, + # layer.logical_widths) + # replace_parameter( + # layer, self.w_s_name, + # torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, None, symmetric=True) + + return ops.cutlass_grouped_mm(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + bias=bias) From 175ecdd8ce90db14679b3f7f3654d827f04f586d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 24 Feb 2025 18:01:21 +0000 Subject: [PATCH 27/58] cudagraphs fix Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 3 +- .../layers/fused_moe/fused_moe.py | 42 ++++--------------- .../compressed_tensors_moe.py | 9 ++-- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index cf1ce726b403..187b18d9c55b 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -164,7 +164,8 @@ void get_grouped_mm_data_caller( torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, const int64_t num_experts, const int64_t n, const int64_t k) { - get_grouped_mm_data<<<1, num_experts>>>( + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + get_grouped_mm_data<<<1, num_experts, 0, stream>>>( (const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), (int32_t*)arg_sort.data_ptr(), (int32_t*)arg_sort_prim.data_ptr(), diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 693bf157dee7..64cfbad50e66 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1385,65 +1385,41 @@ def cutlass_moe( ): topk = topk_ids.shape[1] per_act_token = a_scale.numel() != 1 + device = a_q.device expert_offsets = torch.empty((num_groups + 1), dtype=torch.int32, - device="cuda") + device=device) problem_sizes1 = torch.empty((num_groups, 3), dtype=torch.int32, - device="cuda") + device=device) problem_sizes2 = torch.empty((num_groups, 3), dtype=torch.int32, - device="cuda") + device=device) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device="cuda") - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device="cuda") + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) torch.ops._C.get_grouped_mm_data(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, num_groups, n, k) - # TODO use extend here, or try to create map without repeating and - # interleaving - # TODO reuse MoE align_block kernel? - # rep_a_q = a_q.repeat_interleave( - # topk, dim=0).view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) rep_a_scales = a_scale[a_map] if per_act_token else a_scale - # TODO check if we need zeros here - c1 = torch.empty((m * topk, n * 2), device="cuda", dtype=torch.half) - c2 = torch.empty((m * topk, k), device="cuda", dtype=torch.half) - # TODO move stride creation outside this function, they're going to be - # constant for all calls - # ab_strides1 = torch.full((num_groups, ), - # a_q.stride(0), - # device="cuda", - # dtype=torch.int64) - # c_strides1 = torch.full((num_groups, ), - # c1.stride(0), - # device="cuda", - # dtype=torch.int64) + c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) + c2 = torch.empty((m * topk, k), device=device, dtype=torch.half) torch.ops._C.cutlass_grouped_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, ab_strides1, c_strides1) - intermediate = torch.empty((m * topk, n), device="cuda", dtype=torch.half) + intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) torch.ops._C.silu_and_mul(intermediate, c1) intemediate_q, intermediate_scales = ops.scaled_fp8_quant( intermediate, use_per_token_if_dynamic=per_act_token) - # ab_strides2 = torch.full((num_groups, ), - # intemediate_q.stride(0), - # device="cuda", - # dtype=torch.int64) - # c_strides2 = torch.full((num_groups, ), - # c2.stride(0), - # device="cuda", - # dtype=torch.int64) - torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, intermediate_scales, w2_scale, expert_offsets[:-1], problem_sizes2, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 10cbeea23c6c..2a89dad4a019 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -139,27 +139,28 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.w13_input_scale = None layer.w2_input_scale = None + device = w13_weight.device # TODO strides can be shared across multiple layers ab_strides1 = torch.nn.Parameter(torch.full((num_experts, ), hidden_size, - device="cuda", + device=device, dtype=torch.int64), requires_grad=False) c_strides1 = torch.nn.Parameter(torch.full( (num_experts, ), 2 * intermediate_size_per_partition, - device="cuda", + device=device, dtype=torch.int64), requires_grad=False) ab_strides2 = torch.nn.Parameter(torch.full( (num_experts, ), intermediate_size_per_partition, - device="cuda", + device=device, dtype=torch.int64), requires_grad=False) c_strides2 = torch.nn.Parameter(torch.full((num_experts, ), hidden_size, - device="cuda", + device=device, dtype=torch.int64), requires_grad=False) layer.register_parameter("ab_strides1", ab_strides1) From ec0cb941ff273765032a890152bd840b87b9bd70 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 25 Feb 2025 14:12:03 +0000 Subject: [PATCH 28/58] larger index type to support very large batches Signed-off-by: ElizaWszola --- csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh | 11 +++++------ tests/kernels/test_cutlass_moe.py | 4 ++-- .../compressed_tensors/compressed_tensors_moe.py | 15 --------------- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 35d069ca11b7..84723bc2aba8 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -44,19 +44,18 @@ __global__ void get_group_gemm_starts( int64_t* out_offsets, int64_t* a_scales_offsets, int64_t* b_scales_offsets, const int64_t a_base_as_int, const int64_t b_base_as_int, const int64_t out_base_as_int, const int64_t a_scales_base_as_int, - const int64_t b_scales_base_as_int, int n, int k, bool per_act_token, - bool per_out_ch, int64_t ab_size, int64_t c_size, int64_t acc_size) { + const int64_t b_scales_base_as_int, int64_t n, int64_t k, + bool per_act_token, bool per_out_ch, int64_t ab_size, int64_t c_size, + int64_t acc_size) { int expert_id = threadIdx.x; - // int num_experts = blockDim.x; - int expert_offset = expert_offsets[expert_id]; + int64_t expert_offset = expert_offsets[expert_id]; a_offsets[expert_id] = a_base_as_int + expert_offset * k * ab_size; b_offsets[expert_id] = b_base_as_int + expert_id * k * n * ab_size; out_offsets[expert_id] = out_base_as_int + expert_offset * n * c_size; a_scales_offsets[expert_id] = - a_scales_base_as_int + - (per_act_token ? expert_offset : /*expert_id*/ 0) * acc_size; + a_scales_base_as_int + (per_act_token ? expert_offset : 0) * acc_size; b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id) * acc_size; diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 45407e59ddd3..a067326f78fd 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -27,7 +27,7 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, c_strides1, ab_strides2, c_strides2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 163840]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -115,7 +115,7 @@ def test_cutlass_moe_no_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 163840]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 2a89dad4a019..9f8e3fb7c235 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -263,9 +263,6 @@ def apply( # TODO x_q, x_scale = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=False) - # print(x_q.shape, x_scale.shape, - # layer.w13_weight.shape, layer.w2_weight.shape, - # layer.w13_weight_scale.shape, layer.w2_weight_scale.shape) return cutlass_moe( x_q, x_scale, @@ -285,18 +282,6 @@ def apply( layer.c_strides2, ).bfloat16() - # return fused_experts(x, - # layer.w13_weight, - # layer.w2_weight, - # topk_weights=topk_weights, - # topk_ids=topk_ids, - # inplace=True, - # use_fp8_w8a8=True, - # w1_scale=layer.w13_weight_scale, - # w2_scale=layer.w2_weight_scale, - # a1_scale=layer.w13_input_scale, - # a2_scale=layer.w2_input_scale) - class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): From 6dd6d485340f456b449d18ba8fdd125b147d217a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 25 Feb 2025 14:57:08 +0000 Subject: [PATCH 29/58] update benchmarks Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 135 +++++++++++++----- benchmarks/kernels/benchmark_shapes.py | 13 +- 2 files changed, 111 insertions(+), 37 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 7b5bf178ae68..af0a25e41dab 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + from typing import List, Tuple import torch @@ -15,11 +17,10 @@ "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" ] -DEFAULT_BATCH_SIZES = [16, 32, 64, 128, 256, 512] +DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] PER_ACT_TOKEN_OPTS = [False] #[False, True] PER_OUT_CH_OPTS = [False] #[False, True] -TOPKS = [2, 6] def to_fp8(tensor: torch.Tensor): @@ -33,9 +34,10 @@ def bench_run(results: List[benchmark.Measurement], model: str, per_out_ch: bool, mkn: Tuple[int, int, int]): label = "Quant Matmul" - sub_label = ("{}, num_experts={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_experts, per_act_token, - per_out_ch, mkn)) + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, + mkn)) print(f"Testing: {sub_label}") @@ -62,6 +64,23 @@ def bench_run(results: List[benchmark.Measurement], model: str, device="cuda", dtype=torch.float32) + ab_strides1 = torch.full((num_experts, ), + k, + device="cuda", + dtype=torch.int64) + c_strides1 = torch.full((num_experts, ), + 2 * n, + device="cuda", + dtype=torch.int64) + ab_strides2 = torch.full((num_experts, ), + n, + device="cuda", + dtype=torch.int64) + c_strides2 = torch.full((num_experts, ), + k, + device="cuda", + dtype=torch.int64) + for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) @@ -94,36 +113,69 @@ def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, k: int, num_experts: int, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor, num_repeats: int): for _ in range(num_repeats): cutlass_moe(a, a_scale, w1, w2, w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, num_experts) - - def run_from_graph(a_q: torch.Tensor, a_scale: torch.Tensor, - w1_q: torch.Tensor, w2_q: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - m: int, n: int, k: int, e: int): + topk_ids, m, n, k, num_experts, ab_strides1, + c_strides1, ab_strides2, c_strides2) + + def run_cutlass_from_graph( + a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, + k: int, e: int, ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, ab_strides2: torch.Tensor, + c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e) + topk_weights, topk_ids, m, n, k, e, ab_strides1, + c_strides1, ab_strides2, c_strides2) + + def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, a_scale: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) def replay_graph(graph, num_repeats): for _ in range(num_repeats): graph.replay() torch.cuda.synchronize() - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - run_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, num_experts) + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, m, n, k, num_experts, + ab_strides1, c_strides1, ab_strides2, + c_strides2) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, + topk_ids, w1_scale, w2_scale, a_scale) torch.cuda.synchronize() min_run_time = 5 num_warmup = 5 + num_runs = 25 globals = { # Baseline params @@ -145,11 +197,17 @@ def replay_graph(graph, num_repeats): "n": n, "k": k, "num_experts": num_experts, - # Cutlass cuda graph params - "graph": graph, + "ab_strides1": ab_strides1, + "c_strides1": c_strides1, + "ab_strides2": ab_strides2, + "c_strides2": c_strides2, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, # Gen params "topk_weights": topk_weights, "topk_ids": topk_ids, + "num_runs": num_runs, # Kernels "run_triton_moe": run_triton_moe, "run_cutlass_moe": run_cutlass_moe, @@ -163,21 +221,34 @@ def replay_graph(graph, num_repeats): results.append( benchmark.Timer( stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, 1)", # noqa: E501 + "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="triton_moe", ).blocked_autorange(min_run_time=min_run_time)) + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + # Warmup run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, num_experts, num_warmup) + topk_ids, m, n, k, num_experts, ab_strides1, c_strides1, + ab_strides2, c_strides2, num_warmup) results.append( benchmark.Timer( stmt= - "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts, 1)", # noqa: E501 + "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -185,11 +256,11 @@ def replay_graph(graph, num_repeats): ).blocked_autorange(min_run_time=min_run_time)) # Warmup - replay_graph(graph, num_warmup) + replay_graph(cutlass_graph, num_warmup) results.append( benchmark.Timer( - stmt="replay_graph(graph, 1)", + stmt="replay_graph(cutlass_graph, num_runs)", globals=globals, label=label, sub_label=sub_label, @@ -207,8 +278,9 @@ def main(args): for model in args.models: for layer in WEIGHT_SHAPES_MOE[model]: num_experts = layer[0] - size_k = layer[1] - size_n = layer[2] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] if len(args.limit_k) > 0 and size_k not in args.limit_k: continue @@ -218,11 +290,10 @@ def main(args): for per_act_token in PER_ACT_TOKEN_OPTS: for per_out_ch in PER_OUT_CH_OPTS: - for topk in TOPKS: - for size_m in DEFAULT_BATCH_SIZES: - mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + for size_m in DEFAULT_BATCH_SIZES: + mkn = (size_m, size_k, size_n) + bench_run(results, model, num_experts, topk, + per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) compare.print() diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index a95147788b3d..fc8c159342bd 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -78,16 +78,19 @@ WEIGHT_SHAPES_MOE = { "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ - [8, 4096, 28672], - [8, 14336, 4096], + [8, 2, 4096, 57344], + [8, 2, 28672, 4096], ], "nm-testing/deepseekv2-lite": [ - [64, 2048, 1408], + [64, 6, 2048, 2816], + [64, 6, 1408, 2048], ], "ibm-granite/granite-3.0-1b-a400m": [ - [32, 1024, 1024], + [32, 8, 1024, 2048], + [32, 8, 1024, 1024], ], "ibm-granite/granite-3.0-3b-a800m": [ - [40, 1024, 1536], + [40, 8, 1024, 3072], + [40, 8, 1536, 1024], ], } From 716d8c0489c11310a7b64e36f50c94e1d3ab5b81 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 27 Feb 2025 14:35:50 +0000 Subject: [PATCH 30/58] Faster data preparation kernels, bring back correct benchmark shapes Signed-off-by: ElizaWszola --- benchmarks/kernels/benchmark_shapes.py | 11 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 175 +++++++++++++++++- tests/kernels/test_cutlass_moe.py | 16 +- 3 files changed, 181 insertions(+), 21 deletions(-) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index fc8c159342bd..70190ba24d9d 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -78,19 +78,16 @@ WEIGHT_SHAPES_MOE = { "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ - [8, 2, 4096, 57344], - [8, 2, 28672, 4096], + [8, 2, 4096, 28672], + [8, 2, 14336, 4096], ], "nm-testing/deepseekv2-lite": [ - [64, 6, 2048, 2816], - [64, 6, 1408, 2048], + [64, 6, 2048, 1408], ], "ibm-granite/granite-3.0-1b-a400m": [ - [32, 8, 1024, 2048], [32, 8, 1024, 1024], ], "ibm-granite/granite-3.0-3b-a800m": [ - [40, 8, 1024, 3072], - [40, 8, 1536, 1024], + [40, 8, 1024, 1536], ], } diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 187b18d9c55b..e9b811cc0ed5 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -50,6 +50,23 @@ struct sm90_fp8_config_M16 { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm90_fp8_config_K8192 { + // K in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + template typename Epilogue> struct sm90_fp8_config_N8192 { @@ -87,6 +104,9 @@ void cutlass_grouped_mm_sm90( using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< + ElementAB_Type, ElementC_Type, + vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; @@ -96,12 +116,16 @@ void cutlass_grouped_mm_sm90( uint32_t const m = a_tensors.size(0); uint32_t const n = out_tensors.size(1); - // uint32_t const k = a_tensors.size(1); + uint32_t const k = a_tensors.size(1); if (n >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); + } else if (k >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); } else if (m <= 16) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, @@ -113,6 +137,7 @@ void cutlass_grouped_mm_sm90( } } +// basic correctness, currently unused, run with <<<1, num_experts>>> __global__ void get_grouped_mm_data(const int* __restrict__ topk_ids, int32_t* expert_offsets, int32_t* problem_sizes1, @@ -159,15 +184,153 @@ __global__ void get_grouped_mm_data(const int* __restrict__ topk_ids, } } +constexpr int THREADS_PER_EXPERT = 512; + +__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, const int n, + const int k) { + int expert_id = blockIdx.x; + + int occurrences = 0; + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + __syncthreads(); + + if (threadIdx.x == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_expert_offsets( + const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, + int32_t* atomic_buffer, const int num_experts) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + } +} + +__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, + int32_t* arg_sort, int32_t* arg_sort_prim, + int32_t* atomic_buffer, const int topk_length, + const int topk) { + int expert_id = blockIdx.x; + + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + arg_sort[start] = i / topk; + arg_sort_prim[i] = start; + } + } +} + +constexpr int THREADS_PER_EXPERT_2 = 32; + +// 1 warp per expert +// 4 experts per block +__global__ void compute_problem_sizes_2(const int* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, const int n, + const int k) { + int expert_id = blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_2; + int start = threadIdx.x % THREADS_PER_EXPERT_2; + + int occurrences = 0; + for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_2) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + // we only need this if #threads/expert > warp_size + if constexpr (THREADS_PER_EXPERT_2 > 32) { + __syncthreads(); + } + + if (start == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_arg_sorts_2(const int* __restrict__ topk_ids, + int32_t* arg_sort, int32_t* arg_sort_prim, + int32_t* atomic_buffer, + const int topk_length, const int topk) { + int expert_id = blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_2; + int start = threadIdx.x % THREADS_PER_EXPERT_2; + + for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_2) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + arg_sort[start] = i / topk; + arg_sort_prim[i] = start; + } + } +} + void get_grouped_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, const int64_t num_experts, const int64_t n, const int64_t k) { auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); - get_grouped_mm_data<<<1, num_experts, 0, stream>>>( - (const int32_t*)topk_ids.data_ptr(), (int32_t*)expert_offsets.data_ptr(), - (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), - (int32_t*)arg_sort.data_ptr(), (int32_t*)arg_sort_prim.data_ptr(), - topk_ids.numel(), n, k, topk_ids.size(1)); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // TODO this is an alternative way to block kernels + constexpr bool multi_expert_blocks = false; + if constexpr (multi_expert_blocks) { + int num_blocks = (num_experts + 3) / 4; + int num_threads = THREADS_PER_EXPERT_2 * 4; + compute_problem_sizes_2<<>>( + (const int32_t*)topk_ids.data_ptr(), + (int32_t*)problem_sizes1.data_ptr(), + (int32_t*)problem_sizes2.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), + topk_ids.numel(), n, k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + (const int32_t*)problem_sizes1.data_ptr(), + (int32_t*)expert_offsets.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), + num_experts); + compute_arg_sorts_2<<>>( + (const int32_t*)topk_ids.data_ptr(), (int32_t*)arg_sort.data_ptr(), + (int32_t*)arg_sort_prim.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), + topk_ids.numel(), topk_ids.size(1)); + return; + } + + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + compute_problem_sizes<<>>( + (const int32_t*)topk_ids.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), + (int32_t*)problem_sizes2.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), + topk_ids.numel(), n, k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + (const int32_t*)problem_sizes1.data_ptr(), + (int32_t*)expert_offsets.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), + num_experts); + compute_arg_sorts<<>>( + (const int32_t*)topk_ids.data_ptr(), (int32_t*)arg_sort.data_ptr(), + (int32_t*)arg_sort_prim.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), + topk_ids.numel(), topk_ids.size(1)); } diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index a067326f78fd..b60f18480b69 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -10,8 +10,8 @@ fused_topk) from vllm.platforms import current_platform -NUM_EXPERTS = [8, 64] -TOP_KS = [2, 6] +NUM_EXPERTS = [32, 40, 64] +TOP_KS = [6, 8] def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, @@ -27,9 +27,9 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, c_strides1, ab_strides2, c_strides2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 163840]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) +@pytest.mark.parametrize("n", [1024, 2048, 3072]) +@pytest.mark.parametrize("k", [1024, 1536, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -115,9 +115,9 @@ def test_cutlass_moe_no_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 163840]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) +@pytest.mark.parametrize("n", [1024, 2048, 3072]) +@pytest.mark.parametrize("k", [1024, 1536, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) From 975ab5f3dc038f74b51b494f87398562f4934573 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 28 Feb 2025 14:19:25 +0000 Subject: [PATCH 31/58] enable cutlass grouped gemm only on sm90 Signed-off-by: ElizaWszola --- .../cutlass_w8a8/scaled_mm_entry.cu | 18 ++++++ .../compressed_tensors_moe.py | 58 ++++++++++++------- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 446641967630..40971079e9d2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -173,9 +173,18 @@ void cutlass_grouped_mm( torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_grouped_mm for a compute capability less than " + "CUDA device capability: ", + version_num); } void get_grouped_mm_data(const torch::Tensor& topk_ids, @@ -185,9 +194,18 @@ void get_grouped_mm_data(const torch::Tensor& topk_ids, torch::Tensor& arg_sort_prim, const int64_t num_experts, const int64_t n, const int64_t k) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X get_grouped_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, arg_sort, arg_sort_prim, num_experts, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_grouped_mm for a compute capability less than " + "CUDA device capability: ", + version_num); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 9f8e3fb7c235..721e442ae217 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -247,7 +247,6 @@ def apply( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import cutlass_moe topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -261,26 +260,43 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - # TODO - x_q, x_scale = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=False) - return cutlass_moe( - x_q, - x_scale, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, - topk_weights, - topk_ids, - x.shape[0], - layer.w2_weight.shape[2], - x.shape[1], - layer.w13_weight.shape[0], - layer.ab_strides1, - layer.c_strides1, - layer.ab_strides2, - layer.c_strides2, - ).bfloat16() + #TODO should the codepath be decided here? + dev_capability = current_platform.get_device_capability().to_int() + if dev_capability == 90: + from vllm.model_executor.layers.fused_moe import cutlass_moe + x_q, x_scale = ops.scaled_fp8_quant(x, + use_per_token_if_dynamic=False) + return cutlass_moe( + x_q, + x_scale, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + x.shape[0], + layer.w2_weight.shape[2], + x.shape[1], + layer.w13_weight.shape[0], + layer.ab_strides1, + layer.c_strides1, + layer.ab_strides2, + layer.c_strides2, + ).to(x.dtype) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): From 89f2d1cb7782713105aa55bc3a1a9007c7477918 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Mar 2025 08:16:19 +0000 Subject: [PATCH 32/58] Move arch detection to CompressedTensorsMoEMethod, cleanup, bring back split kernels Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 2 +- .../compressed_tensors/compressed_tensors.py | 5 + .../compressed_tensors_moe.py | 303 ++++++++++++++---- .../kernels/scaled_mm/grouped_cutlass.py | 81 ----- 4 files changed, 244 insertions(+), 147 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index e9b811cc0ed5..bb3b240b8eea 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -300,7 +300,7 @@ void get_grouped_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); // TODO this is an alternative way to block kernels - constexpr bool multi_expert_blocks = false; + constexpr bool multi_expert_blocks = true; if constexpr (multi_expert_blocks) { int num_blocks = (num_experts + 3) / 4; int num_threads = THREADS_PER_EXPERT_2 * 4; diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ce6c706fe3d2..4758fdb15016 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -262,6 +262,11 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False) + and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: # Confirm weights quantized. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index dfe98839896b..ec6a0281a39a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -31,6 +31,7 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsWNA16MoEMethod" ] @@ -49,6 +50,8 @@ def get_moe_method( if quant_config._is_wNa16_group_channel(weight_quant, input_quant): return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) else: @@ -139,34 +142,220 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.w13_input_scale = None layer.w2_input_scale = None - device = w13_weight.device - # TODO strides can be shared across multiple layers - ab_strides1 = torch.nn.Parameter(torch.full((num_experts, ), - hidden_size, - device=device, - dtype=torch.int64), - requires_grad=False) - c_strides1 = torch.nn.Parameter(torch.full( - (num_experts, ), + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # If rocm, normalize the weights and scales to e4m3fnuz + if current_platform.is_rocm(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: + "CompressedTensorsConfigGroupedCutlass" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales " + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, 2 * intermediate_size_per_partition, - device=device, - dtype=torch.int64), + hidden_size, + dtype=params_dtype), requires_grad=False) - ab_strides2 = torch.nn.Parameter(torch.full( - (num_experts, ), + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, intermediate_size_per_partition, - device=device, - dtype=torch.int64), - requires_grad=False) - c_strides2 = torch.nn.Parameter(torch.full((num_experts, ), - hidden_size, - device=device, - dtype=torch.int64), - requires_grad=False) - layer.register_parameter("ab_strides1", ab_strides1) - layer.register_parameter("c_strides1", c_strides1) - layer.register_parameter("ab_strides2", ab_strides2) - layer.register_parameter("c_strides2", c_strides2) + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + device = w13_weight.device + # TODO strides can be shared across multiple layers + self.ab_strides1 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((num_experts, ), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((num_experts, ), + intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. @@ -251,6 +440,10 @@ def apply( activation: str = "silu", ) -> torch.Tensor: + assert global_num_experts == layer.w13_weight.shape[0] + assert expert_map is None + assert activation == "silu" + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -263,46 +456,26 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - #TODO should the codepath be decided here? - dev_capability = current_platform.get_device_capability().to_int() - if dev_capability == 90: - from vllm.model_executor.layers.fused_moe import cutlass_moe - x_q, x_scale = ops.scaled_fp8_quant(x, - use_per_token_if_dynamic=False) - return cutlass_moe( - x_q, - x_scale, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, - topk_weights, - topk_ids, - x.shape[0], - layer.w2_weight.shape[2], - x.shape[1], - layer.w13_weight.shape[0], - layer.ab_strides1, - layer.c_strides1, - layer.ab_strides2, - layer.c_strides2, - ).to(x.dtype) - else: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + from vllm.model_executor.layers.fused_moe import cutlass_moe + x_q, x_scale = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=False) + return cutlass_moe( + x_q, + x_scale, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + x.shape[0], + layer.w2_weight.shape[2], + x.shape[1], + layer.w13_weight.shape[0], + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + ).to(x.dtype) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py deleted file mode 100644 index dafff95377ab..000000000000 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/grouped_cutlass.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional, Tuple - -import torch - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.platforms import current_platform - -from .GroupedMMLinearKernel import (GroupedMMLinearKernel, - GroupedMMLinearLayerConfig) - - -class CutlassGroupMMLinearKernel(GroupedMMLinearKernel): - - @classmethod - def get_min_capability(cls) -> int: - return 75 - - @classmethod - def can_implement( - cls, c: GroupedMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: - - if (not current_platform.is_cuda() and not current_platform.is_cpu()): - return False, "CutlassScaledMM requires running on CUDA or CPU." - - return True, None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) - replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False)) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - # is_fused_module = len(layer.logical_widths) > 1 - # weight_scale = getattr(layer, self.w_s_name) - # if is_fused_module and not self.config.is_per_out_ch: - # weight_scale = convert_to_channelwise(weight_scale, - # layer.logical_widths) - # if is_fused_module and not self.config.is_per_act_token: - # input_scale = convert_to_channelwise(weight_scale, - # layer.logical_widths) - # replace_parameter( - # layer, self.w_s_name, - # torch.nn.Parameter(weight_scale.data, requires_grad=False)) - - # INPUT SCALE - if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) - - replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False)) - - else: - setattr(layer, self.i_s_name, None) - - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - w_q, w_s, i_s = self._get_weight_params(layer) - - # ops.scaled_int8_quant supports both dynamic and static quant: - # * dynamic, i_s is None and x_s computed from x. - # * static, i_s is scalar and x_s is i_s. - x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, None, symmetric=True) - - return ops.cutlass_grouped_mm(x_q, - w_q, - scale_a=x_s, - scale_b=w_s, - out_dtype=x.dtype, - bias=bias) From 8fddd4fecb5cd41e74796d979ea8c99b743f083d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Mar 2025 13:51:15 +0000 Subject: [PATCH 33/58] Fix merge, cleanup imports Signed-off-by: ElizaWszola --- .../epilogue/scaled_mm_epilogues_c3x.hpp | 6 ++--- .../cutlass_w8a8/grouped_mm_c3x.cuh | 23 +------------------ tests/kernels/test_cutlass_moe.py | 4 ++-- 3 files changed, 5 insertions(+), 28 deletions(-) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index ff2baef1b994..598abf07022d 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -73,14 +73,12 @@ struct ScaledEpilogueBase { template using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; template using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 84723bc2aba8..27f4fe7dfded 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -2,25 +2,11 @@ #include "cutlass/cutlass.h" -// TODO clean up the includes we no longer need - -#include "cute/tensor.hpp" -#include "cutlass/tensor_ref.h" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" - #include "cutlass_extensions/common.hpp" using namespace cute; @@ -82,18 +68,11 @@ template struct cutlass_3x_group_gemm { using ElementAB = ElementAB_; - // TODO check if this works using ElementC = void; using ElementD = ElementC_; - // using ElementC = ElementC_; using ElementAccumulator = float; - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementC, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; + using Epilogue = Epilogue_; using StrideC = cute::remove_pointer_t, cute::Int<0>>>; diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index b60f18480b69..844aa1431eb1 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -99,7 +99,7 @@ def test_cutlass_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) + torch_output = torch_moe(a_d, w1_d, w2_d, score, topk, None) cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, e, ab_strides1, c_strides1, ab_strides2, @@ -182,7 +182,7 @@ def test_cutlass_moe_cuda_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - torch_output = torch_moe(a_d, w1_d, w2_d, score, topk) + torch_output = torch_moe(a_d, w1_d, w2_d, score, topk, None) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() From 583f749cbf755b835580a77e27daf3a33d527cef Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Mar 2025 14:02:21 +0000 Subject: [PATCH 34/58] fix benchmark precommit hooks Signed-off-by: ElizaWszola --- benchmarks/kernels/benchmark_grouped_gemm_cutlass.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index af0a25e41dab..85767fee465a 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Tuple - import torch import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES_MOE @@ -29,9 +27,9 @@ def to_fp8(tensor: torch.Tensor): min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) -def bench_run(results: List[benchmark.Measurement], model: str, +def bench_run(results: list[benchmark.Measurement], model: str, num_experts: int, topk: int, per_act_token: bool, - per_out_ch: bool, mkn: Tuple[int, int, int]): + per_out_ch: bool, mkn: tuple[int, int, int]): label = "Quant Matmul" sub_label = ( @@ -273,7 +271,7 @@ def main(args): for i, model in enumerate(args.models): print(f"[{i}] {model}") - results: List[benchmark.Measurement] = [] + results: list[benchmark.Measurement] = [] for model in args.models: for layer in WEIGHT_SHAPES_MOE[model]: From 10f5a975557df0a1ade14280ea6a66b8208ad48c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Mar 2025 15:07:28 +0000 Subject: [PATCH 35/58] Various cleanups Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 1 - .../epilogue/broadcast_load_epilogue_c3x.hpp | 3 - .../epilogue/scaled_mm_epilogues_c3x.hpp | 6 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 33 +++++----- .../cutlass_w8a8/grouped_mm_c3x.cuh | 36 ----------- csrc/torch_bindings.cpp | 2 +- .../compressed_tensors_moe.py | 2 +- .../scaled_mm/GroupedMMLinearKernel.py | 60 ------------------- 8 files changed, 20 insertions(+), 123 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 6e8a32549864..5d1c5f4c83d3 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,7 +118,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); - // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index ad33eec9ef8f..de03889506ec 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -424,9 +424,6 @@ struct Sm90ColOrScalarBroadcast { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - // if (threadIdx.x ==128){ - // printf("M: %d, N: %d, K: %d, L: %d, coord m: %d, n: %d, k: %d, l: %d\n", M, N, K, L, m, n, k, l); - // } Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 598abf07022d..bb2bf449125d 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -401,8 +401,10 @@ struct ScaledEpilogueBiasAzpToken }; /* -TODO document -This is an epilogue with ptr arrays to a_scales and b_scales + This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers + to arrays containing different scales used in group gemm. The number of + pointers in ScaleA and the number of pointers in ScaleB are equal to the + group size. */ template struct ScaledEpilogueArray diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index bb3b240b8eea..c64188d38ee6 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -138,12 +138,10 @@ void cutlass_grouped_mm_sm90( } // basic correctness, currently unused, run with <<<1, num_experts>>> -__global__ void get_grouped_mm_data(const int* __restrict__ topk_ids, - int32_t* expert_offsets, - int32_t* problem_sizes1, - int32_t* problem_sizes2, int32_t* arg_sort, - int32_t* arg_sort_prim, int topk_length, - int n, int k, int topk) { +__global__ void get_grouped_mm_data_kernel( + const int* __restrict__ topk_ids, int32_t* expert_offsets, + int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* arg_sort, + int32_t* arg_sort_prim, int topk_length, int n, int k, int topk) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -243,12 +241,10 @@ constexpr int THREADS_PER_EXPERT_2 = 32; // 1 warp per expert // 4 experts per block -__global__ void compute_problem_sizes_2(const int* __restrict__ topk_ids, - int32_t* problem_sizes1, - int32_t* problem_sizes2, - int32_t* atomic_buffer, - const int topk_length, const int n, - const int k) { +__global__ void compute_problem_sizes_multi_expert( + const int* __restrict__ topk_ids, int32_t* problem_sizes1, + int32_t* problem_sizes2, int32_t* atomic_buffer, const int topk_length, + const int n, const int k) { int expert_id = blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_2; int start = threadIdx.x % THREADS_PER_EXPERT_2; @@ -273,10 +269,9 @@ __global__ void compute_problem_sizes_2(const int* __restrict__ topk_ids, } } -__global__ void compute_arg_sorts_2(const int* __restrict__ topk_ids, - int32_t* arg_sort, int32_t* arg_sort_prim, - int32_t* atomic_buffer, - const int topk_length, const int topk) { +__global__ void compute_arg_sorts_multi_expert( + const int* __restrict__ topk_ids, int32_t* arg_sort, int32_t* arg_sort_prim, + int32_t* atomic_buffer, const int topk_length, const int topk) { int expert_id = blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_2; int start = threadIdx.x % THREADS_PER_EXPERT_2; @@ -300,11 +295,11 @@ void get_grouped_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); // TODO this is an alternative way to block kernels - constexpr bool multi_expert_blocks = true; + constexpr bool multi_expert_blocks = false; if constexpr (multi_expert_blocks) { int num_blocks = (num_experts + 3) / 4; int num_threads = THREADS_PER_EXPERT_2 * 4; - compute_problem_sizes_2<<>>( + compute_problem_sizes_multi_expert<<>>( (const int32_t*)topk_ids.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), (int32_t*)problem_sizes2.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), @@ -313,7 +308,7 @@ void get_grouped_mm_data_caller( (const int32_t*)problem_sizes1.data_ptr(), (int32_t*)expert_offsets.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), num_experts); - compute_arg_sorts_2<<>>( + compute_arg_sorts_multi_expert<<>>( (const int32_t*)topk_ids.data_ptr(), (int32_t*)arg_sort.data_ptr(), (int32_t*)arg_sort_prim.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), topk_ids.numel(), topk_ids.size(1)); diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 27f4fe7dfded..2d608b4a10ed 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -15,16 +15,6 @@ using namespace cute; #define ENABLE_SM90_KERNEL_LEVEL 1 #endif -// for debugging -// __global__ void print_elements(int64_t* tensor, int64_t elements) { -// if (threadIdx.x == 0) { -// for (int64_t i = 0; i < elements; ++i) { -// printf("%ld ", tensor[i]); -// } -// printf("\n---\n"); -// } -// } - __global__ void get_group_gemm_starts( int32_t* expert_offsets, int64_t* a_offsets, int64_t* b_offsets, int64_t* out_offsets, int64_t* a_scales_offsets, int64_t* b_scales_offsets, @@ -106,32 +96,6 @@ struct cutlass_3x_group_gemm { struct GemmKernel : public KernelType {}; }; -template -struct ItemDeleter { - void operator()(T* ptr) { - cudaFree(ptr); // noexcept - } -}; - -template -cutlass::platform::unique_ptr> make_device_ptr( - std::vector& data_host) { - T* data_device; - int count = data_host.size(); - cudaMalloc(&data_device, count * sizeof(T)); - cudaMemcpy(data_device, data_host.data(), count * sizeof(T), - cudaMemcpyHostToDevice); - return cutlass::platform::unique_ptr>(data_device); -} - -template -cutlass::platform::unique_ptr> allocate_device_ptr( - int count) { - T* data_device; - cudaMalloc(&data_device, count * sizeof(T)); - return cutlass::platform::unique_ptr>(data_device); -} - template void cutlass_group_gemm_caller( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 588ac009ee58..51a11504d85f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -331,7 +331,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); - // CUTLASS w8a8 grouped GEMM // TODO complete this + // CUTLASS w8a8 grouped GEMM ops.def( "cutlass_grouped_mm(Tensor! out_tensors," " Tensor a_tensors," diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ec6a0281a39a..009a2adb0da2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -224,6 +224,7 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -237,7 +238,6 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts(x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py deleted file mode 100644 index 3293e2482d6d..000000000000 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/GroupedMMLinearKernel.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch - - -#always symmetric for now -@dataclass -class GroupedMMLinearLayerConfig: - is_per_act_token: bool - is_per_out_ch: bool - is_static_input_scheme: bool - - -class GroupedMMLinearKernel(ABC): - - @classmethod - @abstractmethod - def get_min_capability(cls) -> int: - raise NotImplementedError - - @classmethod - @abstractmethod - def can_implement( - cls, c: GroupedMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: - raise NotImplementedError - - def __init__(self, c: GroupedMMLinearLayerConfig, w_q_param_name: str, - w_s_param_name: str, i_s_param_name: str) -> None: - assert self.can_implement(c) - self.config = c - self.w_q_name = w_q_param_name - self.w_s_name = w_s_param_name - self.i_s_name = i_s_param_name - - @abstractmethod - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - raise NotImplementedError - - @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - raise NotImplementedError - - def _get_weight_params( - self, layer: torch.nn.Module) -> Tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - Optional[torch.Tensor], # input_scale, - ]: - return ( - getattr(layer, self.w_q_name), - getattr(layer, self.w_s_name), - getattr(layer, self.i_s_name), - ) From 5e8558745d5fda84c7b4c92ca198cae564e1ef20 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Mar 2025 15:18:47 +0000 Subject: [PATCH 36/58] precommit hook fix Signed-off-by: ElizaWszola --- .../compressed_tensors/compressed_tensors_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 009a2adb0da2..35a1ab46de39 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -257,9 +257,8 @@ def apply( class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): def __init__( - self, - quant_config: - "CompressedTensorsConfigGroupedCutlass" # type: ignore # noqa E501 + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 ): self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( From 8f5ac7768861e42cd50e883690b06dd7c0bc5345 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 12 Mar 2025 16:11:35 +0000 Subject: [PATCH 37/58] Post-merge fix, fallback to triton if not yet implemented features are used Signed-off-by: ElizaWszola --- .../cutlass_w8a8/scaled_mm_entry.cu | 6 +- .../layers/fused_moe/fused_moe.py | 5 +- .../compressed_tensors_moe.py | 138 ++++++++++-------- 3 files changed, 83 insertions(+), 66 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 1a6795f2cb75..83b4da3a6743 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -189,7 +189,7 @@ void cutlass_grouped_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { int32_t version_num = get_sm_version_num(); -#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); @@ -210,7 +210,7 @@ void get_grouped_mm_data(const torch::Tensor& topk_ids, const int64_t num_experts, const int64_t n, const int64_t k) { int32_t version_num = get_sm_version_num(); -#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 get_grouped_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, arg_sort, arg_sort_prim, num_experts, n, k); @@ -218,7 +218,7 @@ void get_grouped_mm_data(const torch::Tensor& topk_ids, #endif TORCH_CHECK_NOT_IMPLEMENTED( false, - "No compiled cutlass_grouped_mm for a compute capability less than " + "No compiled get_grouped_mm_data for a compute capability less than " "CUDA device capability: ", version_num); } diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d1568678dc18..21b296a7a399 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1584,6 +1584,7 @@ def cutlass_moe( c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, + intermediate_scale: Optional[torch.Tensor] = None, ): topk = topk_ids.shape[1] per_act_token = a_scale.numel() != 1 @@ -1620,7 +1621,9 @@ def cutlass_moe( torch.ops._C.silu_and_mul(intermediate, c1) intemediate_q, intermediate_scales = ops.scaled_fp8_quant( - intermediate, use_per_token_if_dynamic=per_act_token) + intermediate, + intermediate_scale, + use_per_token_if_dynamic=per_act_token) torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, intermediate_scales, w2_scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ea910be0f098..44ca3f8e5b14 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -374,33 +374,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if current_platform.is_rocm(): - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, - requires_grad=False) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, - requires_grad=False) - # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. assert layer.w13_weight_scale is not None @@ -438,42 +411,83 @@ def apply( activation: str = "silu", ) -> torch.Tensor: - assert global_num_experts == layer.w13_weight.shape[0] - assert expert_map is None - assert activation == "silu" + if (global_num_experts == layer.w13_weight.shape[0] + and expert_map is None and activation == "silu"): + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe import cutlass_moe + x_q, x_scale = ops.scaled_fp8_quant(x, + layer.w13_input_scale, + use_per_token_if_dynamic=False) + return cutlass_moe( + x_q, + x_scale, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + x.shape[0], + layer.w2_weight.shape[2], + x.shape[1], + layer.w13_weight.shape[0], + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + intermediate_scale=layer.w2_input_scale, + ).to(x.dtype) - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - from vllm.model_executor.layers.fused_moe import cutlass_moe - x_q, x_scale = ops.scaled_fp8_quant(x, use_per_token_if_dynamic=False) - return cutlass_moe( - x_q, - x_scale, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, - topk_weights, - topk_ids, - x.shape[0], - layer.w2_weight.shape[2], - x.shape[1], - layer.w13_weight.shape[0], - self.ab_strides1, - self.c_strides1, - self.ab_strides2, - self.c_strides2, - ).to(x.dtype) + else: + if expert_map is not None: + logger.info_once("Expert map support has not been implemented " + "in CUTLASS MoE kernel yet. Falling back to " + "Triton kernel.") + elif activation != "silu": + logger.info_once( + "CUTLASS MoE kernel currently does not " + "support %s. Falling back to Triton kernel.", activation) + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): From 3a016165ba92de2d26124ea66be511b6393c98ba Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 17 Mar 2025 16:29:04 +0000 Subject: [PATCH 38/58] Lots of minor feedback changes, self-commenting names Signed-off-by: ElizaWszola --- CMakeLists.txt | 3 +- .../epilogue/broadcast_load_epilogue_c3x.hpp | 2 - .../epilogue/scaled_mm_epilogues_c3x.hpp | 4 +- csrc/ops.h | 15 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 197 +--------------- .../cutlass_w8a8/grouped_mm_c3x.cuh | 52 ++--- csrc/quantization/cutlass_w8a8/moe_data.cu | 213 ++++++++++++++++++ .../cutlass_w8a8/scaled_mm_entry.cu | 55 ++--- csrc/torch_bindings.cpp | 38 ++-- tests/kernels/test_cutlass.py | 18 +- tests/kernels/test_cutlass_moe.py | 8 + vllm/_custom_ops.py | 53 +++++ .../layers/fused_moe/fused_moe.py | 18 +- .../layers/quantization/utils/w8a8_utils.py | 10 + 14 files changed, 392 insertions(+), 294 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/moe_data.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index fea2a10f6dca..6d049f05d4db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -345,7 +345,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu" - "csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu") + "csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index de03889506ec..58b1e8ff159f 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -422,8 +422,6 @@ struct Sm90ColOrScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index bb2bf449125d..62b848a0a963 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -434,8 +434,8 @@ struct ScaledEpilogueArray using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; - static ArgumentType prepare_args(const float* const* a_scales_ptr, - const float* const* b_scales_ptr, + static ArgumentType prepare_args(float const* const* a_scales_ptr, + float const* const* b_scales_ptr, bool a_col_broadcast, bool b_row_broadcast) { auto a_args = SUPER::template args_from_tensor( a_scales_ptr, a_col_broadcast); diff --git a/csrc/ops.h b/csrc/ops.h index 5bfb2cb1e194..1ea9f465cf21 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -164,6 +164,7 @@ int64_t ggml_moe_get_block_size(int64_t type); bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); +bool cutlass_group_gemm_supported(int64_t cuda_device_capability); void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, @@ -175,20 +176,18 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_grouped_mm( +void cutlass_moe_mm( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void get_grouped_mm_data(const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, - torch::Tensor& arg_sort_prim, - const int64_t num_experts, const int64_t n, - const int64_t k); +void get_cutlass_moe_mm_data( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index c64188d38ee6..51608da0f773 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -6,8 +6,6 @@ #include "cutlass/cutlass.h" #include "grouped_mm_c3x.cuh" -#include "grouped_mm_c3x.cuh" - using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 @@ -86,7 +84,7 @@ struct sm90_fp8_config_N8192 { } // namespace -void cutlass_grouped_mm_sm90( +void cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, @@ -136,196 +134,3 @@ void cutlass_grouped_mm_sm90( problem_sizes, a_strides, b_strides, c_strides); } } - -// basic correctness, currently unused, run with <<<1, num_experts>>> -__global__ void get_grouped_mm_data_kernel( - const int* __restrict__ topk_ids, int32_t* expert_offsets, - int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* arg_sort, - int32_t* arg_sort_prim, int topk_length, int n, int k, int topk) { - int expert_id = threadIdx.x; - int num_experts = blockDim.x; - - int occurrences = 0; - for (int i = 0; i < topk_length; ++i) { - occurrences += (topk_ids[i] == expert_id); - } - problem_sizes1[expert_id * 3] = occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; - __syncthreads(); - - if (threadIdx.x == 0) { - int32_t tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - tot_offset += problem_sizes1[i * 3]; - expert_offsets[i + 1] = tot_offset; - } - } - - __syncthreads(); - - int start = expert_offsets[expert_id]; - int end = expert_offsets[expert_id + 1]; - for (int i = 0; i < topk_length; ++i) { - if (topk_ids[i] == expert_id) { - arg_sort[start] = i / topk; - arg_sort_prim[i] = start; - ++start; - if (start == end) { - break; - } - } - } -} - -constexpr int THREADS_PER_EXPERT = 512; - -__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, - int32_t* problem_sizes1, - int32_t* problem_sizes2, - int32_t* atomic_buffer, - const int topk_length, const int n, - const int k) { - int expert_id = blockIdx.x; - - int occurrences = 0; - for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - occurrences += (topk_ids[i] == expert_id); - } - atomicAdd(&atomic_buffer[expert_id], occurrences); - __syncthreads(); - - if (threadIdx.x == 0) { - int final_occurrences = atomic_buffer[expert_id]; - problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; - } -} - -__global__ void compute_expert_offsets( - const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* atomic_buffer, const int num_experts) { - int32_t tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - atomic_buffer[i] = tot_offset; - tot_offset += problem_sizes1[i * 3]; - expert_offsets[i + 1] = tot_offset; - } -} - -__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, - int32_t* arg_sort, int32_t* arg_sort_prim, - int32_t* atomic_buffer, const int topk_length, - const int topk) { - int expert_id = blockIdx.x; - - for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - if (topk_ids[i] == expert_id) { - int start = atomicAdd(&atomic_buffer[expert_id], 1); - arg_sort[start] = i / topk; - arg_sort_prim[i] = start; - } - } -} - -constexpr int THREADS_PER_EXPERT_2 = 32; - -// 1 warp per expert -// 4 experts per block -__global__ void compute_problem_sizes_multi_expert( - const int* __restrict__ topk_ids, int32_t* problem_sizes1, - int32_t* problem_sizes2, int32_t* atomic_buffer, const int topk_length, - const int n, const int k) { - int expert_id = blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_2; - int start = threadIdx.x % THREADS_PER_EXPERT_2; - - int occurrences = 0; - for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_2) { - occurrences += (topk_ids[i] == expert_id); - } - atomicAdd(&atomic_buffer[expert_id], occurrences); - // we only need this if #threads/expert > warp_size - if constexpr (THREADS_PER_EXPERT_2 > 32) { - __syncthreads(); - } - - if (start == 0) { - int final_occurrences = atomic_buffer[expert_id]; - problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; - } -} - -__global__ void compute_arg_sorts_multi_expert( - const int* __restrict__ topk_ids, int32_t* arg_sort, int32_t* arg_sort_prim, - int32_t* atomic_buffer, const int topk_length, const int topk) { - int expert_id = blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_2; - int start = threadIdx.x % THREADS_PER_EXPERT_2; - - for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_2) { - if (topk_ids[i] == expert_id) { - int start = atomicAdd(&atomic_buffer[expert_id], 1); - arg_sort[start] = i / topk; - arg_sort_prim[i] = start; - } - } -} - -void get_grouped_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, - const int64_t num_experts, const int64_t n, const int64_t k) { - auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); - auto options_int32 = - torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); - torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); - - // TODO this is an alternative way to block kernels - constexpr bool multi_expert_blocks = false; - if constexpr (multi_expert_blocks) { - int num_blocks = (num_experts + 3) / 4; - int num_threads = THREADS_PER_EXPERT_2 * 4; - compute_problem_sizes_multi_expert<<>>( - (const int32_t*)topk_ids.data_ptr(), - (int32_t*)problem_sizes1.data_ptr(), - (int32_t*)problem_sizes2.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), - topk_ids.numel(), n, k); - compute_expert_offsets<<<1, 1, 0, stream>>>( - (const int32_t*)problem_sizes1.data_ptr(), - (int32_t*)expert_offsets.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), - num_experts); - compute_arg_sorts_multi_expert<<>>( - (const int32_t*)topk_ids.data_ptr(), (int32_t*)arg_sort.data_ptr(), - (int32_t*)arg_sort_prim.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), - topk_ids.numel(), topk_ids.size(1)); - return; - } - - int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - compute_problem_sizes<<>>( - (const int32_t*)topk_ids.data_ptr(), (int32_t*)problem_sizes1.data_ptr(), - (int32_t*)problem_sizes2.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), - topk_ids.numel(), n, k); - compute_expert_offsets<<<1, 1, 0, stream>>>( - (const int32_t*)problem_sizes1.data_ptr(), - (int32_t*)expert_offsets.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), - num_experts); - compute_arg_sorts<<>>( - (const int32_t*)topk_ids.data_ptr(), (int32_t*)arg_sort.data_ptr(), - (int32_t*)arg_sort_prim.data_ptr(), (int32_t*)atomic_buffer.data_ptr(), - topk_ids.numel(), topk_ids.size(1)); -} diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 2d608b4a10ed..0a8070588e10 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -67,8 +67,9 @@ struct cutlass_3x_group_gemm { using StrideC = cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; @@ -76,8 +77,8 @@ struct cutlass_3x_group_gemm { typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, 4, ElementD, LayoutC*, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -86,9 +87,9 @@ struct cutlass_3x_group_gemm { using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementAB, LayoutA*, 16, ElementAB, LayoutB*, - 16, ElementAccumulator, TileShape, ClusterShape, Stages, - KernelSchedule>::CollectiveOp; + ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, + LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, + Stages, KernelSchedule>::CollectiveOp; using KernelType = enable_sm90_or_later>; @@ -126,12 +127,12 @@ void cutlass_group_gemm_caller( torch::Tensor b_scales_ptrs = torch::empty(groups, options_int); get_group_gemm_starts<<<1, groups, 0, stream>>>( - reinterpret_cast(expert_offsets.data_ptr()), - reinterpret_cast(a_ptrs.data_ptr()), - reinterpret_cast(b_ptrs.data_ptr()), - reinterpret_cast(out_ptrs.data_ptr()), - reinterpret_cast(a_scales_ptrs.data_ptr()), - reinterpret_cast(b_scales_ptrs.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), reinterpret_cast(a_tensors.data_ptr()), reinterpret_cast(b_tensors.data_ptr()), reinterpret_cast(out_tensors.data_ptr()), @@ -146,27 +147,26 @@ void cutlass_group_gemm_caller( using StrideC = typename GemmKernel::InternalStrideC; ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = - reinterpret_cast( + static_cast( problem_sizes.data_ptr()); ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; typename GemmKernel::MainloopArguments mainloop_args{ - reinterpret_cast(a_ptrs.data_ptr()), - reinterpret_cast(a_strides.data_ptr()), - reinterpret_cast(b_ptrs.data_ptr()), - reinterpret_cast(b_strides.data_ptr())}; + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr())}; // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args(reinterpret_cast( - a_scales_ptrs.data_ptr()), - reinterpret_cast( - b_scales_ptrs.data_ptr()), - per_act_token, per_out_ch), - nullptr, reinterpret_cast(c_strides.data_ptr()), - reinterpret_cast(out_ptrs.data_ptr()), - reinterpret_cast(c_strides.data_ptr())}; + Gemm::Epilogue::prepare_args( + static_cast(a_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + per_act_token, per_out_ch), + nullptr, static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, diff --git a/csrc/quantization/cutlass_w8a8/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe_data.cu new file mode 100644 index 000000000000..d23b6fa5b2be --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe_data.cu @@ -0,0 +1,213 @@ +#include + +#include +#include + +#include + +// basic correctness, currently unused, run with <<<1, num_experts>>> +__global__ void get_grouped_mm_data_kernel( + const int* __restrict__ topk_ids, int32_t* expert_offsets, + int32_t* problem_sizes1, int32_t* problem_sizes2, + int32_t* input_permutation, int32_t* output_permutation, int topk_length, + int n, int k, int topk) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + problem_sizes1[expert_id * 3] = occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + __syncthreads(); + + if (threadIdx.x == 0) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + } + } + + __syncthreads(); + + int start = expert_offsets[expert_id]; + int end = expert_offsets[expert_id + 1]; + for (int i = 0; i < topk_length; ++i) { + if (topk_ids[i] == expert_id) { + input_permutation[start] = i / topk; + output_permutation[i] = start; + ++start; + if (start == end) { + break; + } + } + } +} + +constexpr uint64_t THREADS_PER_EXPERT = 512; + +__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, const int n, + const int k) { + int expert_id = blockIdx.x; + + int occurrences = 0; + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + __syncthreads(); + + if (threadIdx.x == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_expert_offsets( + const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, + int32_t* atomic_buffer, const int num_experts) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + } +} + +__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, const int topk_length, + const int topk) { + int expert_id = blockIdx.x; + + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + input_permutation[start] = i / topk; + output_permutation[i] = start; + } + } +} + +constexpr uint64_t THREADS_PER_EXPERT_MULTI_EXPERT = 32; + +// 1 warp per expert +// 4 experts per block +__global__ void compute_problem_sizes_multi_expert( + const int* __restrict__ topk_ids, int32_t* problem_sizes1, + int32_t* problem_sizes2, int32_t* atomic_buffer, const int topk_length, + const int n, const int k) { + int expert_id = blockIdx.x * 4 + + threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; + int start = threadIdx.x % THREADS_PER_EXPERT_MULTI_EXPERT; + + int occurrences = 0; + for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_MULTI_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + // we only need this if #threads/expert > warp_size + if constexpr (THREADS_PER_EXPERT_MULTI_EXPERT > 32) { + __syncthreads(); + } + + if (start == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_arg_sorts_multi_expert( + const int* __restrict__ topk_ids, int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, const int topk_length, const int topk) { + int expert_id = blockIdx.x * 4 + + threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; + int start = threadIdx.x % THREADS_PER_EXPERT_MULTI_EXPERT; + + for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_MULTI_EXPERT) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + input_permutation[start] = i / topk; + output_permutation[i] = start; + } + } +} + +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + // This is an alternative way to block kernels (currently unused) + constexpr bool multi_expert_blocks = false; + if constexpr (multi_expert_blocks) { + int num_blocks = (num_experts + 3) / 4; + int num_threads = THREADS_PER_EXPERT_MULTI_EXPERT * 4; + compute_problem_sizes_multi_expert<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), n, k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + compute_arg_sorts_multi_expert<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), topk_ids.size(1)); + return; + } + + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), n, k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + compute_arg_sorts<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), topk_ids.size(1)); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 83b4da3a6743..931418fb2255 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -30,17 +30,17 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_grouped_mm_sm90( +void cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void get_grouped_mm_data_caller( +void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& arg_sort, torch::Tensor& arg_sort_prim, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, const int64_t num_experts, const int64_t n, const int64_t k); #endif @@ -116,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { return false; } +bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { + // CUTLASS groped FP8 kernels need at least CUDA 12.0 + // and at least SM90 (Hopper) + +#if defined CUDA_VERSION + if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; + } +#endif + + return false; +} + void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, @@ -182,7 +195,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_grouped_mm( +void cutlass_moe_mm( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, @@ -190,37 +203,27 @@ void cutlass_grouped_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides) { int32_t version_num = get_sm_version_num(); #if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 - cutlass_grouped_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); + cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( false, - "No compiled cutlass_grouped_mm for a compute capability less than " + "No compiled cutlass_moe_mm for a compute capability less than 90. " "CUDA device capability: ", version_num); } -void get_grouped_mm_data(const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, torch::Tensor& arg_sort, - torch::Tensor& arg_sort_prim, - const int64_t num_experts, const int64_t n, - const int64_t k) { - int32_t version_num = get_sm_version_num(); -#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 - get_grouped_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, arg_sort, arg_sort_prim, - num_experts, n, k); +void get_cutlass_moe_mm_data( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, input_permutation, + output_permutation, num_experts, n, k); return; -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled get_grouped_mm_data for a compute capability less than " - "CUDA device capability: ", - version_num); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0961edd66050..1704c0dce6cd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -365,22 +365,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); - // CUTLASS w8a8 grouped GEMM - ops.def( - "cutlass_grouped_mm(Tensor! out_tensors," - " Tensor a_tensors," - " Tensor b_tensors, Tensor a_scales, " - " Tensor b_scales, Tensor expert_offsets, " - " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor c_strides) -> ()"); - ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + // Check if cutlass grouped gemm is supported for CUDA devices of the given + // capability + ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool"); + ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported); + // CUTLASS w8a8 grouped GEMM ops.def( - "get_grouped_mm_data(Tensor topk_ids, Tensor! expert_offsets, " - " Tensor! problem_sizes1, Tensor! problem_sizes2, " - " Tensor! arg_sort, Tensor! arg_sort_prim, " - " SymInt num_experts, SymInt n, SymInt k) -> ()"); - ops.impl("get_grouped_mm_data", torch::kCUDA, &get_grouped_mm_data); + "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " + " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " + " Tensor problem_sizes, Tensor a_strides, " + " Tensor b_strides, Tensor c_strides) -> ()"); + ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); + + // A function that computes data required to run fused MoE with w8a8 grouped + // GEMM. It takes topk_ids as an input, and computes expert_offsets + // (token start indices of each expert). In addition to this, it computes + // problem sizes for each expert's multiplication used by the two mms called + // from fused MoE operation, and arrays with permutations required to shuffle + // and de-shuffle the input/output of the fused operation. + ops.def( + "get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, " + " Tensor! problem_sizes1, Tensor! problem_sizes2, " + " Tensor! input_permutation, " + " Tensor! output_permutation, SymInt num_experts, " + " SymInt n, SymInt k) -> ()"); + ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) ops.def( diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 5447318741f0..106e9f8e045c 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -515,8 +515,10 @@ def test_cutlass_support_opcheck(): @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, per_out_ch: bool, use_bias: bool): @@ -626,12 +628,10 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, device="cuda", dtype=torch.int64) - torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked, - b_tensors_stacked, - a_scales_tensors_stacked, - b_scales_tensors_stacked, - expert_offsets[:-1], problem_sizes, - ab_strides, ab_strides, c_strides) + ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, + b_tensors_stacked, a_scales_tensors_stacked, + b_scales_tensors_stacked, expert_offsets[:-1], + problem_sizes, ab_strides, ab_strides, c_strides) # Validate each group's result against the baseline for g in range(num_groups): @@ -640,4 +640,4 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, print(baseline) print(c) print("*") - torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 844aa1431eb1..cbd12d75ceb6 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -34,6 +34,10 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") def test_cutlass_moe_no_graph( m: int, n: int, @@ -122,6 +126,10 @@ def test_cutlass_moe_no_graph( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") def test_cutlass_moe_cuda_graph( m: int, n: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d68c097fbe84..701d9dc8209e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -575,6 +575,9 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: cuda_device_capability) +def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + def cutlass_sparse_compress(a: torch.Tensor) \ -> tuple[torch.Tensor, torch.Tensor]: """ @@ -665,6 +668,56 @@ def cutlass_scaled_sparse_mm( return out +def get_cutlass_moe_mm_data( + topk_ids: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, output_permutation: torch.Tensor, + num_experts: int, n: int, k: int): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token-expert mapping) and uses it to + compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation after the input is sorted with + input_permutation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + - input_permutation: Permutation that must be used to shuffle the input + before executing the MMs. + - output_permutation: Permutation that must be used to shuffle the output + after executing the MMs. + """ + torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, output_permutation, + num_experts, n, k) + + +def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, a_strides: torch.Tensor, + b_strides: torch.Tensor, c_strides: torch.Tensor): + """ + A single grouped matrix multiplication used in CUTLASS-based fused MoE. + The function executes fp8-quantized OUT = AB matrix multiplication. + + - expert_offsets: Indices that mark at which token index each expert begins + its computation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + - a/b/c_strides: The data strides passed to grouped matrix multiplication. + """ + torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, + b_scales, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides) + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 21b296a7a399..5f8b40f8834f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1603,9 +1603,8 @@ def cutlass_moe( a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - torch.ops._C.get_grouped_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_groups, - n, k) + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, num_groups, n, k) rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) rep_a_scales = a_scale[a_map] if per_act_token else a_scale @@ -1613,9 +1612,9 @@ def cutlass_moe( c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) c2 = torch.empty((m * topk, k), device=device, dtype=torch.half) - torch.ops._C.cutlass_grouped_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, - expert_offsets[:-1], problem_sizes1, - ab_strides1, ab_strides1, c_strides1) + ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, + expert_offsets[:-1], problem_sizes1, ab_strides1, + ab_strides1, c_strides1) intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) torch.ops._C.silu_and_mul(intermediate, c1) @@ -1625,10 +1624,9 @@ def cutlass_moe( intermediate_scale, use_per_token_if_dynamic=per_act_token) - torch.ops._C.cutlass_grouped_mm(c2, intemediate_q, w2_q, - intermediate_scales, w2_scale, - expert_offsets[:-1], problem_sizes2, - ab_strides2, ab_strides2, c_strides2) + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, intermediate_scales, w2_scale, + expert_offsets[:-1], problem_sizes2, ab_strides2, + ab_strides2, c_strides2) return (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).half()).sum(dim=1) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 9de8e453354c..c2bd4bce560e 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -50,6 +50,16 @@ def cutlass_block_fp8_supported() -> bool: return ops.cutlass_scaled_mm_supports_block_fp8(capability) +def cutlass_group_gemm_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_group_gemm_supported(capability) + + CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() From 3159141077c2eb6053755a4a85aa9c641ac2267f Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 17 Mar 2025 16:40:50 +0000 Subject: [PATCH 39/58] format Signed-off-by: ElizaWszola --- csrc/quantization/cutlass_w8a8/moe_data.cu | 55 +++++++++++----------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe_data.cu index d23b6fa5b2be..50006762298a 100644 --- a/csrc/quantization/cutlass_w8a8/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe_data.cu @@ -115,8 +115,8 @@ __global__ void compute_problem_sizes_multi_expert( const int* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, const int topk_length, const int n, const int k) { - int expert_id = blockIdx.x * 4 + - threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; + int expert_id = + blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; int start = threadIdx.x % THREADS_PER_EXPERT_MULTI_EXPERT; int occurrences = 0; @@ -140,12 +140,14 @@ __global__ void compute_problem_sizes_multi_expert( } } -__global__ void compute_arg_sorts_multi_expert( - const int* __restrict__ topk_ids, int32_t* input_permutation, - int32_t* output_permutation, - int32_t* atomic_buffer, const int topk_length, const int topk) { - int expert_id = blockIdx.x * 4 + - threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; +__global__ void compute_arg_sorts_multi_expert(const int* __restrict__ topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, + const int topk_length, + const int topk) { + int expert_id = + blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; int start = threadIdx.x % THREADS_PER_EXPERT_MULTI_EXPERT; for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_MULTI_EXPERT) { @@ -176,38 +178,35 @@ void get_cutlass_moe_mm_data_caller( static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - topk_ids.numel(), n, k); + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); compute_expert_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - num_experts); + static_cast(atomic_buffer.data_ptr()), num_experts); compute_arg_sorts_multi_expert<<>>( static_cast(topk_ids.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - topk_ids.numel(), topk_ids.size(1)); + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), + topk_ids.size(1)); return; } int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - topk_ids.numel(), n, k); + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); compute_expert_offsets<<<1, 1, 0, stream>>>( - static_cast(problem_sizes1.data_ptr()), - static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - num_experts); + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), num_experts); compute_arg_sorts<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(input_permutation.data_ptr()), - static_cast(output_permutation.data_ptr()), - static_cast(atomic_buffer.data_ptr()), - topk_ids.numel(), topk_ids.size(1)); + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), + topk_ids.size(1)); } From baa503d89385fa46de261a51643434034017ddf4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 17 Mar 2025 17:23:24 +0000 Subject: [PATCH 40/58] Decide whether to use cutlass or triton in compressed tensors method init Signed-off-by: ElizaWszola --- .../compressed_tensors/compressed_tensors.py | 3 +- .../compressed_tensors_moe.py | 121 +++++++----------- 2 files changed, 45 insertions(+), 79 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4758fdb15016..7b9423b34205 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -96,7 +96,8 @@ def get_quant_method( if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method(self) + return CompressedTensorsMoEMethod.get_moe_method( + self, layer.activation, layer.expert_map) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 44ca3f8e5b14..7f808fbb35f2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -40,7 +40,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + activation: str, + expert_map: Optional[torch.Tensor], ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -50,7 +52,8 @@ def get_moe_method( if quant_config._is_wNa16_group_channel(weight_quant, input_quant): return CompressedTensorsWNA16MoEMethod(quant_config) - elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): + elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + and activation == "silu" and expert_map is None): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -411,83 +414,45 @@ def apply( activation: str = "silu", ) -> torch.Tensor: - if (global_num_experts == layer.w13_weight.shape[0] - and expert_map is None and activation == "silu"): - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - from vllm.model_executor.layers.fused_moe import cutlass_moe - x_q, x_scale = ops.scaled_fp8_quant(x, - layer.w13_input_scale, - use_per_token_if_dynamic=False) - return cutlass_moe( - x_q, - x_scale, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, - topk_weights, - topk_ids, - x.shape[0], - layer.w2_weight.shape[2], - x.shape[1], - layer.w13_weight.shape[0], - self.ab_strides1, - self.c_strides1, - self.ab_strides2, - self.c_strides2, - intermediate_scale=layer.w2_input_scale, - ).to(x.dtype) + assert activation == "silu" + assert global_num_experts == layer.w13_weight.shape[0] + assert expert_map is None - else: - if expert_map is not None: - logger.info_once("Expert map support has not been implemented " - "in CUTLASS MoE kernel yet. Falling back to " - "Triton kernel.") - elif activation != "silu": - logger.info_once( - "CUTLASS MoE kernel currently does not " - "support %s. Falling back to Triton kernel.", activation) - - from vllm.model_executor.layers.fused_moe import fused_experts - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe import cutlass_moe + x_q, x_scale = ops.scaled_fp8_quant(x, + layer.w13_input_scale, + use_per_token_if_dynamic=False) + return cutlass_moe( + x_q, + x_scale, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + layer.w13_weight_scale, + layer.w2_weight_scale, + topk_weights, + topk_ids, + x.shape[0], + layer.w2_weight.shape[2], + x.shape[1], + layer.w13_weight.shape[0], + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + intermediate_scale=layer.w2_input_scale, + ).to(x.dtype) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): From ed673cb813516db2f49be1e8bbb8f15c9b4244db Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Mar 2025 08:14:38 +0000 Subject: [PATCH 41/58] Docs, remove redundant args Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 42 ++++----- tests/kernels/test_cutlass.py | 37 ++++---- tests/kernels/test_cutlass_moe.py | 35 ++++---- .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/fused_moe.py | 90 ++++++++++++++++--- .../compressed_tensors_moe.py | 8 +- 6 files changed, 137 insertions(+), 79 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 85767fee465a..2998ab2a1a68 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,7 +6,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, fused_experts, fused_topk) from vllm.utils import FlexibleArgumentParser @@ -17,8 +17,8 @@ ] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] -PER_ACT_TOKEN_OPTS = [False] #[False, True] -PER_OUT_CH_OPTS = [False] #[False, True] +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] def to_fp8(tensor: torch.Tensor): @@ -110,28 +110,27 @@ def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - m: int, n: int, k: int, num_experts: int, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, num_repeats: int): for _ in range(num_repeats): - cutlass_moe(a, a_scale, w1, w2, w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, num_experts, ab_strides1, - c_strides1, ab_strides2, c_strides2) + cutlass_moe_fp8(a, a_scale, w1, w2, w1_scale, w2_scale, + topk_weights, topk_ids, ab_strides1, c_strides1, + ab_strides2, c_strides2) def run_cutlass_from_graph( a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, - k: int, e: int, ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, ab_strides2: torch.Tensor, - c_strides2: torch.Tensor): + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e, ab_strides1, - c_strides1, ab_strides2, c_strides2) + return cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, + w2_scale, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -159,9 +158,8 @@ def replay_graph(graph, num_repeats): cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): run_cutlass_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, num_experts, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + topk_weights, topk_ids, ab_strides1, c_strides1, + ab_strides2, c_strides2) torch.cuda.synchronize() triton_stream = torch.cuda.Stream() @@ -191,10 +189,6 @@ def replay_graph(graph, num_repeats): "w2_q": w2_q, "w1_scale": w1_scale, "w2_scale": w2_scale, - "m": m, - "n": n, - "k": k, - "num_experts": num_experts, "ab_strides1": ab_strides1, "c_strides1": c_strides1, "ab_strides2": ab_strides2, @@ -240,13 +234,13 @@ def replay_graph(graph, num_repeats): # Warmup run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, num_experts, ab_strides1, c_strides1, - ab_strides2, c_strides2, num_warmup) + topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, + num_warmup) results.append( benchmark.Timer( stmt= - "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 106e9f8e045c..605025614c10 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -511,7 +511,7 @@ def test_cutlass_support_opcheck(): # TODO add bias -@pytest.mark.parametrize("num_groups", [8, 64]) +@pytest.mark.parametrize("num_experts", [8, 64]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [False]) @@ -519,7 +519,7 @@ def test_cutlass_support_opcheck(): (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, +def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool): # Device and dtype setup @@ -533,11 +533,11 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, b_scales_tensors = [] baseline_tensors = [] - expert_offsets = torch.zeros((num_groups + 1), + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) - problem_sizes = torch.zeros((num_groups, 3), + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) @@ -548,7 +548,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, # For variation, each group has dimensions n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64) - for g in range(num_groups): + for g in range(num_experts): m_g = alignment * random.randint(1, 64) expert_offsets[g + 1] = expert_offsets[g] + m_g @@ -586,44 +586,45 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, None) baseline_tensors.append(baseline_g) - a_tensors_stacked = torch.empty((expert_offsets[num_groups], k_g), + a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn) - b_tensors_stacked = torch.empty((num_groups, n_g, k_g), + b_tensors_stacked = torch.empty((num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn) - for g in range(num_groups): + for g in range(num_experts): a_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] = a_tensors[g] b_tensors_stacked[g] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.transpose(1, 2) if per_act_token: - a_scales_tensors_stacked = torch.empty((expert_offsets[num_groups], 1), - device=device, - dtype=torch.float32) - for g in range(num_groups): + a_scales_tensors_stacked = torch.empty( + (expert_offsets[num_experts], 1), + device=device, + dtype=torch.float32) + for g in range(num_experts): a_scales_tensors_stacked[ expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] else: a_scales_tensors_stacked = one_scale_a - b_scales_tensors_stacked = torch.empty((num_groups, n_b_scales), + b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales), device=device, dtype=torch.float32) - for g in range(num_groups): + for g in range(num_experts): b_scales_tensors_stacked[g] = b_scales_tensors[g] - out_tensors_stacked = torch.zeros((expert_offsets[num_groups], n_g), + out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), device=device, dtype=out_dtype) - ab_strides = torch.full((num_groups, ), + ab_strides = torch.full((num_experts, ), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64) - c_strides = torch.full((num_groups, ), + c_strides = torch.full((num_experts, ), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64) @@ -634,7 +635,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, problem_sizes, ab_strides, ab_strides, c_strides) # Validate each group's result against the baseline - for g in range(num_groups): + for g in range(num_experts): baseline = baseline_tensors[g] c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] print(baseline) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index cbd12d75ceb6..9cc5497955bc 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -5,7 +5,7 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, fused_experts, fused_topk) from vllm.platforms import current_platform @@ -16,15 +16,15 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, - k: int, e: int, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e, ab_strides1, - c_strides1, ab_strides2, c_strides2) + return cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, ab_strides1, c_strides1, + ab_strides2, c_strides2) @pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) @@ -104,10 +104,10 @@ def test_cutlass_moe_no_graph( topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) torch_output = torch_moe(a_d, w1_d, w2_d, score, topk, None) - cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, - w2_scale, topk_weights, topk_ids, m, n, k, - e, ab_strides1, c_strides1, ab_strides2, - c_strides2) + cutlass_output = cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, + w2_scale, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) print(torch_output) print(cutlass_output) @@ -196,9 +196,8 @@ def test_cutlass_moe_cuda_graph( graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): cutlass_output = run(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, m, n, k, e, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + topk_weights, topk_ids, ab_strides1, + c_strides1, ab_strides2, c_strides2) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() @@ -273,11 +272,11 @@ def test_cutlass_moe_profile( activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True) as prof_cutlass: with torch.profiler.record_function("cutlass_output"): - cutlass_output = cutlass_moe(a_q, a_scale, w1_q, w2_q, - w1_scale, w2_scale, topk_weights, - topk_ids, m, n, k, e, ab_strides1, - c_strides1, ab_strides2, - c_strides2) + cutlass_output = cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, + w1_scale, w2_scale, + topk_weights, topk_ids, + ab_strides1, c_strides1, + ab_strides2, c_strides2) print("profile cutlass:") print( prof_cutlass.key_averages(group_by_input_shape=True).table( diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index c29849ce707d..e096d14fc6f9 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,7 +36,7 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( - cutlass_moe, fused_experts, fused_moe, fused_topk, + cutlass_moe_fp8, fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ @@ -45,5 +45,5 @@ def get_config() -> Optional[Dict[str, Any]]: "fused_experts", "get_config_file_name", "grouped_topk", - "cutlass_moe", + "cutlass_moe_fp8", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5f8b40f8834f..bdcc2213b4b3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1459,7 +1459,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, return out_hidden_states -#TODO make the grouped gemm kernel consistent with scaled gemm kernel def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1567,7 +1566,8 @@ def fused_moe( block_shape=block_shape) -def cutlass_moe( +#TODO make the grouped gemm kernel consistent with scaled gemm kernel +def cutlass_moe_fp8( a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, @@ -1576,27 +1576,94 @@ def cutlass_moe( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - m: int, - n: int, - k: int, - num_groups: int, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, intermediate_scale: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a_q (torch.Tensor): The fp8-quantized input tensor to the MoE layer. + Shape: [M, K] + - a_scale (torch.Tensor): The fp32 scale to dequantize a_q. + Shape: scalar or [M] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - ab_strides1 (torch.Tensor): The input and weights strides of the first + grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - ab_strides2 (torch.Tensor): The input and weights strides of the second + grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - intermediate_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize and dequantize the intermediate result between the gemms. + Shape: scalar + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert a_q.dtype == torch.float8_e4m3fn + assert w1_q.dtype == torch.float8_e4m3fn + assert w2_q.dtype == torch.float8_e4m3fn + assert a_q.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert a_scale.dim() == 0 or a_scale.shape[0] == 1 or a_scale.shape[ + 0] == a_q.shape[0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1_q.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2_q.shape[2], "W2 scale shape mismatch" + assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert intermediate_scale is None or intermediate_scale.dim( + ) == 0 or intermediate_scale.shape[ + 0] == 1, "Intermediate scale shape mismatch" + assert ab_strides1.shape[0] == w1_q.shape[ + 0], "AB Strides 1 expert number mismatch" + assert c_strides1.shape[0] == w1_q.shape[ + 0], "C Strides 1 expert number mismatch" + assert ab_strides2.shape[0] == w2_q.shape[ + 0], "AB Strides 2 expert number mismatch" + assert c_strides2.shape[0] == w2_q.shape[ + 0], "C Strides 2 expert number mismatch" + + num_experts = w1_q.shape[0] + m = a_q.shape[0] + k = w1_q.shape[1] + n = w2_q.shape[1] + topk = topk_ids.shape[1] per_act_token = a_scale.numel() != 1 device = a_q.device - expert_offsets = torch.empty((num_groups + 1), + expert_offsets = torch.empty((num_experts + 1), dtype=torch.int32, device=device) - problem_sizes1 = torch.empty((num_groups, 3), + problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) - problem_sizes2 = torch.empty((num_groups, 3), + problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) @@ -1604,7 +1671,8 @@ def cutlass_moe( c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_groups, n, k) + problem_sizes2, a_map, c_map, num_experts, n, + k) rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) rep_a_scales = a_scale[a_map] if per_act_token else a_scale diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7f808fbb35f2..2aa8dae01162 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -430,11 +430,11 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - from vllm.model_executor.layers.fused_moe import cutlass_moe + from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 x_q, x_scale = ops.scaled_fp8_quant(x, layer.w13_input_scale, use_per_token_if_dynamic=False) - return cutlass_moe( + return cutlass_moe_fp8( x_q, x_scale, layer.w13_weight.transpose(1, 2), @@ -443,10 +443,6 @@ def apply( layer.w2_weight_scale, topk_weights, topk_ids, - x.shape[0], - layer.w2_weight.shape[2], - x.shape[1], - layer.w13_weight.shape[0], self.ab_strides1, self.c_strides1, self.ab_strides2, From 5287681d0d297ad84c7eb8cef204a2d77efdf193 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Mar 2025 12:15:39 +0000 Subject: [PATCH 42/58] Changed CUDA version error message, added tp TODO to benchmark Signed-off-by: ElizaWszola --- benchmarks/kernels/benchmark_grouped_gemm_cutlass.py | 1 + csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 2998ab2a1a68..ce1280db5a3a 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -291,6 +291,7 @@ def main(args): compare.print() +# TODO add --tp-sizes argument if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark Marlin across specified models/shapes/batches") diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 931418fb2255..c71f7909e9e9 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -210,8 +210,7 @@ void cutlass_moe_mm( #endif TORCH_CHECK_NOT_IMPLEMENTED( false, - "No compiled cutlass_moe_mm for a compute capability less than 90. " - "CUDA device capability: ", + "cutlass_moe_mm requires capability 90. Current CUDA device capability: ", version_num); } From 42dc92ca3beb64a8a10e6349969e3386b569ae05 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Mar 2025 13:25:01 +0000 Subject: [PATCH 43/58] Add tp argument to benchmarks Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index ce1280db5a3a..dbbd40ef5725 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -16,6 +16,7 @@ "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" ] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] PER_ACT_TOKEN_OPTS = [False] PER_OUT_CH_OPTS = [False] @@ -268,30 +269,30 @@ def main(args): results: list[benchmark.Measurement] = [] for model in args.models: - for layer in WEIGHT_SHAPES_MOE[model]: - num_experts = layer[0] - topk = layer[1] - size_k = layer[2] - size_n = layer[3] - - if len(args.limit_k) > 0 and size_k not in args.limit_k: - continue - - if len(args.limit_n) > 0 and size_n not in args.limit_n: - continue - - for per_act_token in PER_ACT_TOKEN_OPTS: - for per_out_ch in PER_OUT_CH_OPTS: - for size_m in DEFAULT_BATCH_SIZES: - mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in DEFAULT_BATCH_SIZES: + mkn = (size_m, size_k, size_n) + bench_run(results, model, num_experts, topk, + per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) compare.print() -# TODO add --tp-sizes argument if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark Marlin across specified models/shapes/batches") @@ -302,6 +303,10 @@ def main(args): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) + parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) parser.add_argument("--batch-sizes", nargs="+", type=int, From 83f708423914fa2f90e2c336b19f84428e7620b2 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Mar 2025 17:47:46 +0000 Subject: [PATCH 44/58] Add bfloat16 type to the kernel Signed-off-by: ElizaWszola --- .../cutlass_w8a8/grouped_mm_c3x.cu | 50 +++++++++++++++---- .../cutlass_w8a8/grouped_mm_c3x.cuh | 13 ++--- .../layers/fused_moe/fused_moe.py | 11 ++-- .../compressed_tensors_moe.py | 3 +- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 51608da0f773..517e92c90c4d 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -82,9 +82,8 @@ struct sm90_fp8_config_N8192 { KernelSchedule, EpilogueSchedule>; }; -} // namespace - -void cutlass_moe_mm_sm90( +template +void run_cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, @@ -99,18 +98,17 @@ void cutlass_moe_mm_sm90( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< - ElementAB_Type, ElementC_Type, - vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< - ElementAB_Type, ElementC_Type, - vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< - ElementAB_Type, ElementC_Type, - vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmDefault = typename sm90_fp8_config_default< - ElementAB_Type, ElementC_Type, - vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; uint32_t const m = a_tensors.size(0); uint32_t const n = out_tensors.size(1); @@ -134,3 +132,33 @@ void cutlass_moe_mm_sm90( problem_sizes, a_strides, b_strides, c_strides); } } + +void dispatch_moe_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm90( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else { + run_cutlass_moe_mm_sm90( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } +} + +} // namespace + +void cutlass_moe_mm_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides); +} diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 0a8070588e10..3b105ad4b8d6 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -41,8 +41,6 @@ namespace { using ProblemShape = cutlass::gemm::GroupProblemShape>; -using ElementAB_Type = cutlass::float_e4m3_t; -using ElementC_Type = cutlass::half_t; using ElementAccumulator = float; using ArchTag = cutlass::arch::Sm90; @@ -105,7 +103,6 @@ void cutlass_group_gemm_caller( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { using ElementAB = typename Gemm::ElementAB; - using ElementC = typename Gemm::ElementC; using ElementD = typename Gemm::ElementD; int groups = (int)expert_offsets.size(0); @@ -138,8 +135,8 @@ void cutlass_group_gemm_caller( reinterpret_cast(out_tensors.data_ptr()), reinterpret_cast(a_scales.data_ptr()), reinterpret_cast(b_scales.data_ptr()), out_tensors.size(1), - a_tensors.size(1), per_act_token, per_out_ch, sizeof(ElementAB_Type), - sizeof(ElementC_Type), sizeof(ElementAccumulator)); + a_tensors.size(1), per_act_token, per_out_ch, sizeof(ElementAB), + sizeof(ElementD), sizeof(ElementAccumulator)); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -152,9 +149,9 @@ void cutlass_group_gemm_caller( ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; typename GemmKernel::MainloopArguments mainloop_args{ - static_cast(a_ptrs.data_ptr()), + static_cast(a_ptrs.data_ptr()), static_cast(a_strides.data_ptr()), - static_cast(b_ptrs.data_ptr()), + static_cast(b_ptrs.data_ptr()), static_cast(b_strides.data_ptr())}; // Currently, we are only able to do broadcast on either all or none a_scales @@ -165,7 +162,7 @@ void cutlass_group_gemm_caller( static_cast(b_scales_ptrs.data_ptr()), per_act_token, per_out_ch), nullptr, static_cast(c_strides.data_ptr()), - static_cast(out_ptrs.data_ptr()), + static_cast(out_ptrs.data_ptr()), static_cast(c_strides.data_ptr())}; typename GemmKernel::Arguments args{ diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1fac2f83830b..4e9b69ff9a66 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1586,6 +1586,7 @@ def cutlass_moe_fp8( ab_strides2: torch.Tensor, c_strides2: torch.Tensor, intermediate_scale: Optional[torch.Tensor] = None, + out_dtype: torch.Type = torch.half, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -1618,6 +1619,7 @@ def cutlass_moe_fp8( - intermediate_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize and dequantize the intermediate result between the gemms. Shape: scalar + - out_dtype (torch.Tensor): The output tensor type. Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. @@ -1652,6 +1654,7 @@ def cutlass_moe_fp8( 0], "AB Strides 2 expert number mismatch" assert c_strides2.shape[0] == w2_q.shape[ 0], "C Strides 2 expert number mismatch" + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" num_experts = w1_q.shape[0] m = a_q.shape[0] @@ -1682,14 +1685,14 @@ def cutlass_moe_fp8( rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) rep_a_scales = a_scale[a_map] if per_act_token else a_scale - c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) - c2 = torch.empty((m * topk, k), device=device, dtype=torch.half) + c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) + c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, ab_strides1, c_strides1) - intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) torch.ops._C.silu_and_mul(intermediate, c1) intemediate_q, intermediate_scales = ops.scaled_fp8_quant( @@ -1702,4 +1705,4 @@ def cutlass_moe_fp8( ab_strides2, c_strides2) return (c2[c_map].view(m, topk, k) * - topk_weights.view(m, topk, 1).half()).sum(dim=1) + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 2aa8dae01162..3286ccd45019 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -448,7 +448,8 @@ def apply( self.ab_strides2, self.c_strides2, intermediate_scale=layer.w2_input_scale, - ).to(x.dtype) + out_dtype=x.dtype, + ) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): From be83180f5393603053c1121b8bf9b480b1bea023 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 19 Mar 2025 14:24:45 +0000 Subject: [PATCH 45/58] Rename groups to num_experts in kernel, make group starts kernel more readable and move it to separate file Signed-off-by: ElizaWszola --- .../cutlass_w8a8/get_group_starts.cuh | 90 +++++++++++++++++++ .../cutlass_w8a8/grouped_mm_c3x.cuh | 58 +++--------- .../cutlass_w8a8/scaled_mm_entry.cu | 6 +- 3 files changed, 106 insertions(+), 48 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/get_group_starts.cuh diff --git a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh new file mode 100644 index 000000000000..b7e4f57c2f1d --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include + +#include "core/scalar_type.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + + +template +__global__ void get_group_gemm_starts( + int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, + ElementAB* a_base_as_int, ElementAB* b_base_as_int, + ElementC* out_base_as_int, ElementAccumulator* a_scales_base_as_int, + ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, + bool per_act_token, bool per_out_ch) { + int expert_id = threadIdx.x; + + int64_t expert_offset = expert_offsets[expert_id]; + + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + b_offsets[expert_id] = b_base_as_int + expert_id * k * n; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = + a_scales_base_as_int + (per_act_token ? expert_offset : 0); + b_scales_offsets[expert_id] = + b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts<<< \ + 1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + reinterpret_cast(a_tensors.data_ptr()), \ + reinterpret_cast(b_tensors.data_ptr()), \ + reinterpret_cast(out_tensors.data_ptr()), \ + reinterpret_cast(a_scales.data_ptr()), \ + reinterpret_cast(b_scales.data_ptr()), \ + out_tensors.size(1), a_tensors.size(1), \ + per_act_token, per_out_ch); \ + } + +namespace { + + +void run_get_group_gemm_starts( + torch::Tensor const& expert_offsets, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor& out_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int num_experts = (int)expert_offsets.size(0); + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if(false){ + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, + "Invalid output type (must be float16 or bfloat16)"); + } +} + +} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh index 3b105ad4b8d6..e0f5ae88d019 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh @@ -8,6 +8,7 @@ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/common.hpp" +#include "get_group_starts.cuh" using namespace cute; @@ -15,28 +16,6 @@ using namespace cute; #define ENABLE_SM90_KERNEL_LEVEL 1 #endif -__global__ void get_group_gemm_starts( - int32_t* expert_offsets, int64_t* a_offsets, int64_t* b_offsets, - int64_t* out_offsets, int64_t* a_scales_offsets, int64_t* b_scales_offsets, - const int64_t a_base_as_int, const int64_t b_base_as_int, - const int64_t out_base_as_int, const int64_t a_scales_base_as_int, - const int64_t b_scales_base_as_int, int64_t n, int64_t k, - bool per_act_token, bool per_out_ch, int64_t ab_size, int64_t c_size, - int64_t acc_size) { - int expert_id = threadIdx.x; - - int64_t expert_offset = expert_offsets[expert_id]; - - a_offsets[expert_id] = a_base_as_int + expert_offset * k * ab_size; - b_offsets[expert_id] = b_base_as_int + expert_id * k * n * ab_size; - out_offsets[expert_id] = out_base_as_int + expert_offset * n * c_size; - a_scales_offsets[expert_id] = - a_scales_base_as_int + (per_act_token ? expert_offset : 0) * acc_size; - b_scales_offsets[expert_id] = - b_scales_base_as_int + - (per_out_ch ? n * expert_id : expert_id) * acc_size; -} - namespace { using ProblemShape = @@ -105,38 +84,27 @@ void cutlass_group_gemm_caller( using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; - int groups = (int)expert_offsets.size(0); + int num_experts = (int)expert_offsets.size(0); int k_size = a_tensors.size(1); int n_size = out_tensors.size(1); bool per_act_token = a_scales.numel() != 1; - bool per_out_ch = b_scales.numel() != groups; + bool per_out_ch = b_scales.numel() != num_experts; auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); - torch::Tensor a_ptrs = torch::empty(groups, options_int); - torch::Tensor b_ptrs = torch::empty(groups, options_int); - torch::Tensor out_ptrs = torch::empty(groups, options_int); - torch::Tensor a_scales_ptrs = torch::empty(groups, options_int); - torch::Tensor b_scales_ptrs = torch::empty(groups, options_int); - - get_group_gemm_starts<<<1, groups, 0, stream>>>( - static_cast(expert_offsets.data_ptr()), - static_cast(a_ptrs.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(out_ptrs.data_ptr()), - static_cast(a_scales_ptrs.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), - reinterpret_cast(a_tensors.data_ptr()), - reinterpret_cast(b_tensors.data_ptr()), - reinterpret_cast(out_tensors.data_ptr()), - reinterpret_cast(a_scales.data_ptr()), - reinterpret_cast(b_scales.data_ptr()), out_tensors.size(1), - a_tensors.size(1), per_act_token, per_out_ch, sizeof(ElementAB), - sizeof(ElementD), sizeof(ElementAccumulator)); + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + + run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, + out_tensors, a_scales, b_scales); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -146,7 +114,7 @@ void cutlass_group_gemm_caller( ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = static_cast( problem_sizes.data_ptr()); - ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; typename GemmKernel::MainloopArguments mainloop_args{ static_cast(a_ptrs.data_ptr()), diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index c71f7909e9e9..4daf03adfb34 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -210,8 +210,9 @@ void cutlass_moe_mm( #endif TORCH_CHECK_NOT_IMPLEMENTED( false, - "cutlass_moe_mm requires capability 90. Current CUDA device capability: ", - version_num); + "No compiled cutlass_scaled_mm for a compute capability less than " + "CUDA device capability: ", + version_num, ". Required capability: 90"); } void get_cutlass_moe_mm_data( @@ -222,7 +223,6 @@ void get_cutlass_moe_mm_data( get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); - return; } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, From e6481c8d71807f40a684834a0f573b3ac58c490c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 19 Mar 2025 14:35:54 +0000 Subject: [PATCH 46/58] format Signed-off-by: ElizaWszola --- .../cutlass_w8a8/get_group_starts.cuh | 93 +++++++++---------- 1 file changed, 42 insertions(+), 51 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh index b7e4f57c2f1d..2f52670e33e6 100644 --- a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh @@ -8,14 +8,13 @@ #include "cutlass/bfloat16.h" #include "cutlass/float8.h" - template __global__ void get_group_gemm_starts( int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, - ElementAccumulator** b_scales_offsets, - ElementAB* a_base_as_int, ElementAB* b_base_as_int, - ElementC* out_base_as_int, ElementAccumulator* a_scales_base_as_int, + ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, bool per_act_token, bool per_out_ch) { int expert_id = threadIdx.x; @@ -31,60 +30,52 @@ __global__ void get_group_gemm_starts( b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); } -#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ - get_group_gemm_starts<<< \ - 1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ - static_cast(a_ptrs.data_ptr()), \ - static_cast(b_ptrs.data_ptr()), \ - static_cast(out_ptrs.data_ptr()), \ - static_cast(a_scales_ptrs.data_ptr()), \ - static_cast(b_scales_ptrs.data_ptr()), \ - reinterpret_cast(a_tensors.data_ptr()), \ - reinterpret_cast(b_tensors.data_ptr()), \ - reinterpret_cast(out_tensors.data_ptr()), \ - reinterpret_cast(a_scales.data_ptr()), \ - reinterpret_cast(b_scales.data_ptr()), \ - out_tensors.size(1), a_tensors.size(1), \ - per_act_token, per_out_ch); \ - } +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + reinterpret_cast(a_tensors.data_ptr()), \ + reinterpret_cast(b_tensors.data_ptr()), \ + reinterpret_cast(out_tensors.data_ptr()), \ + reinterpret_cast(a_scales.data_ptr()), \ + reinterpret_cast(b_scales.data_ptr()), \ + out_tensors.size(1), a_tensors.size(1), per_act_token, \ + per_out_ch); \ + } namespace { - void run_get_group_gemm_starts( - torch::Tensor const& expert_offsets, - torch::Tensor& a_ptrs, - torch::Tensor& b_ptrs, - torch::Tensor& out_ptrs, - torch::Tensor& a_scales_ptrs, - torch::Tensor& b_scales_ptrs, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor& out_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - int num_experts = (int)expert_offsets.size(0); - bool per_act_token = a_scales.numel() != 1; - bool per_out_ch = b_scales.numel() != num_experts; + int num_experts = (int)expert_offsets.size(0); + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - if(false){ - } - __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) - __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) - else { - TORCH_CHECK(false, - "Invalid output type (must be float16 or bfloat16)"); - } + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } } } // namespace \ No newline at end of file From f0c2f06ab4651362f3c9b24858c2594271c4246e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 19 Mar 2025 14:45:31 +0000 Subject: [PATCH 47/58] format Signed-off-by: ElizaWszola --- csrc/quantization/cutlass_w8a8/get_group_starts.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh index 2f52670e33e6..320591d7498f 100644 --- a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh @@ -40,11 +40,11 @@ __global__ void get_group_gemm_starts( static_cast(out_ptrs.data_ptr()), \ static_cast(a_scales_ptrs.data_ptr()), \ static_cast(b_scales_ptrs.data_ptr()), \ - reinterpret_cast(a_tensors.data_ptr()), \ - reinterpret_cast(b_tensors.data_ptr()), \ - reinterpret_cast(out_tensors.data_ptr()), \ - reinterpret_cast(a_scales.data_ptr()), \ - reinterpret_cast(b_scales.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ out_tensors.size(1), a_tensors.size(1), per_act_token, \ per_out_ch); \ } From 8d0e70089d9160823496bb47a0f114d6619a5091 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 19 Mar 2025 14:51:05 +0000 Subject: [PATCH 48/58] format 3 Signed-off-by: ElizaWszola --- .../cutlass_w8a8/get_group_starts.cuh | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh index 320591d7498f..c00d08ec0e8b 100644 --- a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/get_group_starts.cuh @@ -30,23 +30,22 @@ __global__ void get_group_gemm_starts( b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); } -#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ - get_group_gemm_starts \ - <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ - static_cast(a_ptrs.data_ptr()), \ - static_cast(b_ptrs.data_ptr()), \ - static_cast(out_ptrs.data_ptr()), \ - static_cast(a_scales_ptrs.data_ptr()), \ - static_cast(b_scales_ptrs.data_ptr()), \ - static_cast(a_tensors.data_ptr()), \ - static_cast(b_tensors.data_ptr()), \ - static_cast(out_tensors.data_ptr()), \ - static_cast(a_scales.data_ptr()), \ - static_cast(b_scales.data_ptr()), \ - out_tensors.size(1), a_tensors.size(1), per_act_token, \ - per_out_ch); \ +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), out_tensors.size(1), \ + a_tensors.size(1), per_act_token, per_out_ch); \ } namespace { From 84dbc2aa230a094b1fec45f1014311d319ec879c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 21 Mar 2025 14:49:16 +0000 Subject: [PATCH 49/58] Add hack for accepting int input in weak_ref_tensors Signed-off-by: ElizaWszola --- vllm/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 79787303af5b..738a0dd3a213 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1560,18 +1560,21 @@ def contains(self, key: object, *, strict: bool = False) -> bool: return any(cls in self.data for cls in key.mro()) -def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: +def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. The new tensor will share the same data as the original tensor, but will not keep the original tensor alive. """ - return torch.ops._C.weak_ref_tensor(tensor) + if isinstance(tensor, torch.Tensor): + return torch.ops._C.weak_ref_tensor(tensor) + else: + return tensor def weak_ref_tensors( tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] -) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]: +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. @@ -1584,7 +1587,6 @@ def weak_ref_tensors( return tuple(weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") - def is_in_doc_build() -> bool: try: from sphinx.ext.autodoc.mock import _MockModule From 5ad4b0be60b063f151929274cd82543a0d8c2a83 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 24 Mar 2025 17:05:40 +0000 Subject: [PATCH 50/58] Fixes Signed-off-by: ElizaWszola --- csrc/torch_bindings.cpp | 8 +++++--- vllm/model_executor/layers/fused_moe/fused_moe.py | 12 ++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b89e3f56c99e..60ad6430336a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -375,7 +375,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor c_strides) -> ()"); + " Tensor b_strides, Tensor c_strides) -> ()", + {stride_tag}); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); // A function that computes data required to run fused MoE with w8a8 grouped @@ -388,8 +389,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, " " Tensor! problem_sizes1, Tensor! problem_sizes2, " " Tensor! input_permutation, " - " Tensor! output_permutation, SymInt num_experts, " - " SymInt n, SymInt k) -> ()"); + " Tensor! output_permutation, int num_experts, " + " int n, int k) -> ()", + {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4e9b69ff9a66..f44b9eb8154d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1586,7 +1586,7 @@ def cutlass_moe_fp8( ab_strides2: torch.Tensor, c_strides2: torch.Tensor, intermediate_scale: Optional[torch.Tensor] = None, - out_dtype: torch.Type = torch.half, + out_dtype: torch.dtype = torch.half, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -1656,12 +1656,12 @@ def cutlass_moe_fp8( 0], "C Strides 2 expert number mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - num_experts = w1_q.shape[0] - m = a_q.shape[0] - k = w1_q.shape[1] - n = w2_q.shape[1] + num_experts = w1_q.size(0) + m = a_q.size(0) + k = w1_q.size(1) + n = w2_q.size(1) - topk = topk_ids.shape[1] + topk = topk_ids.size(1) per_act_token = a_scale.numel() != 1 device = a_q.device From 41eb522bff074f14cb2926937650b4bb0bd3de19 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 24 Mar 2025 17:18:45 +0000 Subject: [PATCH 51/58] format utils.py Signed-off-by: ElizaWszola --- vllm/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/utils.py b/vllm/utils.py index 738a0dd3a213..4ee1d80d5d27 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1587,6 +1587,7 @@ def weak_ref_tensors( return tuple(weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") + def is_in_doc_build() -> bool: try: from sphinx.ext.autodoc.mock import _MockModule From c6076b31a0675469760e11034ff2f38d9bf0a4e3 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 26 Mar 2025 06:57:36 +0000 Subject: [PATCH 52/58] Make handling of both input scales consistent in the code Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 44 +++-- tests/kernels/test_cutlass_moe.py | 181 ++++++------------ .../layers/fused_moe/fused_moe.py | 49 +++-- .../compressed_tensors_moe.py | 10 +- 4 files changed, 114 insertions(+), 170 deletions(-) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index dbbd40ef5725..bcdbf6c7551a 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -48,7 +48,7 @@ def bench_run(results: list[benchmark.Measurement], model: str, w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 - a_q, a_scale = ops.scaled_fp8_quant(a) + _, a_scale = ops.scaled_fp8_quant(a) w1_q = torch.empty((num_experts, 2 * n, k), device="cuda", @@ -115,12 +115,21 @@ def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, num_repeats: int): for _ in range(num_repeats): - cutlass_moe_fp8(a, a_scale, w1, w2, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) + cutlass_moe_fp8(a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) def run_cutlass_from_graph( - a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, @@ -128,10 +137,18 @@ def run_cutlass_from_graph( with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -158,7 +175,7 @@ def replay_graph(graph, num_repeats): cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): - run_cutlass_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2) torch.cuda.synchronize() @@ -176,7 +193,6 @@ def replay_graph(graph, num_repeats): globals = { # Baseline params - "a": a, "w1": w1, "w2": w2, "score": score, @@ -184,7 +200,6 @@ def replay_graph(graph, num_repeats): "w1_q_notransp": w1_q_notransp, "w2_q_notransp": w2_q_notransp, # Cutlass params - "a_q": a_q, "a_scale": a_scale, "w1_q": w1_q, "w2_q": w2_q, @@ -198,6 +213,7 @@ def replay_graph(graph, num_repeats): "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, # Gen params + "a": a, "topk_weights": topk_weights, "topk_ids": topk_ids, "num_runs": num_runs, @@ -234,14 +250,14 @@ def replay_graph(graph, num_repeats): ).blocked_autorange(min_run_time=min_run_time)) # Warmup - run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, + run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_warmup) results.append( benchmark.Timer( stmt= - "run_cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 9cc5497955bc..a6befbfb446e 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -2,7 +2,6 @@ import pytest import torch -from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, @@ -14,7 +13,7 @@ TOP_KS = [6, 8] -def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, +def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, @@ -22,12 +21,21 @@ def run(a_q: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) - - -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) + + +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512]) @pytest.mark.parametrize("n", [1024, 2048, 3072]) @pytest.mark.parametrize("k", [1024, 1536, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -58,8 +66,14 @@ def test_cutlass_moe_no_graph( w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - a_q, a_scale = ops.scaled_fp8_quant( + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 @@ -87,7 +101,6 @@ def test_cutlass_moe_no_graph( w2[expert], use_per_token_if_dynamic=per_out_ch) w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) - a_d = (a_q.float() * a_scale).half() ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) @@ -103,17 +116,26 @@ def test_cutlass_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - torch_output = torch_moe(a_d, w1_d, w2_d, score, topk, None) - cutlass_output = cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) - - print(torch_output) + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + cutlass_output = cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale1) + + print(triton_output) print(cutlass_output) print("*") - torch.testing.assert_close(torch_output, + torch.testing.assert_close(triton_output, cutlass_output, atol=5e-2, rtol=1e-2) @@ -150,8 +172,14 @@ def test_cutlass_moe_cuda_graph( w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - a_q, a_scale = ops.scaled_fp8_quant( + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 @@ -167,6 +195,11 @@ def test_cutlass_moe_cuda_graph( device="cuda", dtype=torch.float32) + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( w1[expert], use_per_token_if_dynamic=per_out_ch) @@ -174,7 +207,6 @@ def test_cutlass_moe_cuda_graph( w2[expert], use_per_token_if_dynamic=per_out_ch) w1_q = w1_q.transpose(1, 2) w2_q = w2_q.transpose(1, 2) - a_d = (a_q.float() * a_scale).half() ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) @@ -190,124 +222,23 @@ def test_cutlass_moe_cuda_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - torch_output = torch_moe(a_d, w1_d, w2_d, score, topk, None) + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, + cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - print(torch_output) + print(triton_output) print(cutlass_output) print("*") - torch.testing.assert_close(torch_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) - - -@pytest.mark.skip("profiling only") -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -def test_cutlass_moe_profile( - m: int, - n: int, - k: int, - e: int, - topk: int, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - a_q, a_scale = ops.scaled_fp8_quant(a) - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) - w2_scale = torch.empty((e, 1, 1), device="cuda", dtype=torch.float32) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) - w1_q_notransp = w1_q.clone() - w2_q_notransp = w2_q.clone() - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - # ruff: noqa: SIM117 - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], - record_shapes=True) as prof_cutlass: - with torch.profiler.record_function("cutlass_output"): - cutlass_output = cutlass_moe_fp8(a_q, a_scale, w1_q, w2_q, - w1_scale, w2_scale, - topk_weights, topk_ids, - ab_strides1, c_strides1, - ab_strides2, c_strides2) - print("profile cutlass:") - print( - prof_cutlass.key_averages(group_by_input_shape=True).table( - sort_by="cuda_time_total", row_limit=50)) - - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA], - record_shapes=True) as prof_triton: - with torch.profiler.record_function("triton_output"): - triton_output = fused_experts(a, - w1_q_notransp, - w2_q_notransp, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - print("profile triton:") - print( - prof_triton.key_averages(group_by_input_shape=True).table( - sort_by="cuda_time_total", row_limit=50)) - - # Uncomment to produce trace files - # cutlass_trace_name = f"trace_cutlass-{m}x{n}x{k}-{e}x{topk}.json" - # triton_trace_name = f"trace_triton-{m}x{n}x{k}-{e}x{topk}.json" - # prof_cutlass.export_chrome_trace(cutlass_trace_name) - # prof_triton.export_chrome_trace(triton_trace_name) - torch.testing.assert_close(triton_output, cutlass_output, - atol=5e-2, + atol=9e-2, rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f44b9eb8154d..266bdd06475c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1573,8 +1573,7 @@ def fused_moe( #TODO make the grouped gemm kernel consistent with scaled gemm kernel def cutlass_moe_fp8( - a_q: torch.Tensor, - a_scale: torch.Tensor, + a: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, @@ -1585,7 +1584,8 @@ def cutlass_moe_fp8( c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - intermediate_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, ) -> torch.Tensor: """ @@ -1595,10 +1595,8 @@ def cutlass_moe_fp8( grouped gemm. Parameters: - - a_q (torch.Tensor): The fp8-quantized input tensor to the MoE layer. + - a (torch.Tensor): The input tensor to the MoE layer. Shape: [M, K] - - a_scale (torch.Tensor): The fp32 scale to dequantize a_q. - Shape: scalar or [M] - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. Shape: [num_experts, K, 2N] (the weights are passed transposed) - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. @@ -1616,9 +1614,11 @@ def cutlass_moe_fp8( - ab_strides2 (torch.Tensor): The input and weights strides of the second grouped gemm. - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. - - intermediate_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize and dequantize the intermediate result between the gemms. - Shape: scalar + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] - out_dtype (torch.Tensor): The output tensor type. Returns: @@ -1626,14 +1626,14 @@ def cutlass_moe_fp8( """ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert a_q.dtype == torch.float8_e4m3fn assert w1_q.dtype == torch.float8_e4m3fn assert w2_q.dtype == torch.float8_e4m3fn - assert a_q.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a_scale.dim() == 0 or a_scale.shape[0] == 1 or a_scale.shape[ - 0] == a_q.shape[0], "Input scale shape mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ + 0], "Input scale shape mismatch" assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ 1] == w1_q.shape[2], "W1 scale shape mismatch" assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ @@ -1643,9 +1643,7 @@ def cutlass_moe_fp8( 0], "w1 scales expert number mismatch" assert w1_q.shape[0] == w2_scale.shape[ 0], "w2 scales expert number mismatch" - assert intermediate_scale is None or intermediate_scale.dim( - ) == 0 or intermediate_scale.shape[ - 0] == 1, "Intermediate scale shape mismatch" + assert a2_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 assert ab_strides1.shape[0] == w1_q.shape[ 0], "AB Strides 1 expert number mismatch" assert c_strides1.shape[0] == w1_q.shape[ @@ -1657,12 +1655,15 @@ def cutlass_moe_fp8( assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" num_experts = w1_q.size(0) - m = a_q.size(0) + m = a.size(0) k = w1_q.size(1) n = w2_q.size(1) topk = topk_ids.size(1) - per_act_token = a_scale.numel() != 1 + per_act_token = a1_scale.numel() != 1 + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) device = a_q.device expert_offsets = torch.empty((num_experts + 1), @@ -1683,24 +1684,22 @@ def cutlass_moe_fp8( k) rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a_scales = a_scale[a_map] if per_act_token else a_scale + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a_scales, w1_scale, + ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, ab_strides1, c_strides1) intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) torch.ops._C.silu_and_mul(intermediate, c1) - intemediate_q, intermediate_scales = ops.scaled_fp8_quant( - intermediate, - intermediate_scale, - use_per_token_if_dynamic=per_act_token) + intemediate_q, a2_scale = ops.scaled_fp8_quant( + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, intermediate_scales, w2_scale, + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 3286ccd45019..2e14845ff2d6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -431,12 +431,9 @@ def apply( e_score_correction_bias=e_score_correction_bias) from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 - x_q, x_scale = ops.scaled_fp8_quant(x, - layer.w13_input_scale, - use_per_token_if_dynamic=False) + return cutlass_moe_fp8( - x_q, - x_scale, + x, layer.w13_weight.transpose(1, 2), layer.w2_weight.transpose(1, 2), layer.w13_weight_scale, @@ -447,7 +444,8 @@ def apply( self.c_strides1, self.ab_strides2, self.c_strides2, - intermediate_scale=layer.w2_input_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, out_dtype=x.dtype, ) From c8f15678c0b7775355cab95765f9dd9e8ed9a09c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 26 Mar 2025 07:29:37 +0000 Subject: [PATCH 53/58] Fix handling optional vals Signed-off-by: ElizaWszola --- vllm/model_executor/layers/fused_moe/fused_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 266bdd06475c..584674f1c8f3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1643,7 +1643,7 @@ def cutlass_moe_fp8( 0], "w1 scales expert number mismatch" assert w1_q.shape[0] == w2_scale.shape[ 0], "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 assert ab_strides1.shape[0] == w1_q.shape[ 0], "AB Strides 1 expert number mismatch" assert c_strides1.shape[0] == w1_q.shape[ @@ -1660,7 +1660,8 @@ def cutlass_moe_fp8( n = w2_q.size(1) topk = topk_ids.size(1) - per_act_token = a1_scale.numel() != 1 + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) a_q, a1_scale = ops.scaled_fp8_quant( a, a1_scale, use_per_token_if_dynamic=per_act_token) From 96296cbab24fed83834543726444aeccabe664f4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 26 Mar 2025 19:46:46 +0000 Subject: [PATCH 54/58] feedback: version checks, file structure Signed-off-by: ElizaWszola --- CMakeLists.txt | 31 ++- csrc/cutlass_extensions/common.hpp | 12 +- .../{ => moe}/get_group_starts.cuh | 2 +- .../cutlass_w8a8/{ => moe}/grouped_mm_c3x.cu | 2 +- .../cutlass_w8a8/{ => moe}/grouped_mm_c3x.cuh | 6 +- .../quantization/cutlass_w8a8/moe/moe_data.cu | 90 ++++++++ csrc/quantization/cutlass_w8a8/moe_data.cu | 212 ------------------ .../cutlass_w8a8/scaled_mm_entry.cu | 8 +- tests/kernels/test_cutlass.py | 1 - tests/kernels/test_cutlass_moe.py | 14 +- .../compressed_tensors/compressed_tensors.py | 25 ++- 11 files changed, 162 insertions(+), 241 deletions(-) rename csrc/quantization/cutlass_w8a8/{ => moe}/get_group_starts.cuh (98%) rename csrc/quantization/cutlass_w8a8/{ => moe}/grouped_mm_c3x.cu (99%) rename csrc/quantization/cutlass_w8a8/{ => moe}/grouped_mm_c3x.cuh (96%) create mode 100644 csrc/quantization/cutlass_w8a8/moe/moe_data.cu delete mode 100644 csrc/quantization/cutlass_w8a8/moe_data.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 07d4a2921c3e..6e6737e979aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -344,9 +344,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu" - "csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/moe_data.cu") + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -463,6 +461,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() + # + # CUTLASS MoE kernels + + # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works + # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible + # to compile MoE kernels that use its output. + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" + "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + # # Machete kernels diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index febc4eccd956..dbe0e30f5cbf 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel { Kernel::operator()(std::forward(args)...); #endif } -}; \ No newline at end of file +}; + +template +struct enable_sm90_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh similarity index 98% rename from csrc/quantization/cutlass_w8a8/get_group_starts.cuh rename to csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index c00d08ec0e8b..6c6e89790847 100644 --- a/csrc/quantization/cutlass_w8a8/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -62,7 +62,7 @@ void run_get_group_gemm_starts( TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - int num_experts = (int)expert_offsets.size(0); + int num_experts = static_cast(expert_offsets.size(0)); bool per_act_token = a_scales.numel() != 1; bool per_out_ch = b_scales.numel() != num_experts; diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu similarity index 99% rename from csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu rename to csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu index 517e92c90c4d..3f480a0e062c 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -8,7 +8,7 @@ using namespace cute; -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 #define ENABLE_SM90_KERNEL_LEVEL 1 #endif diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh similarity index 96% rename from csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh rename to csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index e0f5ae88d019..503f5c417fee 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -12,7 +12,7 @@ using namespace cute; -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 #define ENABLE_SM90_KERNEL_LEVEL 1 #endif @@ -68,7 +68,7 @@ struct cutlass_3x_group_gemm { LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; - using KernelType = enable_sm90_or_later>; struct GemmKernel : public KernelType {}; @@ -84,7 +84,7 @@ void cutlass_group_gemm_caller( using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; - int num_experts = (int)expert_offsets.size(0); + int num_experts = static_cast(expert_offsets.size(0)); int k_size = a_tensors.size(1); int n_size = out_tensors.size(1); diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu new file mode 100644 index 000000000000..2fb0417ce6c4 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -0,0 +1,90 @@ +#include + +#include +#include + +#include + +constexpr uint64_t THREADS_PER_EXPERT = 512; + +__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const int topk_length, const int n, + const int k) { + int expert_id = blockIdx.x; + + int occurrences = 0; + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + occurrences += (topk_ids[i] == expert_id); + } + atomicAdd(&atomic_buffer[expert_id], occurrences); + __syncthreads(); + + if (threadIdx.x == 0) { + int final_occurrences = atomic_buffer[expert_id]; + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } +} + +__global__ void compute_expert_offsets( + const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, + int32_t* atomic_buffer, const int num_experts) { + int32_t tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + } +} + +__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, const int topk_length, + const int topk) { + int expert_id = blockIdx.x; + + for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { + if (topk_ids[i] == expert_id) { + int start = atomicAdd(&atomic_buffer[expert_id], 1); + input_permutation[start] = i / topk; + output_permutation[i] = start; + } + } +} + +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); + auto options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), num_experts); + compute_arg_sorts<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), + topk_ids.size(1)); +} diff --git a/csrc/quantization/cutlass_w8a8/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe_data.cu deleted file mode 100644 index 50006762298a..000000000000 --- a/csrc/quantization/cutlass_w8a8/moe_data.cu +++ /dev/null @@ -1,212 +0,0 @@ -#include - -#include -#include - -#include - -// basic correctness, currently unused, run with <<<1, num_experts>>> -__global__ void get_grouped_mm_data_kernel( - const int* __restrict__ topk_ids, int32_t* expert_offsets, - int32_t* problem_sizes1, int32_t* problem_sizes2, - int32_t* input_permutation, int32_t* output_permutation, int topk_length, - int n, int k, int topk) { - int expert_id = threadIdx.x; - int num_experts = blockDim.x; - - int occurrences = 0; - for (int i = 0; i < topk_length; ++i) { - occurrences += (topk_ids[i] == expert_id); - } - problem_sizes1[expert_id * 3] = occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; - __syncthreads(); - - if (threadIdx.x == 0) { - int32_t tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - tot_offset += problem_sizes1[i * 3]; - expert_offsets[i + 1] = tot_offset; - } - } - - __syncthreads(); - - int start = expert_offsets[expert_id]; - int end = expert_offsets[expert_id + 1]; - for (int i = 0; i < topk_length; ++i) { - if (topk_ids[i] == expert_id) { - input_permutation[start] = i / topk; - output_permutation[i] = start; - ++start; - if (start == end) { - break; - } - } - } -} - -constexpr uint64_t THREADS_PER_EXPERT = 512; - -__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, - int32_t* problem_sizes1, - int32_t* problem_sizes2, - int32_t* atomic_buffer, - const int topk_length, const int n, - const int k) { - int expert_id = blockIdx.x; - - int occurrences = 0; - for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - occurrences += (topk_ids[i] == expert_id); - } - atomicAdd(&atomic_buffer[expert_id], occurrences); - __syncthreads(); - - if (threadIdx.x == 0) { - int final_occurrences = atomic_buffer[expert_id]; - problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; - } -} - -__global__ void compute_expert_offsets( - const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* atomic_buffer, const int num_experts) { - int32_t tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - atomic_buffer[i] = tot_offset; - tot_offset += problem_sizes1[i * 3]; - expert_offsets[i + 1] = tot_offset; - } -} - -__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, - int32_t* input_permutation, - int32_t* output_permutation, - int32_t* atomic_buffer, const int topk_length, - const int topk) { - int expert_id = blockIdx.x; - - for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - if (topk_ids[i] == expert_id) { - int start = atomicAdd(&atomic_buffer[expert_id], 1); - input_permutation[start] = i / topk; - output_permutation[i] = start; - } - } -} - -constexpr uint64_t THREADS_PER_EXPERT_MULTI_EXPERT = 32; - -// 1 warp per expert -// 4 experts per block -__global__ void compute_problem_sizes_multi_expert( - const int* __restrict__ topk_ids, int32_t* problem_sizes1, - int32_t* problem_sizes2, int32_t* atomic_buffer, const int topk_length, - const int n, const int k) { - int expert_id = - blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; - int start = threadIdx.x % THREADS_PER_EXPERT_MULTI_EXPERT; - - int occurrences = 0; - for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_MULTI_EXPERT) { - occurrences += (topk_ids[i] == expert_id); - } - atomicAdd(&atomic_buffer[expert_id], occurrences); - // we only need this if #threads/expert > warp_size - if constexpr (THREADS_PER_EXPERT_MULTI_EXPERT > 32) { - __syncthreads(); - } - - if (start == 0) { - int final_occurrences = atomic_buffer[expert_id]; - problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; - } -} - -__global__ void compute_arg_sorts_multi_expert(const int* __restrict__ topk_ids, - int32_t* input_permutation, - int32_t* output_permutation, - int32_t* atomic_buffer, - const int topk_length, - const int topk) { - int expert_id = - blockIdx.x * 4 + threadIdx.x / THREADS_PER_EXPERT_MULTI_EXPERT; - int start = threadIdx.x % THREADS_PER_EXPERT_MULTI_EXPERT; - - for (int i = start; i < topk_length; i += THREADS_PER_EXPERT_MULTI_EXPERT) { - if (topk_ids[i] == expert_id) { - int start = atomicAdd(&atomic_buffer[expert_id], 1); - input_permutation[start] = i / topk; - output_permutation[i] = start; - } - } -} - -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k) { - auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); - auto options_int32 = - torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); - torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); - - // This is an alternative way to block kernels (currently unused) - constexpr bool multi_expert_blocks = false; - if constexpr (multi_expert_blocks) { - int num_blocks = (num_experts + 3) / 4; - int num_threads = THREADS_PER_EXPERT_MULTI_EXPERT * 4; - compute_problem_sizes_multi_expert<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - compute_expert_offsets<<<1, 1, 0, stream>>>( - static_cast(problem_sizes1.data_ptr()), - static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); - compute_arg_sorts_multi_expert<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(input_permutation.data_ptr()), - static_cast(output_permutation.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), - topk_ids.size(1)); - return; - } - - int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); - compute_expert_offsets<<<1, 1, 0, stream>>>( - static_cast(problem_sizes1.data_ptr()), - static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); - compute_arg_sorts<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(input_permutation.data_ptr()), - static_cast(output_permutation.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), - topk_ids.size(1)); -} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 4daf03adfb34..6755718cca5e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -117,12 +117,12 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS groped FP8 kernels need at least CUDA 12.0 - // and at least SM90 (Hopper) + // CUTLASS groped FP8 kernels need at least CUDA 12.3 + // and SM90 (Hopper) #if defined CUDA_VERSION - if (cuda_device_capability >= 90) { - return CUDA_VERSION >= 12000; + if (cuda_device_capability == 90) { + return CUDA_VERSION >= 12030; } #endif diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 605025614c10..f11ce6f45a98 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -510,7 +510,6 @@ def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) -# TODO add bias @pytest.mark.parametrize("num_experts", [8, 64]) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index a6befbfb446e..1652c72d86fe 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -9,7 +9,7 @@ fused_topk) from vllm.platforms import current_platform -NUM_EXPERTS = [32, 40, 64] +NUM_EXPERTS = [40, 64] TOP_KS = [6, 8] @@ -35,9 +35,9 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, a1_scale=a_scale) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512]) -@pytest.mark.parametrize("n", [1024, 2048, 3072]) -@pytest.mark.parametrize("k", [1024, 1536, 2048]) +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -141,9 +141,9 @@ def test_cutlass_moe_no_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) -@pytest.mark.parametrize("n", [1024, 2048, 3072]) -@pytest.mark.parametrize("k", [1024, 1536, 2048]) +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 7b9423b34205..4b2d7ca2bade 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -192,17 +192,26 @@ def get_config_filenames(cls) -> List[str]: def _check_scheme_supported(self, min_capability: int, - error: bool = True) -> bool: + error: bool = True, + match_exact: bool = False) -> bool: capability_tuple = current_platform.get_device_capability() if capability_tuple is not None: capability = capability_tuple.to_int() - supported = capability >= min_capability - if error and not supported: - raise RuntimeError( - "Quantization scheme is not supported for ", - f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + if match_exact: + supported = capability == min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + "the current GPU. Required capability: ", + f"{min_capability}. Current capability: {capability}.") + else: + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") return supported else: return False @@ -265,7 +274,7 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: - return (self._check_scheme_supported(90, error=False) + return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) def _is_fp8_w8a16(self, weight_quant: BaseModel, From 3977d674ecde7a76cf6f5f2ee1477b8cab9eeb73 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 26 Mar 2025 20:18:14 +0000 Subject: [PATCH 55/58] Change cmake flag, remove unused code Signed-off-by: ElizaWszola --- CMakeLists.txt | 2 +- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu | 4 ---- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh | 4 ---- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e6737e979aa..e0f1fdf78d14 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -475,7 +475,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu index 3f480a0e062c..2b8bc3fb0b26 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -8,10 +8,6 @@ using namespace cute; -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900 - #define ENABLE_SM90_KERNEL_LEVEL 1 -#endif - namespace { template Date: Wed, 26 Mar 2025 20:56:51 +0000 Subject: [PATCH 56/58] update kernel run conditions in scaled_mm_entry.cu Signed-off-by: ElizaWszola --- .../cutlass_w8a8/scaled_mm_entry.cu | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 6755718cca5e..54b63894e4cb 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -202,7 +202,7 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { int32_t version_num = get_sm_version_num(); -#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); @@ -210,9 +210,8 @@ void cutlass_moe_mm( #endif TORCH_CHECK_NOT_IMPLEMENTED( false, - "No compiled cutlass_scaled_mm for a compute capability less than " - "CUDA device capability: ", - version_num, ". Required capability: 90"); + "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, + ". Required capability: 90"); } void get_cutlass_moe_mm_data( @@ -220,9 +219,20 @@ void get_cutlass_moe_mm_data( torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, const int64_t num_experts, const int64_t n, const int64_t k) { + // This function currently gets compiled only if we have a valid cutlass moe + // mm to run it for. + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " + "CUDA device capability: ", + version_num, ". Required capability: 90"); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, From fbe2b80c9b0e6c1d0e2fe2071e58e6d9e4d0be14 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 27 Mar 2025 01:39:29 +0000 Subject: [PATCH 57/58] added channelwise, dynamic per token Signed-off-by: rshaw@neuralmagic.com --- benchmarks/backend_request_func.py | 4 +- .../openai_completion_client.py | 7 +- vllm/model_executor/layers/fused_moe/layer.py | 26 ----- .../compressed_tensors_moe.py | 107 +++++++++++------- 4 files changed, 73 insertions(+), 71 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 0f13c79ae234..90e58f6ec94d 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -252,7 +252,9 @@ async def async_request_openai_completions( "temperature": 0.0, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, - "stream": True, + "stream": False, + "use_beam_search": True, + "n": 3, "stream_options": { "include_usage": True, }, diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 06b93d7d1931..48953eeac650 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,7 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:8001/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") @@ -21,9 +21,8 @@ model=model, prompt="A robot may not injure a human being", echo=False, - n=2, - stream=stream, - logprobs=3) + n=3, + extra_body={'use_beam_search': True}) print("Completion results:") if stream: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 917643134645..063ead8951fa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -822,32 +822,6 @@ def make_expert_params_mapping( ] ] - def _load_fp8_scale(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: - param_data = param.data - - # Input scales can be loaded directly and should be equal. - if "input_scale" in weight_name: - if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}") - param_data[expert_id] = loaded_weight - # Weight scales - elif "weight_scale" in weight_name: - # If we are in merged column case (gate_up_proj) - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - else: - param_data[expert_id] = loaded_weight - def extra_repr(self) -> str: s = ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 2e14845ff2d6..bf32bee89e89 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -268,14 +268,23 @@ def __init__( self.input_quant = self.quant_config.target_scheme_map["Linear"].get( "input_activations") - if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy == QuantizationStrategy.TENSOR): + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " + "For FP8 Fused MoE layers, we require per tensor " + "or channelwise, dynamic per token quantization. Found " f"{self.weight_quant}, {self.input_quant}") self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: @@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " @@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + # For Per-TENSOR case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) def apply( self, From e0bae3cee2fab6ce830a796f8a0724f7093b194d Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 27 Mar 2025 01:51:14 +0000 Subject: [PATCH 58/58] updated Signed-off-by: rshaw@neuralmagic.com --- benchmarks/backend_request_func.py | 4 +--- examples/online_serving/openai_completion_client.py | 7 ++++--- .../compressed_tensors/compressed_tensors_moe.py | 1 - 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 90e58f6ec94d..0f13c79ae234 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -252,9 +252,7 @@ async def async_request_openai_completions( "temperature": 0.0, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, - "stream": False, - "use_beam_search": True, - "n": 3, + "stream": True, "stream_options": { "include_usage": True, }, diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 48953eeac650..06b93d7d1931 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -4,7 +4,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8001/v1" +openai_api_base = "http://localhost:8000/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") @@ -21,8 +21,9 @@ model=model, prompt="A robot may not injure a human being", echo=False, - n=3, - extra_body={'use_beam_search': True}) + n=2, + stream=stream, + logprobs=3) print("Completion results:") if stream: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 503dc43d6dc2..bf32bee89e89 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -32,7 +32,6 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", - "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsWNA16MoEMethod" ]