-
Notifications
You must be signed in to change notification settings - Fork 413
Add CUTLASS-based W4A4 #1515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add CUTLASS-based W4A4 #1515
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
1d350d6
add w4a4
gau-nernst 7e277df
add test
gau-nernst a44df9e
hook up to AQT
gau-nernst 2487eb9
Merge branch 'main' into w4a4
gau-nernst de167f0
fix quant api test
gau-nernst fe1f0eb
fix test
gau-nernst 908f464
Merge branch 'main' into w4a4
gau-nernst 883384b
make threadblockswizzle a template param
gau-nernst ee34bb2
re-use s8s4 cutlass template
gau-nernst f513523
add Alex's patch and some changes
gau-nernst 9a1ce25
fix aqt test
gau-nernst b9db0f1
remove int4_cutlass.cu
gau-nernst f42fc65
apply alex's patch
gau-nernst a43f804
Merge branch 'main' into w4a4
gau-nernst 5c30303
update benchmark script
gau-nernst d7c0896
ruff
gau-nernst 2c5f565
Merge branch 'main' into w4a4
gau-nernst fd8dc4e
add some tuning
gau-nernst 5449a56
reduce num_stages to fit shared memory of small GPUs (<100kb)
gau-nernst c421921
replace torch timer with triton do_bench
gau-nernst 81a0a13
ruff
gau-nernst 69e6777
Merge branch 'main' into w4a4
gau-nernst c736856
use ZeroPointDomain.NONE
gau-nernst bdcb85c
fix 3.7 typing
gau-nernst 0c85805
Merge branch 'main' into w4a4
gau-nernst 4a19634
merge Aleksandar changes
gau-nernst 496cec8
run ruff
gau-nernst 9a0ae7b
try replace torch/extension.h with torch/library.h
gau-nernst 9332ac4
Merge branch 'main' into w4a4
gau-nernst 37dc5f7
(alexsamardzic) improve error handling
gau-nernst c003018
ruff format
gau-nernst 3b0b32b
add note on cutlass naming
gau-nernst 4613503
Merge branch 'main' into w4a4
gau-nernst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,231 @@ | ||
| #include <torch/extension.h> | ||
| #include <ATen/cuda/CUDAContext.h> | ||
|
|
||
| // copied from s8s4_linear_cutlass.cu | ||
| #if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ | ||
| defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) | ||
| #define BUILD_INT4_MM_CUTLASS | ||
| #endif | ||
|
|
||
| #if defined(BUILD_INT4_MM_CUTLASS) | ||
| #include "cutlass/cutlass.h" | ||
| #include "cutlass/gemm/device/gemm_universal.h" | ||
| #include "cutlass/gemm/device/gemm.h" | ||
| #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" | ||
| #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" | ||
| #include "cutlass/gemm/device/gemm_universal_adapter.h" | ||
|
|
||
| #define CUTLASS_STATUS_CHECK(status) \ | ||
| { \ | ||
| TORCH_CHECK(status == cutlass::Status::kSuccess, \ | ||
| __func__, " : Got CUTLASS error: ", \ | ||
| cutlassGetStatusString(status)); \ | ||
| } | ||
| #endif | ||
|
|
||
| namespace torchao { | ||
|
|
||
| #if defined(BUILD_INT4_MM_CUTLASS) | ||
| // define common params | ||
| using ElementA = cutlass::int4b_t; | ||
| using ElementB = cutlass::int4b_t; | ||
| using ElementAccumulator = int32_t; | ||
| using OpClass = cutlass::arch::OpClassTensorOp; | ||
| using ArchTag = cutlass::arch::Sm80; | ||
|
|
||
| // how many elements to load at a time -> load 128-bit = 32 x 4-bit | ||
| constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; | ||
| constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; | ||
| #endif | ||
|
|
||
| // we will do input checks in python. A and B are stored as int8 | ||
| torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) { | ||
| #if defined(BUILD_INT4_MM_CUTLASS) | ||
| int M = A.size(0); | ||
| int K = A.size(1) * 2; | ||
| int N = B.size(1); | ||
| torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32)); | ||
|
|
||
| // some configs for int4 mma | ||
| // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu | ||
| // using default config. this can be tuned. | ||
| using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; | ||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | ||
| using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; | ||
| // static int const kStages = 3; | ||
| using ElementC = int32_t; | ||
| using Gemm = cutlass::gemm::device::Gemm< | ||
| ElementA, cutlass::layout::RowMajor, // A matrix | ||
| ElementB, cutlass::layout::ColumnMajor, // B matrix | ||
| ElementC, cutlass::layout::RowMajor, // C matrix | ||
| ElementAccumulator, OpClass, ArchTag, | ||
| ThreadblockShape, WarpShape, InstructionShape | ||
| >; | ||
| Gemm::Arguments args { | ||
| {M, N, K}, | ||
| {reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()), K}, | ||
| {reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()), K}, | ||
| {C.data_ptr<ElementC>(), N}, | ||
| {C.data_ptr<ElementC>(), N}, | ||
| {1, 0} // epilogue | ||
| }; | ||
| Gemm gemm_op; | ||
| CUTLASS_STATUS_CHECK(gemm_op(args)); | ||
| return C; | ||
| #else | ||
| TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); | ||
| return at::Tensor{}; | ||
| #endif | ||
| } | ||
|
|
||
| template< | ||
| typename ElementC, | ||
| typename ThreadblockShape, | ||
| typename WarpShape, | ||
| typename InstructionShape, | ||
| int numStages> | ||
| void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) { | ||
| // problem shape | ||
| int M = A.size(0); | ||
| int K = A.size(1) * 2; | ||
| int N = B.size(1); | ||
|
|
||
| constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // 8 for BF16/FP16 | ||
| using ElementEpilogue = float; | ||
| constexpr int numEpilogueStages = 1; | ||
|
|
||
| // build epilogue visitor tree | ||
| using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< | ||
| ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages | ||
| >; | ||
|
|
||
| using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; | ||
| constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest; | ||
| using Multiply = cutlass::epilogue::threadblock::VisitorCompute< | ||
| cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode | ||
| >; | ||
|
|
||
| // (1, N) | ||
| using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast< | ||
| OutputTileThreadMap, ElementC, | ||
| cute::Stride<cute::_0, cute::_1, int32_t> // MNL | ||
| >; | ||
| using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, Accum, ColScale>; | ||
|
|
||
| // (M, 1) | ||
| using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast< | ||
| OutputTileThreadMap, ElementC, | ||
| cute::Stride<cute::_1, cute::_0, int32_t> // MNL | ||
| >; | ||
| using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, EVTCompute0, RowScale>; | ||
|
|
||
| using Output = cutlass::epilogue::threadblock::VisitorAuxStore< | ||
| OutputTileThreadMap, ElementC, RoundMode, | ||
| cute::Stride<int64_t, cute::_1, int64_t> // MNL | ||
| >; | ||
| using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<Output, EVTCompute1>; | ||
|
|
||
| using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< | ||
| ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA, | ||
| ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB, | ||
| ElementC, cutlass::layout::RowMajor, AlignmentC, | ||
| ElementAccumulator, ElementEpilogue, OpClass, ArchTag, | ||
| ThreadblockShape, WarpShape, InstructionShape, | ||
| EVTOutput, | ||
| cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, | ||
| numStages, | ||
| cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work | ||
| numEpilogueStages | ||
| >::GemmKernel; | ||
| using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>; | ||
|
|
||
| // col_scale, row_scale, and C must have the same dtype | ||
| const ElementA *A_ptr = reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()); | ||
| const ElementB *B_ptr = reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()); | ||
| const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr()); | ||
| const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr()); | ||
| ElementC *C_ptr = reinterpret_cast<ElementC *>(C.data_ptr()); | ||
|
|
||
| typename EVTOutput::Arguments callback_args{ | ||
| { | ||
| { | ||
| {}, // Accum | ||
| {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale | ||
| {} // Multiply | ||
| }, // EVTCompute0 | ||
| {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale | ||
| {} // Multiply | ||
| }, // EVTCompute1 | ||
| {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput | ||
| }; | ||
|
|
||
| typename DeviceGemm::Arguments args( | ||
| cutlass::gemm::GemmUniversalMode::kGemm, | ||
| cutlass::gemm::GemmCoord{M, N, K}, | ||
| 1, // batch_split | ||
| callback_args, | ||
| A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr | ||
| M * K, N * K, 0, 0, // batch_stride A, B, C, D | ||
| K, K, 0, 0 // stride A, B, C, D | ||
| ); | ||
|
|
||
| DeviceGemm gemm_op; | ||
| auto stream = at::cuda::getCurrentCUDAStream(); | ||
| CUTLASS_STATUS_CHECK(gemm_op.can_implement(args)); | ||
| CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream)); | ||
| } | ||
|
|
||
| // we will do input checks in python. A and B are stored as int8 | ||
| // this function is based on the following cutlass example | ||
| // https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu | ||
| // also with the help of emitted code from cutlass Python | ||
| torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) { | ||
| #if defined(BUILD_INT4_MM_CUTLASS) | ||
| int M = A.size(0); | ||
| int N = B.size(1); | ||
| torch::Tensor C = torch::empty({M, N}, row_scale.options()); | ||
|
|
||
| // some configs for int4 mma | ||
| // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu | ||
| // using default config. this can be tuned. | ||
| using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; | ||
| using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; | ||
| using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; | ||
| constexpr int numStages = 3; | ||
|
|
||
| AT_DISPATCH_SWITCH( | ||
drisspg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| row_scale.scalar_type(), | ||
| "scaled_int4_mm_cutlass", | ||
| AT_DISPATCH_CASE( | ||
| torch::ScalarType::Half, | ||
| [&]() { | ||
| using ElementC = cutlass::half_t; | ||
| scaled_int4_mm_cutlass_dispatch< | ||
| ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( | ||
| A, B, row_scale, col_scale, C); | ||
| } | ||
| ) | ||
| AT_DISPATCH_CASE( | ||
| torch::ScalarType::BFloat16, | ||
| [&]() { | ||
| using ElementC = cutlass::bfloat16_t; | ||
| scaled_int4_mm_cutlass_dispatch< | ||
| ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( | ||
| A, B, row_scale, col_scale, C); | ||
| } | ||
| ) | ||
| ); | ||
|
|
||
| return C; | ||
| #else | ||
| TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); | ||
| return at::Tensor{}; | ||
| #endif | ||
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchao, CUDA, m) { | ||
| m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass); | ||
| m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass); | ||
| } | ||
|
|
||
| } // namespace torchao | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if the universal gemm api can be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look into it. I wrote this quite some time ago...