From 2e6623b532dadea1406efaaddd0de71fab26c0e9 Mon Sep 17 00:00:00 2001 From: MzeroMiko <45236610+MzeroMiko@users.noreply.github.com> Date: Sun, 4 Feb 2024 22:38:44 +0800 Subject: [PATCH 1/9] add feature nrows --- csrc/selective_scan/selective_scan.cpp | 36 +++++++++++++---- .../selective_scan_bwd_bf16_complex.cu | 2 +- .../selective_scan_bwd_bf16_real.cu | 2 +- .../selective_scan_bwd_fp16_complex.cu | 2 +- .../selective_scan_bwd_fp16_real.cu | 2 +- .../selective_scan_bwd_fp32_complex.cu | 2 +- .../selective_scan_bwd_fp32_real.cu | 2 +- .../selective_scan_bwd_kernel.cuh | 40 ++++++++++--------- csrc/selective_scan/selective_scan_fwd2.cu | 14 +++++++ csrc/selective_scan/selective_scan_fwd3.cu | 14 +++++++ csrc/selective_scan/selective_scan_fwd4.cu | 14 +++++++ .../selective_scan/selective_scan_fwd_bf16.cu | 4 +- .../selective_scan/selective_scan_fwd_fp16.cu | 4 +- .../selective_scan/selective_scan_fwd_fp32.cu | 4 +- .../selective_scan_fwd_kernel.cuh | 34 ++++++++-------- mamba_ssm/ops/selective_scan_interface.py | 16 +++++--- tests/ops/test_selective_scan.py | 5 ++- 17 files changed, 132 insertions(+), 65 deletions(-) create mode 100644 csrc/selective_scan/selective_scan_fwd2.cu create mode 100644 csrc/selective_scan/selective_scan_fwd3.cu create mode 100644 csrc/selective_scan/selective_scan_fwd4.cu diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index cde867cd3..c1bcbb6f6 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -8,6 +8,7 @@ #include #include "selective_scan.h" +#define MAX_DSTATE 256 #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -50,10 +51,18 @@ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ } -template +#define INT_SWITCH(INT, NAME, ...) [&] { \ + if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ + else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ + else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ + else {constexpr int NAME = 1; __VA_ARGS__(); } \ +}() \ + + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); void set_ssm_params_fwd(SSMParamsBase ¶ms, @@ -229,7 +238,9 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, + int nrows + ) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -259,7 +270,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const int dstate = A.size(1); const int n_groups = is_variable_B ? B.size(1) : 1; - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); CHECK_SHAPE(u, batch_size, dim, seqlen); CHECK_SHAPE(delta, batch_size, dim, seqlen); @@ -327,7 +339,9 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); + INT_SWITCH(nrows, kNRows, [&] { + selective_scan_fwd_cuda(params, stream); + }); }); }); std::vector result = {out, x}; @@ -346,7 +360,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &out_, c10::optional &dz_, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + int nrows + ) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -379,7 +395,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, const int dstate = A.size(1); const int n_groups = is_variable_B ? B.size(1) : 1; - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); CHECK_SHAPE(u, batch_size, dim, seqlen); CHECK_SHAPE(delta, batch_size, dim, seqlen); @@ -482,7 +499,10 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { - selective_scan_bwd_cuda(params, stream); + constexpr int kNRows = 1; + // INT_SWITCH(nrows, kNRows, [&] { + selective_scan_bwd_cuda(params, stream); + // }); }); }); std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; diff --git a/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu b/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu index c55f0e858..268c904fd 100644 --- a/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +++ b/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu @@ -6,4 +6,4 @@ #include "selective_scan_bwd_kernel.cuh" -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_bwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_bf16_real.cu b/csrc/selective_scan/selective_scan_bwd_bf16_real.cu index 72adaf5cb..66ae72e15 100644 --- a/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +++ b/csrc/selective_scan/selective_scan_bwd_bf16_real.cu @@ -6,4 +6,4 @@ #include "selective_scan_bwd_kernel.cuh" -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu b/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu index df126d7c8..2131f8f6b 100644 --- a/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +++ b/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu @@ -6,4 +6,4 @@ #include "selective_scan_bwd_kernel.cuh" -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_bwd_cuda<1, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp16_real.cu b/csrc/selective_scan/selective_scan_bwd_fp16_real.cu index 3ff271b50..b5e0f7674 100644 --- a/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +++ b/csrc/selective_scan/selective_scan_bwd_fp16_real.cu @@ -6,4 +6,4 @@ #include "selective_scan_bwd_kernel.cuh" -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu b/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu index 555490234..32b79094d 100644 --- a/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +++ b/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu @@ -6,4 +6,4 @@ #include "selective_scan_bwd_kernel.cuh" -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_bwd_cuda<1, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp32_real.cu b/csrc/selective_scan/selective_scan_bwd_fp32_real.cu index a7ed64223..7ef3e0ed9 100644 --- a/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +++ b/csrc/selective_scan/selective_scan_bwd_fp32_real.cu @@ -6,4 +6,4 @@ #include "selective_scan_bwd_kernel.cuh" -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 2ed101148..efb615189 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -31,6 +31,8 @@ struct Selective_Scan_bwd_kernel_traits { using weight_t = weight_t_; static constexpr int kNThreads = kNThreads_; static constexpr int kNItems = kNItems_; + // we are about to add kNRows here + static constexpr int MaxDState = MAX_DSTATE / 1; static constexpr int kNBytes = sizeof(input_t); static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); @@ -89,8 +91,8 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { // Shared memory. extern __shared__ char smem_[]; // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); @@ -104,9 +106,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); - scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * MAX_DSTATE + kNThreads); - weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); - weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + Ktraits::MaxDState); const int batch_id = blockIdx.x; const int dim_id = blockIdx.y; @@ -247,7 +249,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; } else { thread_reverse_data[i - 1].x = delta_a_exp; } @@ -258,8 +260,8 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { } __syncthreads(); thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; // Initialize running total scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); SSMScanPrefixCallbackOp prefix_op(running_prefix); @@ -335,7 +337,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; } else { thread_reverse_data[i - 1].x = delta_a_exp.real_; thread_reverse_data[i - 1].y = -delta_a_exp.imag_; @@ -349,8 +351,8 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { } __syncthreads(); complex_t delta_a_exp = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; // Initialize running total @@ -488,7 +490,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { } } -template +template void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { @@ -498,7 +500,7 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { using Ktraits = Selective_Scan_bwd_kernel_traits; // using Ktraits = Selective_Scan_bwd_kernel_traits; // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); // printf("smem_size = %d\n", kSmemSize); dim3 grid(params.batch, params.dim); auto kernel = &selective_scan_bwd_kernel; @@ -515,17 +517,17 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { }); } -template +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { if (params.seqlen <= 128) { - selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); + selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); } else if (params.seqlen <= 256) { - selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); + selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { - selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); + selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); } else if (params.seqlen <= 1024) { - selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); + selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); } else { - selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); + selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); } } \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd2.cu b/csrc/selective_scan/selective_scan_fwd2.cu new file mode 100644 index 000000000..66286238a --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd2.cu @@ -0,0 +1,14 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd3.cu b/csrc/selective_scan/selective_scan_fwd3.cu new file mode 100644 index 000000000..6ca83c5a6 --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd3.cu @@ -0,0 +1,14 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd4.cu b/csrc/selective_scan/selective_scan_fwd4.cu new file mode 100644 index 000000000..442f575d4 --- /dev/null +++ b/csrc/selective_scan/selective_scan_fwd4.cu @@ -0,0 +1,14 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_bf16.cu b/csrc/selective_scan/selective_scan_fwd_bf16.cu index 2b8615b1d..fb8d1a502 100644 --- a/csrc/selective_scan/selective_scan_fwd_bf16.cu +++ b/csrc/selective_scan/selective_scan_fwd_bf16.cu @@ -6,5 +6,5 @@ #include "selective_scan_fwd_kernel.cuh" -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_fp16.cu b/csrc/selective_scan/selective_scan_fwd_fp16.cu index 015e2a0ef..0d3bc738f 100644 --- a/csrc/selective_scan/selective_scan_fwd_fp16.cu +++ b/csrc/selective_scan/selective_scan_fwd_fp16.cu @@ -6,5 +6,5 @@ #include "selective_scan_fwd_kernel.cuh" -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_fp32.cu b/csrc/selective_scan/selective_scan_fwd_fp32.cu index c142fe020..80b1a552b 100644 --- a/csrc/selective_scan/selective_scan_fwd_fp32.cu +++ b/csrc/selective_scan/selective_scan_fwd_fp32.cu @@ -6,5 +6,5 @@ #include "selective_scan_fwd_kernel.cuh" -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 440a20910..2d18569a1 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -28,6 +28,7 @@ struct Selective_Scan_fwd_kernel_traits { static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; static constexpr int kNItems = kNItems_; static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows; static constexpr int kNBytes = sizeof(input_t); static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); @@ -82,8 +83,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Shared memory. extern __shared__ char smem_[]; // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); @@ -91,12 +92,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { auto& smem_store = reinterpret_cast(smem_); auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + // weight_t *smem_bc = reinterpret_cast(smem_a + Ktraits::MaxDState); scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); const int batch_id = blockIdx.x; const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); // Mzero: fixbug here for nrow input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride @@ -236,10 +237,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { scan_t running_prefix; if constexpr (!kIsComplex) { // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); } SSMScanPrefixCallbackOp prefix_op(running_prefix); @@ -249,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // There's a syncthreads in the scan op, so we don't need to sync here. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; + smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix; // Mzero: fixbug here for nrow x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; } #pragma unroll @@ -302,18 +303,15 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } -template +template void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block - // processing 1 row. - constexpr int kNRows = 1; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { using Ktraits = Selective_Scan_fwd_kernel_traits; // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); // printf("smem_size = %d\n", kSmemSize); dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; @@ -329,17 +327,17 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { }); } -template +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); } } diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index b8f14dd0b..d55f6be6e 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -15,7 +15,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, nrows=1): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -34,9 +34,10 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None + ctx.nrows = nrows last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) @@ -56,12 +57,13 @@ def backward(ctx, dout, *args): u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() + nrows = 1 # ctx.nrows # we have not implemented the nrows for bwd yet # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here + False, nrows # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB @@ -159,7 +161,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, nrows=1): """ xz: (batch, dim, seqlen) """ @@ -215,8 +217,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows ) + ctx.nrows = nrows ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl @@ -243,6 +246,7 @@ def backward(ctx, dout): conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + nrows = 1 # ctx.nrows # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) @@ -252,7 +256,7 @@ def backward(ctx, dout): dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, - True # option to recompute out_z + True, nrows # option to recompute out_z ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 8a834b3c9..dcbc8fbd7 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -35,8 +35,9 @@ @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("nrows", [1,3,4]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype): + delta_softplus, return_last_state, seqlen, itype, wtype, nrows): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -50,7 +51,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z # set seed torch.random.manual_seed(0) batch_size = 2 - dim = 4 + dim = 24 dstate = 8 is_complex = wtype == torch.complex64 A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() From f6cac911b28cc9b78153ae3bed259e77fa38c41c Mon Sep 17 00:00:00 2001 From: MzeroMiko <45236610+MzeroMiko@users.noreply.github.com> Date: Sun, 4 Feb 2024 23:07:28 +0800 Subject: [PATCH 2/9] Update selective_scan_interface.py --- mamba_ssm/ops/selective_scan_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index d55f6be6e..68593d38a 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -77,12 +77,12 @@ def backward(ctx, dout, *args): def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, nrows=1): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, From e5879119cd1dd08cff4b289cf93f216a3089a4c7 Mon Sep 17 00:00:00 2001 From: MzeroMiko <45236610+MzeroMiko@users.noreply.github.com> Date: Sun, 4 Feb 2024 23:28:33 +0800 Subject: [PATCH 3/9] update --- mamba_ssm/ops/selective_scan_interface.py | 1 + setup.py | 3 + tests/ops/test_selective_scan.py | 2 +- tests/ops/test_selective_scan_.py | 311 ++++++++++++++++++++++ 4 files changed, 316 insertions(+), 1 deletion(-) create mode 100644 tests/ops/test_selective_scan_.py diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 68593d38a..57a4a8c68 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -73,6 +73,7 @@ def backward(ctx, dout, *args): dz, ddelta_bias if delta_bias is not None else None, None, + None, None) diff --git a/setup.py b/setup.py index d2a1f2f28..f290ef89c 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,9 @@ def append_nvcc_threads(nvcc_extra_args): "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", + "csrc/selective_scan/selective_scan_fwd2.cu", + "csrc/selective_scan/selective_scan_fwd3.cu", + "csrc/selective_scan/selective_scan_fwd4.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"], diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index dcbc8fbd7..0c9f3314c 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -35,7 +35,7 @@ @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @pytest.mark.parametrize("is_variable_B", [True]) -@pytest.mark.parametrize("nrows", [1,3,4]) +@pytest.mark.parametrize("nrows", [1, 2, 3, 4]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, delta_softplus, return_last_state, seqlen, itype, wtype, nrows): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): diff --git a/tests/ops/test_selective_scan_.py b/tests/ops/test_selective_scan_.py new file mode 100644 index 000000000..ed1e27bb5 --- /dev/null +++ b/tests/ops/test_selective_scan_.py @@ -0,0 +1,311 @@ +# Copyright (C) 2023, Tri Dao. +# here we have a simple test just verify the selective scan in csrc/ +# you should delete it when pull request... + +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from einops import rearrange, repeat +import selective_scan_cuda +# print(selective_scan_cuda) + + +class SelectiveScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, nrows=1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + ctx.nrows = nrows + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + nrows = 1 # ctx.nrows # we have not implemented the nrows for bwd yet + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, nrows # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None, + None) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, nrows=1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +@pytest.mark.parametrize('wtype', [torch.float32]) +# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('itype', [torch.float32]) +# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +# @pytest.mark.parametrize('seqlen', [128]) +# @pytest.mark.parametrize("return_last_state", [False, True]) +@pytest.mark.parametrize("return_last_state", [True]) +# @pytest.mark.parametrize('has_delta_bias', [False, True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +# @pytest.mark.parametrize('delta_softplus', [False, True]) +@pytest.mark.parametrize('delta_softplus', [True]) +# @pytest.mark.parametrize('has_z', [False, True]) +@pytest.mark.parametrize('has_z', [True]) +# @pytest.mark.parametrize('has_D', [False, True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +# @pytest.mark.parametrize("varBC_groups", [1]) +# @pytest.mark.parametrize("is_variable_C", [False, True]) +@pytest.mark.parametrize("is_variable_C", [True]) +# @pytest.mark.parametrize("is_variable_B", [False, True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("nrows", [1, 2, 3, 4]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, nrows): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 24 + dstate = 8 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + +""" +(mamba) (base) LiuYue@Turing17:~/Workspace/GITHUB/mamba$ pytest tests/ops/test_selective_scan_.py +========================================== test session starts =========================================== +platform linux -- Python 3.10.13, pytest-7.4.3, pluggy-1.0.0 +rootdir: /Workspace/LiuYue/GITHUB/mamba +plugins: anyio-4.2.0 +collected 48 items + +tests/ops/test_selective_scan_.py ................................................ [100%] + +========================================== 48 passed in 42.40s =========================================== +""" \ No newline at end of file From 66539012b4c5ce2e2442abd74331d270e46b666d Mon Sep 17 00:00:00 2001 From: Liu Yue <45236610+MzeroMiko@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:25:01 +0800 Subject: [PATCH 4/9] Update test_selective_scan_.py --- tests/ops/test_selective_scan_.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ops/test_selective_scan_.py b/tests/ops/test_selective_scan_.py index ed1e27bb5..d2dd95d30 100644 --- a/tests/ops/test_selective_scan_.py +++ b/tests/ops/test_selective_scan_.py @@ -246,7 +246,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z out, *rest = selective_scan_fn( u, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state + return_last_state=return_last_state, nrows=nrows ) if return_last_state: state = rest[0] @@ -308,4 +308,4 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z tests/ops/test_selective_scan_.py ................................................ [100%] ========================================== 48 passed in 42.40s =========================================== -""" \ No newline at end of file +""" From 28a3fbdaed211334ec94cb2349dbb1157df5b156 Mon Sep 17 00:00:00 2001 From: MzeroMiko <3496274007@qq.com> Date: Sat, 17 Feb 2024 09:52:46 +0800 Subject: [PATCH 5/9] update --- .../selective_scan_bwd.cu} | 20 +- .../selective_scan/cus/selective_scan_bwd2.cu | 11 + .../cus/selective_scan_bwd2_complex.cu | 11 + .../selective_scan_bwd3.cu} | 21 +- .../selective_scan_bwd3_complex.cu} | 21 +- .../selective_scan/cus/selective_scan_bwd4.cu | 11 + .../cus/selective_scan_bwd4_complex.cu | 11 + .../selective_scan_bwd_complex.cu} | 20 +- .../selective_scan_fwd.cu} | 21 +- .../selective_scan/cus/selective_scan_fwd2.cu | 11 + .../selective_scan_fwd2_complex.cu} | 23 +- .../selective_scan/cus/selective_scan_fwd3.cu | 11 + .../selective_scan_fwd3_complex.cu} | 23 +- .../selective_scan/cus/selective_scan_fwd4.cu | 11 + .../selective_scan_fwd4_complex.cu} | 23 +- .../cus/selective_scan_fwd_complex.cu | 11 + csrc/selective_scan/reverse_scan.cuh | 800 ++++++------ csrc/selective_scan/selective_scan.cpp | 1050 ++++++++-------- csrc/selective_scan/selective_scan.h | 202 +-- .../selective_scan_bwd_bf16_real.cu | 9 - .../selective_scan_bwd_fp16_complex.cu | 9 - .../selective_scan_bwd_fp16_real.cu | 9 - .../selective_scan_bwd_fp32_complex.cu | 9 - .../selective_scan_bwd_kernel.cuh | 1117 +++++++++-------- .../selective_scan_bwd_kernel.nrows.cuh | 586 +++++++++ .../selective_scan_bwd_kernel.ori.cuh | 533 ++++++++ .../selective_scan_bwd_kernel.stage1.cuh | 526 ++++++++ csrc/selective_scan/selective_scan_common.h | 442 +++---- .../selective_scan_fwd_kernel.cuh | 686 +++++----- csrc/selective_scan/static_switch.h | 50 +- csrc/selective_scan/uninitialized_copy.cuh | 138 +- setup.py | 28 +- ...can_.py => test_selective_scan_new2old.py} | 272 ++-- 33 files changed, 4280 insertions(+), 2446 deletions(-) rename csrc/selective_scan/{selective_scan_bwd_fp32_real.cu => cus/selective_scan_bwd.cu} (51%) create mode 100644 csrc/selective_scan/cus/selective_scan_bwd2.cu create mode 100644 csrc/selective_scan/cus/selective_scan_bwd2_complex.cu rename csrc/selective_scan/{selective_scan_fwd_fp16.cu => cus/selective_scan_bwd3.cu} (51%) rename csrc/selective_scan/{selective_scan_fwd_fp32.cu => cus/selective_scan_bwd3_complex.cu} (50%) create mode 100644 csrc/selective_scan/cus/selective_scan_bwd4.cu create mode 100644 csrc/selective_scan/cus/selective_scan_bwd4_complex.cu rename csrc/selective_scan/{selective_scan_bwd_bf16_complex.cu => cus/selective_scan_bwd_complex.cu} (50%) rename csrc/selective_scan/{selective_scan_fwd_bf16.cu => cus/selective_scan_fwd.cu} (57%) create mode 100644 csrc/selective_scan/cus/selective_scan_fwd2.cu rename csrc/selective_scan/{selective_scan_fwd2.cu => cus/selective_scan_fwd2_complex.cu} (61%) create mode 100644 csrc/selective_scan/cus/selective_scan_fwd3.cu rename csrc/selective_scan/{selective_scan_fwd3.cu => cus/selective_scan_fwd3_complex.cu} (61%) create mode 100644 csrc/selective_scan/cus/selective_scan_fwd4.cu rename csrc/selective_scan/{selective_scan_fwd4.cu => cus/selective_scan_fwd4_complex.cu} (61%) create mode 100644 csrc/selective_scan/cus/selective_scan_fwd_complex.cu delete mode 100644 csrc/selective_scan/selective_scan_bwd_bf16_real.cu delete mode 100644 csrc/selective_scan/selective_scan_bwd_fp16_complex.cu delete mode 100644 csrc/selective_scan/selective_scan_bwd_fp16_real.cu delete mode 100644 csrc/selective_scan/selective_scan_bwd_fp32_complex.cu create mode 100644 csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh create mode 100644 csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh create mode 100644 csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh rename tests/ops/{test_selective_scan_.py => test_selective_scan_new2old.py} (52%) diff --git a/csrc/selective_scan/selective_scan_bwd_fp32_real.cu b/csrc/selective_scan/cus/selective_scan_bwd.cu similarity index 51% rename from csrc/selective_scan/selective_scan_bwd_fp32_real.cu rename to csrc/selective_scan/cus/selective_scan_bwd.cu index 7ef3e0ed9..c7d5ecf1d 100644 --- a/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +++ b/csrc/selective_scan/cus/selective_scan_bwd.cu @@ -1,9 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/cus/selective_scan_bwd2.cu b/csrc/selective_scan/cus/selective_scan_bwd2.cu new file mode 100644 index 000000000..2af8f1e2c --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_bwd2.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu b/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu new file mode 100644 index 000000000..51bc14cd3 --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<2, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/selective_scan_fwd_fp16.cu b/csrc/selective_scan/cus/selective_scan_bwd3.cu similarity index 51% rename from csrc/selective_scan/selective_scan_fwd_fp16.cu rename to csrc/selective_scan/cus/selective_scan_bwd3.cu index 0d3bc738f..fe9ebcae1 100644 --- a/csrc/selective_scan/selective_scan_fwd_fp16.cu +++ b/csrc/selective_scan/cus/selective_scan_bwd3.cu @@ -1,10 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<1, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/selective_scan_fwd_fp32.cu b/csrc/selective_scan/cus/selective_scan_bwd3_complex.cu similarity index 50% rename from csrc/selective_scan/selective_scan_fwd_fp32.cu rename to csrc/selective_scan/cus/selective_scan_bwd3_complex.cu index 80b1a552b..c58d3f974 100644 --- a/csrc/selective_scan/selective_scan_fwd_fp32.cu +++ b/csrc/selective_scan/cus/selective_scan_bwd3_complex.cu @@ -1,10 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<1, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<3, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/cus/selective_scan_bwd4.cu b/csrc/selective_scan/cus/selective_scan_bwd4.cu new file mode 100644 index 000000000..36555d110 --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_bwd4.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu b/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu new file mode 100644 index 000000000..11417e17a --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<4, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu b/csrc/selective_scan/cus/selective_scan_bwd_complex.cu similarity index 50% rename from csrc/selective_scan/selective_scan_bwd_bf16_complex.cu rename to csrc/selective_scan/cus/selective_scan_bwd_complex.cu index 268c904fd..29e6a90d0 100644 --- a/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +++ b/csrc/selective_scan/cus/selective_scan_bwd_complex.cu @@ -1,9 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<1, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/selective_scan_fwd_bf16.cu b/csrc/selective_scan/cus/selective_scan_fwd.cu similarity index 57% rename from csrc/selective_scan/selective_scan_fwd_bf16.cu rename to csrc/selective_scan/cus/selective_scan_fwd.cu index fb8d1a502..1b19a9110 100644 --- a/csrc/selective_scan/selective_scan_fwd_bf16.cu +++ b/csrc/selective_scan/cus/selective_scan_fwd.cu @@ -1,10 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/cus/selective_scan_fwd2.cu b/csrc/selective_scan/cus/selective_scan_fwd2.cu new file mode 100644 index 000000000..1b24ae355 --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_fwd2.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd2.cu b/csrc/selective_scan/cus/selective_scan_fwd2_complex.cu similarity index 61% rename from csrc/selective_scan/selective_scan_fwd2.cu rename to csrc/selective_scan/cus/selective_scan_fwd2_complex.cu index 66286238a..e84a2c588 100644 --- a/csrc/selective_scan/selective_scan_fwd2.cu +++ b/csrc/selective_scan/cus/selective_scan_fwd2_complex.cu @@ -1,14 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<2, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<2, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<2, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_fwd_cuda<2, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/cus/selective_scan_fwd3.cu b/csrc/selective_scan/cus/selective_scan_fwd3.cu new file mode 100644 index 000000000..cce00b4e2 --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_fwd3.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd3.cu b/csrc/selective_scan/cus/selective_scan_fwd3_complex.cu similarity index 61% rename from csrc/selective_scan/selective_scan_fwd3.cu rename to csrc/selective_scan/cus/selective_scan_fwd3_complex.cu index 6ca83c5a6..a8dc76640 100644 --- a/csrc/selective_scan/selective_scan_fwd3.cu +++ b/csrc/selective_scan/cus/selective_scan_fwd3_complex.cu @@ -1,14 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<3, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<3, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<3, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_fwd_cuda<3, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/cus/selective_scan_fwd4.cu b/csrc/selective_scan/cus/selective_scan_fwd4.cu new file mode 100644 index 000000000..74383e3a7 --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_fwd4.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_fwd4.cu b/csrc/selective_scan/cus/selective_scan_fwd4_complex.cu similarity index 61% rename from csrc/selective_scan/selective_scan_fwd4.cu rename to csrc/selective_scan/cus/selective_scan_fwd4_complex.cu index 442f575d4..4dd204a49 100644 --- a/csrc/selective_scan/selective_scan_fwd4.cu +++ b/csrc/selective_scan/cus/selective_scan_fwd4_complex.cu @@ -1,14 +1,11 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<4, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<4, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<4, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_fwd_cuda<4, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/cus/selective_scan_fwd_complex.cu b/csrc/selective_scan/cus/selective_scan_fwd_complex.cu new file mode 100644 index 000000000..20f1a86cb --- /dev/null +++ b/csrc/selective_scan/cus/selective_scan_fwd_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index d7e93174b..0baeebb05 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -1,401 +1,401 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include -// #include -#include "uninitialized_copy.cuh" - -/** - * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ReductionOp> -__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { - static_assert(LENGTH > 0); - T retval = input[LENGTH - 1]; - #pragma unroll - for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } - return retval; -} - -/** - * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanInclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - T inclusive = postfix; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(inclusive, input[i]); - output[i] = inclusive; - } -} - -/** - * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanExclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - // Careful, output maybe be aliased to input - T exclusive = postfix; - T inclusive; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(exclusive, input[i]); - output[i] = exclusive; - exclusive = inclusive; - } - return inclusive; -} - - -/** - * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. - * - * LOGICAL_WARP_THREADS must be a power-of-two - */ -template < - typename T, ///< Data type being scanned - int LOGICAL_WARP_THREADS ///< Number of threads per logical warp - > -struct WarpReverseScan { - //--------------------------------------------------------------------- - // Constants and type definitions - //--------------------------------------------------------------------- - - /// Whether the logical warp size and the PTX warp size coincide - static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); - /// The number of warp scan steps - static constexpr int STEPS = cub::Log2::VALUE; - static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); - - - //--------------------------------------------------------------------- - // Thread fields - //--------------------------------------------------------------------- - - /// Lane index in logical warp - unsigned int lane_id; - - /// Logical warp index in 32-thread physical warp - unsigned int warp_id; - - /// 32-thread physical warp member mask of logical warp - unsigned int member_mask; - - //--------------------------------------------------------------------- - // Construction - //--------------------------------------------------------------------- - - /// Constructor - explicit __device__ __forceinline__ - WarpReverseScan() - : lane_id(cub::LaneId()) - , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) - , member_mask(cub::WarpMask(warp_id)) - { - if (!IS_ARCH_WARP) { - lane_id = lane_id % LOGICAL_WARP_THREADS; - } - } - - - /// Broadcast - __device__ __forceinline__ T Broadcast( - T input, ///< [in] The value to broadcast - int src_lane) ///< [in] Which warp lane is to do the broadcasting - { - return cub::ShuffleIndex(input, src_lane, member_mask); - } - - - /// Inclusive scan - template - __device__ __forceinline__ void InclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op) ///< [in] Binary scan operator - { - inclusive_output = input; - #pragma unroll - for (int STEP = 0; STEP < STEPS; STEP++) { - int offset = 1 << STEP; - T temp = cub::ShuffleDown( - inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask - ); - // Perform scan op if from a valid peer - inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset - ? inclusive_output : scan_op(temp, inclusive_output); - } - } - - /// Exclusive scan - // Get exclusive from inclusive - template - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op, ///< [in] Binary scan operator - T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. - { - T inclusive_output; - InclusiveReverseScan(input, inclusive_output, scan_op); - warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); - // initial value unknown - exclusive_output = cub::ShuffleDown( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - - /** - * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. - */ - template - __device__ __forceinline__ void ReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. - T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. - ScanOpT scan_op) ///< [in] Binary scan operator - { - InclusiveReverseScan(input, inclusive_output, scan_op); - // initial value unknown - exclusive_output = cub::ShuffleDown( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - -}; - -/** - * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. - */ -template < - typename T, ///< Data type being scanned - int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension - bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure - > -struct BlockReverseScan { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- - - /// Constants - /// The thread block size in threads - static constexpr int BLOCK_THREADS = BLOCK_DIM_X; - - /// Layout type for padded thread block raking grid - using BlockRakingLayout = cub::BlockRakingLayout; - // The number of reduction elements is not a multiple of the number of raking threads for now - static_assert(BlockRakingLayout::UNGUARDED); - - /// Number of raking threads - static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; - /// Number of raking elements per warp synchronous raking thread - static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; - /// Cooperative work can be entirely warp synchronous - static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); - - /// WarpReverseScan utility type - using WarpReverseScan = WarpReverseScan; - - /// Shared memory storage layout type - struct _TempStorage { - typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid - }; - - - /// Alias wrapper allowing storage to be unioned - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - // Thread fields - _TempStorage &temp_storage; - unsigned int linear_tid; - T cached_segment[SEGMENT_LENGTH]; - - - //--------------------------------------------------------------------- - // Utility methods - //--------------------------------------------------------------------- - - /// Performs upsweep raking reduction, returning the aggregate - template - __device__ __forceinline__ T Upsweep(ScanOp scan_op) { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data into registers - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; - #pragma unroll - for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { - raking_partial = scan_op(raking_partial, cached_segment[i]); - } - return raking_partial; - } - - - /// Performs exclusive downsweep raking scan - template - __device__ __forceinline__ void ExclusiveDownsweep( - ScanOp scan_op, - T raking_partial) - { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data back into registers - if (!MEMOIZE) { - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - } - ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); - // Write data back to smem - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } - } - - - //--------------------------------------------------------------------- - // Constructors - //--------------------------------------------------------------------- - - /// Constructor - __device__ __forceinline__ BlockReverseScan( - TempStorage &temp_storage) - : - temp_storage(temp_storage.Alias()), - linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) - {} - - - /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - template < - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item - T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan operator - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. - { - if (WARP_SYNCHRONOUS) { - // Short-circuit directly to warp-synchronous scan - T block_aggregate; - WarpReverseScan warp_scan; - warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); - // Obtain warp-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); - } else { - // Place thread partial into shared memory raking grid - T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); - detail::uninitialized_copy(placement_ptr, input); - cub::CTA_SYNC(); - // Reduce parallelism down to just raking threads - if (linear_tid < RAKING_THREADS) { - WarpReverseScan warp_scan; - // Raking upsweep reduction across shared partials - T upsweep_partial = Upsweep(scan_op); - // Warp-synchronous scan - T exclusive_partial, block_aggregate; - warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); - // Obtain block-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - // Update postfix with warpscan exclusive partial - T downsweep_postfix = linear_tid == RAKING_THREADS - 1 - ? block_postfix : scan_op(block_postfix, exclusive_partial); - // Exclusive raking downsweep scan - ExclusiveDownsweep(scan_op, downsweep_postfix); - } - cub::CTA_SYNC(); - // Grab thread postfix from shared memory - exclusive_output = *placement_ptr; - - // // Compute warp scan in each warp. - // // The exclusive output from the last lane in each warp is invalid. - // T inclusive_output; - // WarpReverseScan warp_scan; - // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); - - // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. - // T block_aggregate; - // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); - - // // Apply warp postfix to our lane's partial - // if (warp_id != 0) { - // exclusive_output = scan_op(warp_postfix, exclusive_output); - // if (lane_id == 0) { exclusive_output = warp_postfix; } - // } - - // // Use the first warp to determine the thread block postfix, returning the result in lane0 - // if (warp_id == 0) { - // T block_postfix = block_postfix_callback_op(block_aggregate); - // if (lane_id == 0) { - // // Share the postfix with all threads - // detail::uninitialized_copy(&temp_storage.block_postfix, - // block_postfix); - - // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 - // } - // } - - // cub::CTA_SYNC(); - - // // Incorporate thread block postfix into outputs - // T block_postfix = temp_storage.block_postfix; - // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } - } - } - - - /** - * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - */ - template < - int ITEMS_PER_THREAD, - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void InclusiveReverseScan( - T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items - T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan functor - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. - { - // Reduce consecutive thread items in registers - T thread_postfix = ThreadReverseReduce(input, scan_op); - // Exclusive thread block-scan - ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); - // Inclusive scan in registers with postfix as seed - ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); - } - +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +// #include +#include "uninitialized_copy.cuh" + +/** + * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { + static_assert(LENGTH > 0); + T retval = input[LENGTH - 1]; + #pragma unroll + for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } + return retval; +} + +/** + * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanInclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + T inclusive = postfix; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } +} + +/** + * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanExclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + // Careful, output maybe be aliased to input + T exclusive = postfix; + T inclusive; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + return inclusive; +} + + +/** + * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS ///< Number of threads per logical warp + > +struct WarpReverseScan { + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + /// Whether the logical warp size and the PTX warp size coincide + static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); + /// The number of warp scan steps + static constexpr int STEPS = cub::Log2::VALUE; + static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + explicit __device__ __forceinline__ + WarpReverseScan() + : lane_id(cub::LaneId()) + , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) + , member_mask(cub::WarpMask(warp_id)) + { + if (!IS_ARCH_WARP) { + lane_id = lane_id % LOGICAL_WARP_THREADS; + } + } + + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return cub::ShuffleIndex(input, src_lane, member_mask); + } + + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) { + int offset = 1 << STEP; + T temp = cub::ShuffleDown( + inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask + ); + // Perform scan op if from a valid peer + inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset + ? inclusive_output : scan_op(temp, inclusive_output); + } + } + + /// Exclusive scan + // Get exclusive from inclusive + template + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + T inclusive_output; + InclusiveReverseScan(input, inclusive_output, scan_op); + warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + + /** + * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. + */ + template + __device__ __forceinline__ void ReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + ScanOpT scan_op) ///< [in] Binary scan operator + { + InclusiveReverseScan(input, inclusive_output, scan_op); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + +}; + +/** + * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. + */ +template < + typename T, ///< Data type being scanned + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure + > +struct BlockReverseScan { + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + /// The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X; + + /// Layout type for padded thread block raking grid + using BlockRakingLayout = cub::BlockRakingLayout; + // The number of reduction elements is not a multiple of the number of raking threads for now + static_assert(BlockRakingLayout::UNGUARDED); + + /// Number of raking threads + static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; + /// Number of raking elements per warp synchronous raking thread + static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; + /// Cooperative work can be entirely warp synchronous + static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); + + /// WarpReverseScan utility type + using WarpReverseScan = WarpReverseScan; + + /// Shared memory storage layout type + struct _TempStorage { + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + T cached_segment[SEGMENT_LENGTH]; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /// Performs upsweep raking reduction, returning the aggregate + template + __device__ __forceinline__ T Upsweep(ScanOp scan_op) { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data into registers + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; + #pragma unroll + for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { + raking_partial = scan_op(raking_partial, cached_segment[i]); + } + return raking_partial; + } + + + /// Performs exclusive downsweep raking scan + template + __device__ __forceinline__ void ExclusiveDownsweep( + ScanOp scan_op, + T raking_partial) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data back into registers + if (!MEMOIZE) { + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + } + ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); + // Write data back to smem + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } + } + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockReverseScan( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) + {} + + + /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpReverseScan warp_scan; + warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); + // Obtain warp-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); + } else { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + detail::uninitialized_copy(placement_ptr, input); + cub::CTA_SYNC(); + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) { + WarpReverseScan warp_scan; + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + // Obtain block-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + // Update postfix with warpscan exclusive partial + T downsweep_postfix = linear_tid == RAKING_THREADS - 1 + ? block_postfix : scan_op(block_postfix, exclusive_partial); + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, downsweep_postfix); + } + cub::CTA_SYNC(); + // Grab thread postfix from shared memory + exclusive_output = *placement_ptr; + + // // Compute warp scan in each warp. + // // The exclusive output from the last lane in each warp is invalid. + // T inclusive_output; + // WarpReverseScan warp_scan; + // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); + + // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. + // T block_aggregate; + // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); + + // // Apply warp postfix to our lane's partial + // if (warp_id != 0) { + // exclusive_output = scan_op(warp_postfix, exclusive_output); + // if (lane_id == 0) { exclusive_output = warp_postfix; } + // } + + // // Use the first warp to determine the thread block postfix, returning the result in lane0 + // if (warp_id == 0) { + // T block_postfix = block_postfix_callback_op(block_aggregate); + // if (lane_id == 0) { + // // Share the postfix with all threads + // detail::uninitialized_copy(&temp_storage.block_postfix, + // block_postfix); + + // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 + // } + // } + + // cub::CTA_SYNC(); + + // // Incorporate thread block postfix into outputs + // T block_postfix = temp_storage.block_postfix; + // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } + } + } + + + /** + * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void InclusiveReverseScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. + { + // Reduce consecutive thread items in registers + T thread_postfix = ThreadReverseReduce(input, scan_op); + // Exclusive thread block-scan + ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); + // Inclusive scan in registers with postfix as seed + ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); + } + }; \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index c1bcbb6f6..bf06ba60d 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -1,517 +1,533 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include -#include - -#include "selective_scan.h" -#define MAX_DSTATE 256 - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::ComplexFloat) { \ - using weight_t = c10::complex; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -#define INT_SWITCH(INT, NAME, ...) [&] { \ - if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ - else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ - else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ - else {constexpr int NAME = 1; __VA_ARGS__(); } \ -}() \ - - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); - -template -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); - -void set_ssm_params_fwd(SSMParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const at::Tensor u, - const at::Tensor delta, - const at::Tensor A, - const at::Tensor B, - const at::Tensor C, - const at::Tensor out, - const at::Tensor z, - const at::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - bool has_z, - bool delta_softplus) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.dstate = dstate; - params.n_groups = n_groups; - params.n_chunks = n_chunks; - params.dim_ngroups_ratio = dim / n_groups; - - params.delta_softplus = delta_softplus; - - params.is_variable_B = is_variable_B; - params.is_variable_C = is_variable_C; - - // Set the pointers and strides. - params.u_ptr = u.data_ptr(); - params.delta_ptr = delta.data_ptr(); - params.A_ptr = A.data_ptr(); - params.B_ptr = B.data_ptr(); - params.C_ptr = C.data_ptr(); - params.D_ptr = D_ptr; - params.delta_bias_ptr = delta_bias_ptr; - params.out_ptr = out.data_ptr(); - params.x_ptr = x_ptr; - params.z_ptr = has_z ? z.data_ptr() : nullptr; - params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; - // All stride are in elements, not bytes. - params.A_d_stride = A.stride(0); - params.A_dstate_stride = A.stride(1); - if (!is_variable_B) { - params.B_d_stride = B.stride(0); - } else { - params.B_batch_stride = B.stride(0); - params.B_group_stride = B.stride(1); - } - params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); - if (!is_variable_C) { - params.C_d_stride = C.stride(0); - } else { - params.C_batch_stride = C.stride(0); - params.C_group_stride = C.stride(1); - } - params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); - params.u_batch_stride = u.stride(0); - params.u_d_stride = u.stride(1); - params.delta_batch_stride = delta.stride(0); - params.delta_d_stride = delta.stride(1); - if (has_z) { - params.z_batch_stride = z.stride(0); - params.z_d_stride = z.stride(1); - params.out_z_batch_stride = out_z.stride(0); - params.out_z_d_stride = out_z.stride(1); - } - params.out_batch_stride = out.stride(0); - params.out_d_stride = out.stride(1); -} - -void set_ssm_params_bwd(SSMParamsBwd ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const at::Tensor u, - const at::Tensor delta, - const at::Tensor A, - const at::Tensor B, - const at::Tensor C, - const at::Tensor z, - const at::Tensor out, - const at::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - const at::Tensor dout, - const at::Tensor du, - const at::Tensor ddelta, - const at::Tensor dA, - const at::Tensor dB, - const at::Tensor dC, - const at::Tensor dz, - void* dD_ptr, - void* ddelta_bias_ptr, - bool has_z, - bool delta_softplus, - bool recompute_out_z) { - // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z - set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, has_z ? out : dout, - has_z ? z : dout, - // If not recompute_out_z, pass dout instead of out_z. - // This won't be used by the bwd kernel - recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); - if (!recompute_out_z) { params.out_z_ptr = nullptr; } - - // Set the pointers and strides. - params.dout_ptr = dout.data_ptr(); - params.du_ptr = du.data_ptr(); - params.dA_ptr = dA.data_ptr(); - params.dB_ptr = dB.data_ptr(); - params.dC_ptr = dC.data_ptr(); - params.dD_ptr = dD_ptr; - params.ddelta_ptr = ddelta.data_ptr(); - params.ddelta_bias_ptr = ddelta_bias_ptr; - params.dz_ptr = has_z ? dz.data_ptr() : nullptr; - // All stride are in elements, not bytes. - params.dout_batch_stride = dout.stride(0); - params.dout_d_stride = dout.stride(1); - params.dA_d_stride = dA.stride(0); - params.dA_dstate_stride = dA.stride(1); - if (!is_variable_B) { - params.dB_d_stride = dB.stride(0); - } else { - params.dB_batch_stride = dB.stride(0); - params.dB_group_stride = dB.stride(1); - } - params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); - if (!is_variable_C) { - params.dC_d_stride = dC.stride(0); - } else { - params.dC_batch_stride = dC.stride(0); - params.dC_group_stride = dC.stride(1); - } - params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); - params.du_batch_stride = du.stride(0); - params.du_d_stride = du.stride(1); - params.ddelta_batch_stride = ddelta.stride(0); - params.ddelta_d_stride = ddelta.stride(1); - if (has_z) { - params.dz_batch_stride = dz.stride(0); - params.dz_d_stride = dz.stride(1); - } -} - -std::vector -selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, - const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, - bool delta_softplus, - int nrows - ) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); - TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - - at::Tensor z, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout - at::Tensor out = torch::empty_like(delta); - at::Tensor x; - x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); - - SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, out, z, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.data_ptr(), - has_z, - delta_softplus); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { - INT_SWITCH(nrows, kNRows, [&] { - selective_scan_fwd_cuda(params, stream); - }); - }); - }); - std::vector result = {out, x}; - if (has_z) { result.push_back(out_z); } - return result; -} - -std::vector -selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, - const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, - const at::Tensor &dout, - const c10::optional &x_, - const c10::optional &out_, - c10::optional &dz_, - bool delta_softplus, - bool recompute_out_z, - int nrows - ) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - TORCH_CHECK(dout.scalar_type() == input_type); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - TORCH_CHECK(dout.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); - TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); - TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } - CHECK_SHAPE(dout, batch_size, dim, seqlen); - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - - at::Tensor z, out, dz, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - - TORCH_CHECK(out_.has_value()); - out = out_.value(); - TORCH_CHECK(out.scalar_type() == input_type); - TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); - CHECK_SHAPE(out, batch_size, dim, seqlen); - - if (dz_.has_value()) { - dz = dz_.value(); - TORCH_CHECK(dz.scalar_type() == input_type); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); - CHECK_SHAPE(dz, batch_size, dim, seqlen); - } else { - dz = torch::empty_like(z); - } - if (recompute_out_z) { - out_z = torch::empty_like(out); - } - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } - if (x_.has_value()) { - auto x = x_.value(); - TORCH_CHECK(x.scalar_type() == weight_type); - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(x.is_contiguous()); - CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); - } - - at::Tensor du = torch::empty_like(u); - at::Tensor ddelta = torch::empty_like(delta); - at::Tensor dA = torch::zeros_like(A); - at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); - at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); - at::Tensor dD; - if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } - at::Tensor ddelta_bias; - if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } - - SSMParamsBwd params; - set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, z, out, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x_.has_value() ? x_.value().data_ptr() : nullptr, - dout, du, ddelta, dA, dB, dC, dz, - D_.has_value() ? dD.data_ptr() : nullptr, - delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, - has_z, delta_softplus, recompute_out_z); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { - constexpr int kNRows = 1; - // INT_SWITCH(nrows, kNRows, [&] { - selective_scan_bwd_cuda(params, stream); - // }); - }); - }); - std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; - if (has_z) { result.push_back(dz); } - if (recompute_out_z) { result.push_back(out_z); } - return result; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fwd", &selective_scan_fwd, "Selective scan forward"); - m.def("bwd", &selective_scan_bwd, "Selective scan backward"); -} +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" +#define MAX_DSTATE 256 + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::ComplexFloat) { \ + using weight_t = c10::complex; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define INT_SWITCH_FWD(INT, NAME, ...) [&] { \ + if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ + else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ + else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ + else {constexpr int NAME = 1; __VA_ARGS__(); } \ +}() \ + +#define INT_SWITCH_BWD(INT, NAME, ...) [&] { \ + if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ + else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ + else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ + else {constexpr int NAME = 1; __VA_ARGS__(); } \ +}() \ + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + const at::Tensor z, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor z, + const at::Tensor out, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + const at::Tensor dz, + void* dD_ptr, + void* ddelta_bias_ptr, + bool has_z, + bool delta_softplus, + bool recompute_out_z) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, has_z ? out : dout, + has_z ? z : dout, + // If not recompute_out_z, pass dout instead of out_z. + // This won't be used by the bwd kernel + recompute_out_z ? out_z : dout, + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + if (!recompute_out_z) { params.out_z_ptr = nullptr; } + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + params.dz_ptr = has_z ? dz.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + if (!is_variable_B) { + params.dB_d_stride = dB.stride(0); + } else { + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + } + params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); + if (!is_variable_C) { + params.dC_d_stride = dC.stride(0); + } else { + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + } + params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + if (has_z) { + params.dz_batch_stride = dz.stride(0); + params.dz_d_stride = dz.stride(1); + } +} + +template +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + // TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + using weight_t = std::conditional_t>; + TORCH_CHECK(weight_type == (is_complex ? at::ScalarType::ComplexFloat : at::ScalarType::Float)); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + // const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + has_z, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + // DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { + INT_SWITCH_FWD(nrows, kNRows, [&] { + selective_scan_fwd_cuda(params, stream); + }); + // }); + }); + std::vector result = {out, x}; + if (has_z) { result.push_back(out_z); } + return result; +} + +template +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + const c10::optional &out_, + c10::optional &dz_, + bool delta_softplus, + bool recompute_out_z, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + // TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + using weight_t = std::conditional_t>; + TORCH_CHECK(weight_type == (is_complex ? at::ScalarType::ComplexFloat : at::ScalarType::Float)); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + // const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out, dz, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + + TORCH_CHECK(out_.has_value()); + out = out_.value(); + TORCH_CHECK(out.scalar_type() == input_type); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); + CHECK_SHAPE(out, batch_size, dim, seqlen); + + if (dz_.has_value()) { + dz = dz_.value(); + TORCH_CHECK(dz.scalar_type() == input_type); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); + CHECK_SHAPE(dz, batch_size, dim, seqlen); + } else { + dz = torch::empty_like(z); + } + if (recompute_out_z) { + out_z = torch::empty_like(out); + } + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, z, out, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, dz, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + has_z, delta_softplus, recompute_out_z); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + // DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { + INT_SWITCH_BWD(nrows, kNRows, [&] { + selective_scan_bwd_cuda(params, stream); + }); + // }); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + if (has_z) { result.push_back(dz); } + if (recompute_out_z) { result.push_back(out_z); } + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); + // m.def("fwdc", &selective_scan_fwd, "Selective scan forward for complex"); + // m.def("bwdc", &selective_scan_bwd, "Selective scan backward for complex"); +} diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index e2c7bcdbd..86eaa220b 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -1,101 +1,101 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMScanParamsBase { - using index_t = uint32_t; - - int batch, seqlen, n_chunks; - index_t a_batch_stride; - index_t b_batch_stride; - index_t out_batch_stride; - - // Common data pointers. - void *__restrict__ a_ptr; - void *__restrict__ b_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, dstate, n_groups, n_chunks; - int dim_ngroups_ratio; - bool is_variable_B; - bool is_variable_C; - - bool delta_softplus; - - index_t A_d_stride; - index_t A_dstate_stride; - index_t B_batch_stride; - index_t B_d_stride; - index_t B_dstate_stride; - index_t B_group_stride; - index_t C_batch_stride; - index_t C_d_stride; - index_t C_dstate_stride; - index_t C_group_stride; - index_t u_batch_stride; - index_t u_d_stride; - index_t delta_batch_stride; - index_t delta_d_stride; - index_t z_batch_stride; - index_t z_d_stride; - index_t out_batch_stride; - index_t out_d_stride; - index_t out_z_batch_stride; - index_t out_z_d_stride; - - // Common data pointers. - void *__restrict__ A_ptr; - void *__restrict__ B_ptr; - void *__restrict__ C_ptr; - void *__restrict__ D_ptr; - void *__restrict__ u_ptr; - void *__restrict__ delta_ptr; - void *__restrict__ delta_bias_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; - void *__restrict__ z_ptr; - void *__restrict__ out_z_ptr; -}; - -struct SSMParamsBwd: public SSMParamsBase { - index_t dout_batch_stride; - index_t dout_d_stride; - index_t dA_d_stride; - index_t dA_dstate_stride; - index_t dB_batch_stride; - index_t dB_group_stride; - index_t dB_d_stride; - index_t dB_dstate_stride; - index_t dC_batch_stride; - index_t dC_group_stride; - index_t dC_d_stride; - index_t dC_dstate_stride; - index_t du_batch_stride; - index_t du_d_stride; - index_t dz_batch_stride; - index_t dz_d_stride; - index_t ddelta_batch_stride; - index_t ddelta_d_stride; - - // Common data pointers. - void *__restrict__ dout_ptr; - void *__restrict__ dA_ptr; - void *__restrict__ dB_ptr; - void *__restrict__ dC_ptr; - void *__restrict__ dD_ptr; - void *__restrict__ du_ptr; - void *__restrict__ dz_ptr; - void *__restrict__ ddelta_ptr; - void *__restrict__ ddelta_bias_ptr; -}; +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMScanParamsBase { + using index_t = uint32_t; + + int batch, seqlen, n_chunks; + index_t a_batch_stride; + index_t b_batch_stride; + index_t out_batch_stride; + + // Common data pointers. + void *__restrict__ a_ptr; + void *__restrict__ b_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; +}; + +struct SSMParamsBwd: public SSMParamsBase { + index_t dout_batch_stride; + index_t dout_d_stride; + index_t dA_d_stride; + index_t dA_dstate_stride; + index_t dB_batch_stride; + index_t dB_group_stride; + index_t dB_d_stride; + index_t dB_dstate_stride; + index_t dC_batch_stride; + index_t dC_group_stride; + index_t dC_d_stride; + index_t dC_dstate_stride; + index_t du_batch_stride; + index_t du_d_stride; + index_t dz_batch_stride; + index_t dz_d_stride; + index_t ddelta_batch_stride; + index_t ddelta_d_stride; + + // Common data pointers. + void *__restrict__ dout_ptr; + void *__restrict__ dA_ptr; + void *__restrict__ dB_ptr; + void *__restrict__ dC_ptr; + void *__restrict__ dD_ptr; + void *__restrict__ du_ptr; + void *__restrict__ dz_ptr; + void *__restrict__ ddelta_ptr; + void *__restrict__ ddelta_bias_ptr; +}; diff --git a/csrc/selective_scan/selective_scan_bwd_bf16_real.cu b/csrc/selective_scan/selective_scan_bwd_bf16_real.cu deleted file mode 100644 index 66ae72e15..000000000 --- a/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu b/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu deleted file mode 100644 index 2131f8f6b..000000000 --- a/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<1, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp16_real.cu b/csrc/selective_scan/selective_scan_bwd_fp16_real.cu deleted file mode 100644 index b5e0f7674..000000000 --- a/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu b/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu deleted file mode 100644 index 32b79094d..000000000 --- a/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<1, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index efb615189..ef2af5ab3 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -1,533 +1,586 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK -#include // For atomicAdd on complex - -#include -#include -#include -#include - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "reverse_scan.cuh" -#include "static_switch.h" - -template __device__ __forceinline__ scalar_t conj(scalar_t x); -template<> __device__ __forceinline__ float conj(float x) { return x; } -template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } - -template -struct Selective_Scan_bwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kNItems = kNItems_; - // we are about to add kNRows here - static constexpr int MaxDState = MAX_DSTATE / 1; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; - static constexpr bool kHasZ = kHasZ_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. - // For complex this would lead to massive register spilling, so we keep it at 2. - static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; - using vec_t = typename BytesToType::Type; - using scan_t = std::conditional_t; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockReverseScanT = BlockReverseScan; - using BlockReduceT = cub::BlockReduce; - using BlockReduceFloatT = cub::BlockReduce; - using BlockReduceComplexT = cub::BlockReduce; - using BlockExchangeT = cub::BlockExchange; - static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); - static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_bwd_kernel(SSMParamsBwd params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); - auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); - auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); - auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); - auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); - auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); - weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); - scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads); - weight_t *smem_da = reinterpret_cast(smem_running_postfix + Ktraits::MaxDState); - weight_t *smem_dbc = reinterpret_cast(smem_da + Ktraits::MaxDState); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * params.delta_d_stride; - input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride - + dim_id * params.dout_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; - weight_t *dB = reinterpret_cast(params.dB_ptr) - + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); - weight_t *dC = reinterpret_cast(params.dC_ptr) - + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); - float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; - float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; - float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; - float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; - scan_t *x = params.x_ptr == nullptr - ? nullptr - : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; - float dD_val = 0; - float ddelta_bias_val = 0; - - constexpr int kChunkSize = kNThreads * kNItems; - u += (params.n_chunks - 1) * kChunkSize; - delta += (params.n_chunks - 1) * kChunkSize; - dout += (params.n_chunks - 1) * kChunkSize; - Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { - input_t u_vals[kNItems]; - input_t delta_vals_load[kNItems]; - input_t dout_vals_load[kNItems]; - __syncthreads(); - load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); - u -= kChunkSize; - __syncthreads(); - load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - // Will reload delta at the same location if kDeltaSoftplus - if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } - __syncthreads(); - load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - dout -= kChunkSize; - - float dout_vals[kNItems], delta_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dout_vals[i] = float(dout_vals_load[i]); - delta_vals[i] = float(delta_vals_load[i]) + delta_bias; - if constexpr (kDeltaSoftplus) { - delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; - } - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * params.z_d_stride + chunk * kChunkSize; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * params.out_d_stride + chunk * kChunkSize; - input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride - + dim_id * params.dz_d_stride + chunk * kChunkSize; - input_t z_vals[kNItems], out_vals[kNItems]; - __syncthreads(); - load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - __syncthreads(); - load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); - float dz_vals[kNItems], z_silu_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); - z_silu_vals[i] = z_val * z_sigmoid_val; - dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val - * (1.0f + z_val * (1.0f - z_sigmoid_val)); - dout_vals[i] *= z_silu_vals[i]; - } - __syncthreads(); - store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); - if (params.out_z_ptr != nullptr) { // Recompute and store out_z - float out_z_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); - // } - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * params.out_z_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); - } - } - - float du_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } - - float ddelta_vals[kNItems] = {0}; - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - const weight_t A_val = A[state_idx * params.A_dstate_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - weight_t A_scaled; - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_scaled = A_val * kLog2e; - } else { - A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); - } - weight_t B_val, C_val; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (!kIsVariableB) { - B_val = B[state_idx * params.B_dstate_stride]; - } else { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - if constexpr (!kIsVariableC) { - C_val = C[state_idx * params.C_dstate_stride]; - } else { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - // const weight_t A_val = smem_a[state_idx]; - scan_t thread_data[kNItems], thread_reverse_data[kNItems]; - if constexpr (!kIsComplex) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); - thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp; - } - thread_reverse_data[i].y = dout_vals[i] * - (!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - } - __syncthreads(); - thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) - : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp postfix_op(running_postfix); - Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float dx = thread_reverse_data[i].y; - const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; - du_vals[i] += ddelta_u * delta_vals[i]; - const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; - dA_val += dx * delta_vals[i] * a; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += dout_vals[i] * thread_data[i].y; - } - } - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - if constexpr (kIsVariableB) { - Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); - } - if constexpr (kIsVariableC) { - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); - } - const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; - weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; - weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } - if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float2 dA_dBC_val = make_float2(dA_val, dBC_val); - dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = dA_dBC_val.x; - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; - } - } else { - dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } else { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); - weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); - thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp.real_; - thread_reverse_data[i - 1].y = -delta_a_exp.imag_; - } - complex_t dout_BC = 2 * dout_vals[i] - * conj(!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - thread_reverse_data[i].z = dout_BC.real_; - thread_reverse_data[i].w = dout_BC.imag_; - } - __syncthreads(); - complex_t delta_a_exp = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) - : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; - thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; - thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp postfix_op(running_postfix); - Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - complex_t x = complex_t(thread_data[i].z, thread_data[i].w); - complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); - float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += (2 * dout_vals[i]) * conj(x); - } - } - const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); - du_vals[i] += ddelta_u * delta_vals[i]; - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; - dA_val += delta_vals[i] * dx * a_conj; - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; - if constexpr (kIsVariableB) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dB_vals_f[i * 2] = dB_vals[i].real_; - dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; - } - Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); - } - if constexpr (kIsVariableC) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dC_vals_f[i * 2] = dC_vals[i].real_; - dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; - } - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); - } - const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; - float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems * 2; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } - if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); - dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); - dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; - } - } else { - dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } - } - - if constexpr (kDeltaSoftplus) { - __syncthreads(); - input_t delta_vals_load[kNItems]; - load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - delta -= kChunkSize; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float delta_val = float(delta_vals_load[i]) + delta_bias; - float delta_val_neg_exp = expf(-delta_val); - ddelta_vals[i] = delta_val <= 20.f - ? ddelta_vals[i] / (1.f + delta_val_neg_exp) - : ddelta_vals[i]; - } - } - for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } - - input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride - + dim_id * params.du_d_stride + chunk * kChunkSize; - input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride - + dim_id * params.ddelta_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); - __syncthreads(); - store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); - - Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); - Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); - } - if (params.dD_ptr != nullptr) { - dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); - if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } - } - if (params.ddelta_bias_ptr != nullptr) { - __syncthreads(); - ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); - if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } - } - for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); - weight_t dBC_val; - if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } - if constexpr (!kIsVariableB) { - gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), - !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); - } - if constexpr (!kIsVariableC) { - gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), - !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); - } - } -} - -template -void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_bwd_kernel_traits; - // using Ktraits = Selective_Scan_bwd_kernel_traits; - // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim); - auto kernel = &selective_scan_bwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - }); -} - -template -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { - if (params.seqlen <= 128) { - selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); - } else { - selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); - } +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * kNRows * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * kNRows * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * kNRows * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * kNRows * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id * kNRows; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id * kNRows; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id * kNRows; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id * kNRows; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNRows][kNItems]; + input_t delta_vals_load[kNRows][kNItems]; + input_t dout_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + load_input(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[r][i] = float(dout_vals_load[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * kNRows * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out + r * params.out_d_stride, out_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + } + float dz_vals[kNRows][kNItems], z_silu_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[r][i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[r][i] = z_val * z_sigmoid_val; + dz_vals[r][i] = dout_vals[r][i] * float(out_vals[r][i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[r][i] *= z_silu_vals[r][i]; + } + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(dz + r * params.dz_d_stride, dz_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[r][i] = float(out_vals[r][i]) * z_silu_vals[r][i]; } + } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_z_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + } + + float du_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr ? 0 : D_val[r]) * dout_vals[r][i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); } + } + + float ddelta_vals[kNRows][kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + weight_t A_scaled[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled[r] = A_val[r] * kLog2e; + } else { + A_scaled[r] = complex_t(A_val[r].real_ * kLog2e, A_val[r].imag_);; + } + } + weight_t B_val[kNRows], C_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + B_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + C_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : (threadIdx.x + kNRows * 2 * Ktraits::MaxDState)] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[r][i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a; + dA_val += dx * delta_vals[r][i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[r][i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[r][i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[r][i] * (!kIsVariableB ? thread_data[i].y * B_val[r] : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + // !!!!! + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_scaled[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : B_vals[i] * delta_vals[r][i] * float(u_vals[r][i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : threadIdx.x + kNRows * 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[r][i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[r][i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[r][i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i])); + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + (dx * conj(A_val[r]) * a_conj).real_; + dA_val += delta_vals[r][i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[r][i]) * conj(!kIsVariableB ? x * B_val[r] : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + delta -= kChunkSize; + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[r][i] = delta_val <= 20.f + ? ddelta_vals[r][i] / (1.f + delta_val_neg_exp) + : ddelta_vals[r][i]; + } + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; } + } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * kNRows * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * kNRows * params.ddelta_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + + if (params.dD_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); } + } + } + if (params.ddelta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); } + } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx + r * Ktraits::MaxDState]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride + r * params.dB_d_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride + r * params.C_d_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride + r * params.dC_d_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride + r * params.B_d_stride]) : dBC_val); + } + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + kNRows * 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } } \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh new file mode 100644 index 000000000..a1de9d741 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh @@ -0,0 +1,586 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * kNRows * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * kNRows * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * kNRows * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * kNRows * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id * kNRows; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id * kNRows; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id * kNRows; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id * kNRows; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNRows][kNItems]; + input_t delta_vals_load[kNRows][kNItems]; + input_t dout_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + load_input(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[r][i] = float(dout_vals_load[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * kNRows * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out + r * params.out_d_stride, out_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + } + float dz_vals[kNRows][kNItems], z_silu_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[r][i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[r][i] = z_val * z_sigmoid_val; + dz_vals[r][i] = dout_vals[r][i] * float(out_vals[r][i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[r][i] *= z_silu_vals[r][i]; + } + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(dz + r * params.dz_d_stride, dz_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[r][i] = float(out_vals[r][i]) * z_silu_vals[r][i]; } + } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_z_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + } + + float du_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr ? 0 : D_val[r]) * dout_vals[r][i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); } + } + + float ddelta_vals[kNRows][kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + weight_t A_scaled[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled[r] = A_val[r] * kLog2e; + } else { + A_scaled[r] = complex_t(A_val[r].real_ * kLog2e, A_val[r].imag_);; + } + } + weight_t B_val[kNRows], C_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + B_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + C_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : (threadIdx.x + 2 * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState)] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[r][i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a; + dA_val += dx * delta_vals[r][i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[r][i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[r][i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[r][i] * (!kIsVariableB ? thread_data[i].y * B_val[r] : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + // !!!!! + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_scaled[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : B_vals[i] * delta_vals[r][i] * float(u_vals[r][i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : (threadIdx.x + 2 * Ktraits::MaxDState) + r * 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[r][i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[r][i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[r][i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i])); + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + (dx * conj(A_val[r]) * a_conj).real_; + dA_val += delta_vals[r][i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[r][i]) * conj(!kIsVariableB ? x * B_val[r] : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + delta -= kChunkSize; + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[r][i] = delta_val <= 20.f + ? ddelta_vals[r][i] / (1.f + delta_val_neg_exp) + : ddelta_vals[r][i]; + } + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; } + } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * kNRows * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * kNRows * params.ddelta_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + + if (params.dD_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); } + } + } + if (params.ddelta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); } + } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx + r * Ktraits::MaxDState]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride + r * params.dB_d_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride + r * params.C_d_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride + r * params.dC_d_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride + r * params.B_d_stride]) : dBC_val); + } + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + kNRows * 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh new file mode 100644 index 000000000..a06077bf8 --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh @@ -0,0 +1,533 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + // we are about to add kNRows here + static constexpr int MaxDState = MAX_DSTATE / 1; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + float dD_val = 0; + float ddelta_bias_val = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNItems], out_vals[kNItems]; + __syncthreads(); + load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); + float dz_vals[kNItems], z_silu_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[i] = z_val * z_sigmoid_val; + dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[i] *= z_silu_vals[i]; + } + __syncthreads(); + store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); + // } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * params.out_z_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); + } + } + + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + const weight_t A_val = A[state_idx * params.A_dstate_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + weight_t A_scaled; + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled = A_val * kLog2e; + } else { + A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); + } + weight_t B_val, C_val; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + B_val = B[state_idx * params.B_dstate_stride]; + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + C_val = C[state_idx * params.C_dstate_stride]; + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + // const weight_t A_val = smem_a[state_idx]; + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); + du_vals[i] += ddelta_u * delta_vals[i]; + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; + dA_val += delta_vals[i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + } + + if constexpr (kDeltaSoftplus) { + __syncthreads(); + input_t delta_vals_load[kNItems]; + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + if (params.dD_ptr != nullptr) { + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + // using Ktraits = Selective_Scan_bwd_kernel_traits; + // TODO: check this + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, 1, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, 1, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh new file mode 100644 index 000000000..2ef58c4fe --- /dev/null +++ b/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh @@ -0,0 +1,526 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * kNRows * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * kNRows * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * kNRows * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * kNRows * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id * kNRows; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id * kNRows; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id * kNRows; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id * kNRows; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + int r = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNItems], out_vals[kNItems]; + __syncthreads(); + load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); + float dz_vals[kNItems], z_silu_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[i] = z_val * z_sigmoid_val; + dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[i] *= z_silu_vals[i]; + } + __syncthreads(); + store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * params.out_z_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); + } + } + + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = (D_val == nullptr ? 0 : D_val[r]) * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + const weight_t A_val = A[state_idx * params.A_dstate_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + weight_t A_scaled; + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled = A_val * kLog2e; + } else { + A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); + } + weight_t B_val, C_val; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + B_val = B[state_idx * params.B_dstate_stride]; + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + C_val = C[state_idx * params.C_dstate_stride]; + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + // const weight_t A_val = smem_a[state_idx]; + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); + du_vals[i] += ddelta_u * delta_vals[i]; + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; + dA_val += delta_vals[i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + } + + if constexpr (kDeltaSoftplus) { + __syncthreads(); + input_t delta_vals_load[kNItems]; + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + if (params.dD_ptr != nullptr) { + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val[r]); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val[r]); } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + kNRows * 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, 1, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, 1, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_common.h b/csrc/selective_scan/selective_scan_common.h index 9140dcdf3..3c12af500 100644 --- a/csrc/selective_scan/selective_scan_common.h +++ b/csrc/selective_scan/selective_scan_common.h @@ -1,221 +1,221 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include // For scalar_value_type - -#define MAX_DSTATE 256 - -using complex_t = c10::complex; - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -inline __device__ float3 operator+(const float3 &a, const float3 &b) { - return {a.x + b.x, a.y + b.y, a.z + b.z}; -} - -inline __device__ float4 operator+(const float4 & a, const float4 & b){ - return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter{ - static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { dst[i] = src[i]; } - } -}; - -template -struct Converter{ - static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } - } -}; - -#if __CUDA_ARCH__ >= 800 -template -struct Converter{ - static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp -// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 -__device__ __forceinline__ complex_t cexp2f(complex_t z) { - float t = exp2f(z.real_); - float c, s; - sincosf(z.imag_, &s, &c); - return complex_t(c * t, s * t); -} - -__device__ __forceinline__ complex_t cexpf(complex_t z) { - float t = expf(z.real_); - float c, s; - sincosf(z.imag_, &s, &c); - return complex_t(c * t, s * t); -} - -template struct SSMScanOp; - -template<> -struct SSMScanOp { - __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { - return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); - } -}; - -template<> -struct SSMScanOp { - __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { - complex_t a0 = complex_t(ab0.x, ab0.y); - complex_t b0 = complex_t(ab0.z, ab0.w); - complex_t a1 = complex_t(ab1.x, ab1.y); - complex_t b1 = complex_t(ab1.z, ab1.w); - complex_t out_a = a1 * a0; - complex_t out_b = a1 * b0 + b1; - return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); - } -}; - -// A stateful callback functor that maintains a running prefix to be applied -// during consecutive scan operations. -template struct SSMScanPrefixCallbackOp { - using scan_t = std::conditional_t, float2, float4>; - scan_t running_prefix; - // Constructor - __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide scan. - __device__ scan_t operator()(scan_t block_aggregate) { - scan_t old_prefix = running_prefix; - running_prefix = SSMScanOp()(running_prefix, block_aggregate); - return old_prefix; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void load_input(typename Ktraits::input_t *u, - typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadT::TempStorage &smem_load, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_vec = reinterpret_cast(smem_load); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadVecT(smem_load_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) - ); - } else { - Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); - } -} - -template -inline __device__ void load_weight(typename Ktraits::input_t *Bvar, - typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, - int seqlen) { - constexpr int kNItems = Ktraits::kNItems; - if constexpr (!Ktraits::kIsComplex) { - typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - // #pragma unroll - // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } - Converter::to_float(B_vals_load, B_vals); - } else { - typename Ktraits::input_t B_vals_load[kNItems * 2]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } - } -} - -template -inline __device__ void store_output(typename Ktraits::input_t *out, - const float (&out_vals)[Ktraits::kNItems], - typename Ktraits::BlockStoreT::TempStorage &smem_store, - int seqlen) { - typename Ktraits::input_t write_vals[Ktraits::kNItems]; - #pragma unroll - for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_store_vec = reinterpret_cast(smem_store); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockStoreVecT(smem_store_vec).Store( - reinterpret_cast(out), - reinterpret_cast(write_vals) - ); - } else { - Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); - } -} +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For scalar_value_type + +#define MAX_DSTATE 256 + +using complex_t = c10::complex; + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp +// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 +__device__ __forceinline__ complex_t cexp2f(complex_t z) { + float t = exp2f(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +__device__ __forceinline__ complex_t cexpf(complex_t z) { + float t = expf(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { + complex_t a0 = complex_t(ab0.x, ab0.y); + complex_t b0 = complex_t(ab0.z, ab0.w); + complex_t a1 = complex_t(ab1.x, ab1.y); + complex_t b1 = complex_t(ab1.z, ab1.w); + complex_t out_a = a1 * a0; + complex_t out_b = a1 * b0 + b1; + return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + if constexpr (!Ktraits::kIsComplex) { + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); + } else { + typename Ktraits::input_t B_vals_load[kNItems * 2]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } + } +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 2d18569a1..d7126e9d4 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -1,343 +1,343 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include -#include -#include - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "static_switch.h" - -template -struct Selective_Scan_fwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. - static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; - static constexpr int kNItems = kNItems_; - static constexpr int kNRows = kNRows_; - static constexpr int MaxDState = MAX_DSTATE / kNRows; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kHasZ = kHasZ_; - - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - - using vec_t = typename BytesToType::Type; - using scan_t = std::conditional_t; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_fwd_kernel(SSMParamsBase params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - constexpr int kNRows = Ktraits::kNRows; - constexpr bool kDirectIO = Ktraits::kDirectIO; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast(smem_a + Ktraits::MaxDState); - scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); // Mzero: fixbug here for nrow - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * kNRows * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * kNRows * params.delta_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - - float D_val[kNRows] = {0}; - if (params.D_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; - } - } - float delta_bias[kNRows] = {0}; - if (params.delta_bias_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; - } - } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { - input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); - } - u += kChunkSize; - delta += kChunkSize; - - float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float u_val = float(u_vals[r][i]); - delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; - if (params.delta_softplus) { - delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; - } - delta_u_vals[r][i] = delta_vals[r][i] * u_val; - out_vals[r][i] = D_val[r] * u_val; - } - } - - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - weight_t A_val[kNRows]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_val[r] *= kLog2e; - } else { - A_val[r].real_ *= kLog2e; - } - } - // This variable holds B * C if both B and C are constant across seqlen. If only B varies - // across seqlen, this holds C. If only C varies across seqlen, this holds B. - // If both B and C vary, this is unused. - weight_t BC_val[kNRows]; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - } - if constexpr (kIsVariableC) { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableB) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; - } - } - } - if constexpr (!kIsVariableB && !kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if (r > 0) { __syncthreads(); } // Scan could be using the same smem - scan_t thread_data[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if constexpr (!kIsComplex) { - thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } - } - } else { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); - weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; - thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); - } - } - } - } - // Initialize running total - scan_t running_prefix; - if constexpr (!kIsComplex) { - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); - } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - } - SSMScanPrefixCallbackOp prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. - if (threadIdx.x == 0) { - smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix; // Mzero: fixbug here for nrow - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const weight_t C_val = !kIsVariableC - ? BC_val[r] - : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); - if constexpr (!kIsComplex) { - out_vals[r][i] += thread_data[i].y * C_val; - } else { - out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; - } - } - } - } - - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - input_t z_vals[kNItems]; - __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - out_vals[r][i] *= z_val / (1 + expf(-z_val)); - } - __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - } - - Bvar += kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += kChunkSize * (!kIsComplex ? 1 : 2); - } -} - -template -void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); -} - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { - if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); - } -} +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + Ktraits::MaxDState); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); // Mzero: fixbug here for nrow + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_val[r] *= kLog2e; + } else { + A_val[r].real_ *= kLog2e; + } + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if constexpr (!kIsComplex) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } else { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); + } + } + } + } + // Initialize running total + scan_t running_prefix; + if constexpr (!kIsComplex) { + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + } else { + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + } + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix; // Mzero: fixbug here for nrow + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + if constexpr (!kIsComplex) { + out_vals[r][i] += thread_data[i].y * C_val; + } else { + out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; + } + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += kChunkSize * (!kIsComplex ? 1 : 2); + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h index 7920ac045..1d52adf8d 100644 --- a/csrc/selective_scan/static_switch.h +++ b/csrc/selective_scan/static_switch.h @@ -1,25 +1,25 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/selective_scan/uninitialized_copy.cuh b/csrc/selective_scan/uninitialized_copy.cuh index 630622ddd..77863ff8d 100644 --- a/csrc/selective_scan/uninitialized_copy.cuh +++ b/csrc/selective_scan/uninitialized_copy.cuh @@ -1,69 +1,69 @@ -/****************************************************************************** - * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include - -#include - - -namespace detail -{ - -#if defined(_NVHPC_CUDA) -template -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - // NVBug 3384810 - new (ptr) T(::cuda::std::forward(val)); -} -#else -template ::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - *ptr = ::cuda::std::forward(val); -} - -template ::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - new (ptr) T(::cuda::std::forward(val)); -} -#endif - -} // namespace detail +/****************************************************************************** + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + + +namespace detail +{ + +#if defined(_NVHPC_CUDA) +template +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + // NVBug 3384810 + new (ptr) T(::cuda::std::forward(val)); +} +#else +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + *ptr = ::cuda::std::forward(val); +} + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + new (ptr) T(::cuda::std::forward(val)); +} +#endif + +} // namespace detail diff --git a/setup.py b/setup.py index f290ef89c..552427126 100644 --- a/setup.py +++ b/setup.py @@ -124,18 +124,22 @@ def append_nvcc_threads(nvcc_extra_args): name="selective_scan_cuda", sources=[ "csrc/selective_scan/selective_scan.cpp", - "csrc/selective_scan/selective_scan_fwd_fp32.cu", - "csrc/selective_scan/selective_scan_fwd_fp16.cu", - "csrc/selective_scan/selective_scan_fwd_bf16.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", - "csrc/selective_scan/selective_scan_fwd2.cu", - "csrc/selective_scan/selective_scan_fwd3.cu", - "csrc/selective_scan/selective_scan_fwd4.cu", + "csrc/selective_scan/cus/selective_scan_fwd.cu", + "csrc/selective_scan/cus/selective_scan_fwd2.cu", + "csrc/selective_scan/cus/selective_scan_fwd3.cu", + "csrc/selective_scan/cus/selective_scan_fwd4.cu", + "csrc/selective_scan/cus/selective_scan_bwd.cu", + "csrc/selective_scan/cus/selective_scan_bwd2.cu", + "csrc/selective_scan/cus/selective_scan_bwd3.cu", + "csrc/selective_scan/cus/selective_scan_bwd4.cu", + # "csrc/selective_scan/cus/selective_scan_fwd_complex.cu", + # "csrc/selective_scan/cus/selective_scan_fwd2_complex.cu", + # "csrc/selective_scan/cus/selective_scan_fwd3_complex.cu", + # "csrc/selective_scan/cus/selective_scan_fwd4_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd2_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd3_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd4_complex.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"], diff --git a/tests/ops/test_selective_scan_.py b/tests/ops/test_selective_scan_new2old.py similarity index 52% rename from tests/ops/test_selective_scan_.py rename to tests/ops/test_selective_scan_new2old.py index d2dd95d30..2141e0ba0 100644 --- a/tests/ops/test_selective_scan_.py +++ b/tests/ops/test_selective_scan_new2old.py @@ -1,96 +1,144 @@ -# Copyright (C) 2023, Tri Dao. -# here we have a simple test just verify the selective scan in csrc/ -# you should delete it when pull request... +# Modified by Mzero #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. import math - import torch import torch.nn.functional as F import pytest - -from einops import rearrange - import torch import torch.nn.functional as F from torch.cuda.amp import custom_bwd, custom_fwd from einops import rearrange, repeat -import selective_scan_cuda -# print(selective_scan_cuda) -class SelectiveScanFn(torch.autograd.Function): +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm"): + MODE = mode - @staticmethod - def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, nrows=1): - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if z is not None and z.stride(-1) != 1: - z = z.contiguous() - if B.dim() == 3: - B = rearrange(B, "b dstate l -> b 1 dstate l") - ctx.squeeze_B = True - if C.dim() == 3: - C = rearrange(C, "b dstate l -> b 1 dstate l") - ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) - ctx.delta_softplus = delta_softplus - ctx.has_z = z is not None - ctx.nrows = nrows - last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out if not return_last_state else (out, last_state) - else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) - out_z = rest[0] - return out_z if not return_last_state else (out_z, last_state) + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() - @staticmethod - def backward(ctx, dout, *args): - if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - z = None - out = None - else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - nrows = 1 # ctx.nrows # we have not implemented the nrows for bwd yet - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - # Here we just pass in None and dz will be allocated in the C++ code. - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False, nrows # option to recompute out_z, not used here - ) - dz = rest[0] if ctx.has_z else None - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - return (du, ddelta, dA, dB, dC, - dD if D is not None else None, - dz, - ddelta_bias if delta_bias is not None else None, - None, - None, - None) + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) -def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, nrows=1): - """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is - not considered in the backward pass. - """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows) + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + + return selective_scan_fn def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, @@ -162,25 +210,53 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta return out if not return_last_state else (out, last_state) +# MODE = "mamba_ssm" +# MODE = "sscore" +# MODE = "sstest" +MODE = "mamba_ssm_sscore" # 1344 items pass +MODE = "mamba_ssm_sstest" # 1344 items pass + +if MODE in ["mamba_ssm"]: + import selective_scan_cuda as selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["sscore"]: + import selective_scan_cuda_core + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["sstest"]: + import selective_scan_cuda_test + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["mamba_ssm_sscore"]: + import selective_scan_cuda_core + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscore") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +elif MODE in ["mamba_ssm_sstest"]: + import selective_scan_cuda_test + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode="sstest") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +else: + raise NotImplementedError + +print("use MODE:", MODE) +import time; time.sleep(10) + + # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) @pytest.mark.parametrize('wtype', [torch.float32]) -# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -# @pytest.mark.parametrize("return_last_state", [False, True]) +@pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize("return_last_state", [True]) -# @pytest.mark.parametrize('has_delta_bias', [False, True]) -@pytest.mark.parametrize('has_delta_bias', [True]) -# @pytest.mark.parametrize('delta_softplus', [False, True]) -@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_delta_bias', [False, True]) +@pytest.mark.parametrize('delta_softplus', [False, True]) # @pytest.mark.parametrize('has_z', [False, True]) -@pytest.mark.parametrize('has_z', [True]) -# @pytest.mark.parametrize('has_D', [False, True]) -@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize('has_z', [False]) +@pytest.mark.parametrize('has_D', [False, True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) -# @pytest.mark.parametrize("varBC_groups", [1]) # @pytest.mark.parametrize("is_variable_C", [False, True]) @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @@ -188,6 +264,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta @pytest.mark.parametrize("nrows", [1, 2, 3, 4]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, delta_softplus, return_last_state, seqlen, itype, wtype, nrows): + print(f'method: {selective_scan_cuda}') if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -297,15 +374,4 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z if has_delta_bias: assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) -""" -(mamba) (base) LiuYue@Turing17:~/Workspace/GITHUB/mamba$ pytest tests/ops/test_selective_scan_.py -========================================== test session starts =========================================== -platform linux -- Python 3.10.13, pytest-7.4.3, pluggy-1.0.0 -rootdir: /Workspace/LiuYue/GITHUB/mamba -plugins: anyio-4.2.0 -collected 48 items - -tests/ops/test_selective_scan_.py ................................................ [100%] -========================================== 48 passed in 42.40s =========================================== -""" From 3a854f056baaaf8c7b1e2a195bc8b59016ac0b8d Mon Sep 17 00:00:00 2001 From: MzeroMiko <3496274007@qq.com> Date: Sat, 17 Feb 2024 19:47:30 +0800 Subject: [PATCH 6/9] update --- tests/ops/test_selective_scan_speed.py | 334 +++++++++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 tests/ops/test_selective_scan_speed.py diff --git a/tests/ops/test_selective_scan_speed.py b/tests/ops/test_selective_scan_speed.py new file mode 100644 index 000000000..6110b29dc --- /dev/null +++ b/tests/ops/test_selective_scan_speed.py @@ -0,0 +1,334 @@ +# Modified by Mzero #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn.functional as F +import pytest +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from einops import rearrange, repeat +import time +from functools import partial + + +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None): + MODE = mode + + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() + + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + + selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" + print(repr(selective_scan_fn), "==", selective_scan_fn.__repr__()) + + return selective_scan_fn + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +def test_speed(): + wtype = torch.float32 + itype = torch.float32 + is_variable_B = True + is_variable_C = True + has_D = True + has_z = False # sscore not support z + has_delta_bias = True + varBC_groups = 2 + seqlen = 4096 + seqlen = 128 + seqlen = 64 + batch_size = 128 + dim = 24 + dim = 96 + dim = 384 + dim = 768 + dstate = 8 + # dstate = 24 + delta_softplus = True + is_complex = wtype == torch.complex64 + device = 'cuda' + TIMES = 1000 + import selective_scan_cuda_core + import selective_scan_cuda_test + import selective_scan_cuda + # copied from test_selective_scan ====================== + torch.random.manual_seed(0) + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + # ================================ + starts = [] + ends = [] + tests = [ + partial(build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm", tag="ori"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=3), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=4), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), + + ] + + for test in tests: + s = time.time() + for _ in range(TIMES): + with torch.no_grad(): + test() + torch.cuda.synchronize() + torch.cuda.empty_cache() + e = time.time() + starts.append(s) + ends.append(e) + print("fwd", test.func, e - s, flush=True) + for test in tests: + s = time.time() + for _ in range(TIMES): + outs = test() + outs[0].sum().backward() + torch.cuda.synchronize() + torch.cuda.empty_cache() + e = time.time() + starts.append(s) + ends.append(e) + print("fwdbwd", test.func, e - s, flush=True) + +test_speed() \ No newline at end of file From 9d53b21082d82e7ab724df169e213e2617803037 Mon Sep 17 00:00:00 2001 From: MzeroMiko <3496274007@qq.com> Date: Sat, 17 Feb 2024 19:51:15 +0800 Subject: [PATCH 7/9] update --- .gitignore | 3 +- .../selective_scan/cus/selective_scan_bwd.cu | 11 + .../selective_scan/cus/selective_scan_bwd2.cu | 11 + .../cus/selective_scan_bwd2_complex.cu | 11 + .../selective_scan/cus/selective_scan_bwd3.cu | 11 + .../cus/selective_scan_bwd3_complex.cu | 11 + .../selective_scan/cus/selective_scan_bwd4.cu | 11 + .../cus/selective_scan_bwd4_complex.cu | 11 + .../cus/selective_scan_bwd_complex.cu | 11 + .../selective_scan/cus/selective_scan_fwd.cu | 11 + .../selective_scan/cus/selective_scan_fwd2.cu | 11 + .../cus/selective_scan_fwd2_complex.cu | 11 + .../selective_scan/cus/selective_scan_fwd3.cu | 11 + .../cus/selective_scan_fwd3_complex.cu | 11 + .../selective_scan/cus/selective_scan_fwd4.cu | 11 + .../cus/selective_scan_fwd4_complex.cu | 11 + .../cus/selective_scan_fwd_complex.cu | 11 + kernel/csrc/selective_scan/reverse_scan.cuh | 401 ++++++++++++ kernel/csrc/selective_scan/selective_scan.cpp | 533 ++++++++++++++++ kernel/csrc/selective_scan/selective_scan.h | 101 +++ .../selective_scan_bwd_kernel.cuh | 586 ++++++++++++++++++ .../selective_scan_bwd_kernel.nrows.cuh | 586 ++++++++++++++++++ .../selective_scan_bwd_kernel.ori.cuh | 533 ++++++++++++++++ .../selective_scan_bwd_kernel.stage1.cuh | 526 ++++++++++++++++ .../selective_scan/selective_scan_common.h | 221 +++++++ .../selective_scan_fwd_kernel.cuh | 343 ++++++++++ kernel/csrc/selective_scan/static_switch.h | 25 + .../selective_scan/uninitialized_copy.cuh | 69 +++ kernel/readme.md | 1 + kernel/setup.py | 238 +++++++ kernel/test_selective_scan_new2old.py | 377 +++++++++++ kernel/test_selective_scan_speed.py | 334 ++++++++++ 32 files changed, 5052 insertions(+), 1 deletion(-) create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd2.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd3.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd3_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd4.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_bwd_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd2.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd2_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd3.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd3_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd4.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd4_complex.cu create mode 100644 kernel/csrc/selective_scan/cus/selective_scan_fwd_complex.cu create mode 100644 kernel/csrc/selective_scan/reverse_scan.cuh create mode 100644 kernel/csrc/selective_scan/selective_scan.cpp create mode 100644 kernel/csrc/selective_scan/selective_scan.h create mode 100644 kernel/csrc/selective_scan/selective_scan_bwd_kernel.cuh create mode 100644 kernel/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh create mode 100644 kernel/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh create mode 100644 kernel/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh create mode 100644 kernel/csrc/selective_scan/selective_scan_common.h create mode 100644 kernel/csrc/selective_scan/selective_scan_fwd_kernel.cuh create mode 100644 kernel/csrc/selective_scan/static_switch.h create mode 100644 kernel/csrc/selective_scan/uninitialized_copy.cuh create mode 100644 kernel/readme.md create mode 100644 kernel/setup.py create mode 100644 kernel/test_selective_scan_new2old.py create mode 100644 kernel/test_selective_scan_speed.py diff --git a/.gitignore b/.gitignore index e1cdc5a56..0ddc74416 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *__pycache__/ *.egg-info/ build/ -**.so +*.so +*.whl diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd.cu new file mode 100644 index 000000000..c7d5ecf1d --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd2.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd2.cu new file mode 100644 index 000000000..2af8f1e2c --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd2.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu new file mode 100644 index 000000000..51bc14cd3 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd2_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<2, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<2, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd3.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd3.cu new file mode 100644 index 000000000..fe9ebcae1 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd3.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd3_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd3_complex.cu new file mode 100644 index 000000000..c58d3f974 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd3_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<3, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<3, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd4.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd4.cu new file mode 100644 index 000000000..36555d110 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd4.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu new file mode 100644 index 000000000..11417e17a --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd4_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<4, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<4, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_bwd_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_bwd_complex.cu new file mode 100644 index 000000000..29e6a90d0 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_bwd_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda<1, float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); +template void selective_scan_bwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd.cu new file mode 100644 index 000000000..1b19a9110 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd2.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd2.cu new file mode 100644 index 000000000..1b24ae355 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd2.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd2_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd2_complex.cu new file mode 100644 index 000000000..e84a2c588 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd2_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<2, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<2, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd3.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd3.cu new file mode 100644 index 000000000..cce00b4e2 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd3.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd3_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd3_complex.cu new file mode 100644 index 000000000..a8dc76640 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd3_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<3, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<3, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd4.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd4.cu new file mode 100644 index 000000000..74383e3a7 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd4.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd4_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd4_complex.cu new file mode 100644 index 000000000..4dd204a49 --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd4_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<4, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<4, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/kernel/csrc/selective_scan/cus/selective_scan_fwd_complex.cu b/kernel/csrc/selective_scan/cus/selective_scan_fwd_complex.cu new file mode 100644 index 000000000..20f1a86cb --- /dev/null +++ b/kernel/csrc/selective_scan/cus/selective_scan_fwd_complex.cu @@ -0,0 +1,11 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "../selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda<1, at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda<1, float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); diff --git a/kernel/csrc/selective_scan/reverse_scan.cuh b/kernel/csrc/selective_scan/reverse_scan.cuh new file mode 100644 index 000000000..0baeebb05 --- /dev/null +++ b/kernel/csrc/selective_scan/reverse_scan.cuh @@ -0,0 +1,401 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +// #include +#include "uninitialized_copy.cuh" + +/** + * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { + static_assert(LENGTH > 0); + T retval = input[LENGTH - 1]; + #pragma unroll + for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } + return retval; +} + +/** + * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanInclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + T inclusive = postfix; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } +} + +/** + * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanExclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + // Careful, output maybe be aliased to input + T exclusive = postfix; + T inclusive; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + return inclusive; +} + + +/** + * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS ///< Number of threads per logical warp + > +struct WarpReverseScan { + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + /// Whether the logical warp size and the PTX warp size coincide + static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); + /// The number of warp scan steps + static constexpr int STEPS = cub::Log2::VALUE; + static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + explicit __device__ __forceinline__ + WarpReverseScan() + : lane_id(cub::LaneId()) + , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) + , member_mask(cub::WarpMask(warp_id)) + { + if (!IS_ARCH_WARP) { + lane_id = lane_id % LOGICAL_WARP_THREADS; + } + } + + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return cub::ShuffleIndex(input, src_lane, member_mask); + } + + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) { + int offset = 1 << STEP; + T temp = cub::ShuffleDown( + inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask + ); + // Perform scan op if from a valid peer + inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset + ? inclusive_output : scan_op(temp, inclusive_output); + } + } + + /// Exclusive scan + // Get exclusive from inclusive + template + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + T inclusive_output; + InclusiveReverseScan(input, inclusive_output, scan_op); + warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + + /** + * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. + */ + template + __device__ __forceinline__ void ReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + ScanOpT scan_op) ///< [in] Binary scan operator + { + InclusiveReverseScan(input, inclusive_output, scan_op); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + +}; + +/** + * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. + */ +template < + typename T, ///< Data type being scanned + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure + > +struct BlockReverseScan { + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + /// The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X; + + /// Layout type for padded thread block raking grid + using BlockRakingLayout = cub::BlockRakingLayout; + // The number of reduction elements is not a multiple of the number of raking threads for now + static_assert(BlockRakingLayout::UNGUARDED); + + /// Number of raking threads + static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; + /// Number of raking elements per warp synchronous raking thread + static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; + /// Cooperative work can be entirely warp synchronous + static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); + + /// WarpReverseScan utility type + using WarpReverseScan = WarpReverseScan; + + /// Shared memory storage layout type + struct _TempStorage { + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + T cached_segment[SEGMENT_LENGTH]; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /// Performs upsweep raking reduction, returning the aggregate + template + __device__ __forceinline__ T Upsweep(ScanOp scan_op) { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data into registers + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; + #pragma unroll + for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { + raking_partial = scan_op(raking_partial, cached_segment[i]); + } + return raking_partial; + } + + + /// Performs exclusive downsweep raking scan + template + __device__ __forceinline__ void ExclusiveDownsweep( + ScanOp scan_op, + T raking_partial) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data back into registers + if (!MEMOIZE) { + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + } + ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); + // Write data back to smem + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } + } + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockReverseScan( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) + {} + + + /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpReverseScan warp_scan; + warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); + // Obtain warp-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); + } else { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + detail::uninitialized_copy(placement_ptr, input); + cub::CTA_SYNC(); + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) { + WarpReverseScan warp_scan; + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + // Obtain block-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + // Update postfix with warpscan exclusive partial + T downsweep_postfix = linear_tid == RAKING_THREADS - 1 + ? block_postfix : scan_op(block_postfix, exclusive_partial); + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, downsweep_postfix); + } + cub::CTA_SYNC(); + // Grab thread postfix from shared memory + exclusive_output = *placement_ptr; + + // // Compute warp scan in each warp. + // // The exclusive output from the last lane in each warp is invalid. + // T inclusive_output; + // WarpReverseScan warp_scan; + // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); + + // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. + // T block_aggregate; + // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); + + // // Apply warp postfix to our lane's partial + // if (warp_id != 0) { + // exclusive_output = scan_op(warp_postfix, exclusive_output); + // if (lane_id == 0) { exclusive_output = warp_postfix; } + // } + + // // Use the first warp to determine the thread block postfix, returning the result in lane0 + // if (warp_id == 0) { + // T block_postfix = block_postfix_callback_op(block_aggregate); + // if (lane_id == 0) { + // // Share the postfix with all threads + // detail::uninitialized_copy(&temp_storage.block_postfix, + // block_postfix); + + // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 + // } + // } + + // cub::CTA_SYNC(); + + // // Incorporate thread block postfix into outputs + // T block_postfix = temp_storage.block_postfix; + // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } + } + } + + + /** + * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void InclusiveReverseScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. + { + // Reduce consecutive thread items in registers + T thread_postfix = ThreadReverseReduce(input, scan_op); + // Exclusive thread block-scan + ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); + // Inclusive scan in registers with postfix as seed + ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); + } + +}; \ No newline at end of file diff --git a/kernel/csrc/selective_scan/selective_scan.cpp b/kernel/csrc/selective_scan/selective_scan.cpp new file mode 100644 index 000000000..bf06ba60d --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan.cpp @@ -0,0 +1,533 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" +#define MAX_DSTATE 256 + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::ComplexFloat) { \ + using weight_t = c10::complex; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define INT_SWITCH_FWD(INT, NAME, ...) [&] { \ + if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ + else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ + else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ + else {constexpr int NAME = 1; __VA_ARGS__(); } \ +}() \ + +#define INT_SWITCH_BWD(INT, NAME, ...) [&] { \ + if (INT == 2) {constexpr int NAME = 2; __VA_ARGS__(); } \ + else if (INT == 3) {constexpr int NAME = 3; __VA_ARGS__(); } \ + else if (INT == 4) {constexpr int NAME = 4; __VA_ARGS__(); } \ + else {constexpr int NAME = 1; __VA_ARGS__(); } \ +}() \ + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + const at::Tensor z, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor z, + const at::Tensor out, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + const at::Tensor dz, + void* dD_ptr, + void* ddelta_bias_ptr, + bool has_z, + bool delta_softplus, + bool recompute_out_z) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, has_z ? out : dout, + has_z ? z : dout, + // If not recompute_out_z, pass dout instead of out_z. + // This won't be used by the bwd kernel + recompute_out_z ? out_z : dout, + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + if (!recompute_out_z) { params.out_z_ptr = nullptr; } + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + params.dz_ptr = has_z ? dz.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + if (!is_variable_B) { + params.dB_d_stride = dB.stride(0); + } else { + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + } + params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); + if (!is_variable_C) { + params.dC_d_stride = dC.stride(0); + } else { + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + } + params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + if (has_z) { + params.dz_batch_stride = dz.stride(0); + params.dz_d_stride = dz.stride(1); + } +} + +template +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + // TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + using weight_t = std::conditional_t>; + TORCH_CHECK(weight_type == (is_complex ? at::ScalarType::ComplexFloat : at::ScalarType::Float)); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + // const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + has_z, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + // DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { + INT_SWITCH_FWD(nrows, kNRows, [&] { + selective_scan_fwd_cuda(params, stream); + }); + // }); + }); + std::vector result = {out, x}; + if (has_z) { result.push_back(out_z); } + return result; +} + +template +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + const c10::optional &out_, + c10::optional &dz_, + bool delta_softplus, + bool recompute_out_z, + int nrows + ) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + // TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + using weight_t = std::conditional_t>; + TORCH_CHECK(weight_type == (is_complex ? at::ScalarType::ComplexFloat : at::ScalarType::Float)); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + // const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dim % (n_groups * nrows) == 0, "dims should be dividable by n_groups * nrows"); + TORCH_CHECK(dstate <= MAX_DSTATE / nrows, "selective_scan only supports state dimension <= 256 / nrows"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out, dz, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + + TORCH_CHECK(out_.has_value()); + out = out_.value(); + TORCH_CHECK(out.scalar_type() == input_type); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); + CHECK_SHAPE(out, batch_size, dim, seqlen); + + if (dz_.has_value()) { + dz = dz_.value(); + TORCH_CHECK(dz.scalar_type() == input_type); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); + CHECK_SHAPE(dz, batch_size, dim, seqlen); + } else { + dz = torch::empty_like(z); + } + if (recompute_out_z) { + out_z = torch::empty_like(out); + } + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, z, out, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, dz, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + has_z, delta_softplus, recompute_out_z); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + // DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { + INT_SWITCH_BWD(nrows, kNRows, [&] { + selective_scan_bwd_cuda(params, stream); + }); + // }); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + if (has_z) { result.push_back(dz); } + if (recompute_out_z) { result.push_back(out_z); } + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); + // m.def("fwdc", &selective_scan_fwd, "Selective scan forward for complex"); + // m.def("bwdc", &selective_scan_bwd, "Selective scan backward for complex"); +} diff --git a/kernel/csrc/selective_scan/selective_scan.h b/kernel/csrc/selective_scan/selective_scan.h new file mode 100644 index 000000000..86eaa220b --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan.h @@ -0,0 +1,101 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMScanParamsBase { + using index_t = uint32_t; + + int batch, seqlen, n_chunks; + index_t a_batch_stride; + index_t b_batch_stride; + index_t out_batch_stride; + + // Common data pointers. + void *__restrict__ a_ptr; + void *__restrict__ b_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; +}; + +struct SSMParamsBwd: public SSMParamsBase { + index_t dout_batch_stride; + index_t dout_d_stride; + index_t dA_d_stride; + index_t dA_dstate_stride; + index_t dB_batch_stride; + index_t dB_group_stride; + index_t dB_d_stride; + index_t dB_dstate_stride; + index_t dC_batch_stride; + index_t dC_group_stride; + index_t dC_d_stride; + index_t dC_dstate_stride; + index_t du_batch_stride; + index_t du_d_stride; + index_t dz_batch_stride; + index_t dz_d_stride; + index_t ddelta_batch_stride; + index_t ddelta_d_stride; + + // Common data pointers. + void *__restrict__ dout_ptr; + void *__restrict__ dA_ptr; + void *__restrict__ dB_ptr; + void *__restrict__ dC_ptr; + void *__restrict__ dD_ptr; + void *__restrict__ du_ptr; + void *__restrict__ dz_ptr; + void *__restrict__ ddelta_ptr; + void *__restrict__ ddelta_bias_ptr; +}; diff --git a/kernel/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.cuh new file mode 100644 index 000000000..ef2af5ab3 --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -0,0 +1,586 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * kNRows * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * kNRows * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * kNRows * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * kNRows * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id * kNRows; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id * kNRows; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id * kNRows; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id * kNRows; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNRows][kNItems]; + input_t delta_vals_load[kNRows][kNItems]; + input_t dout_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + load_input(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[r][i] = float(dout_vals_load[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * kNRows * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out + r * params.out_d_stride, out_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + } + float dz_vals[kNRows][kNItems], z_silu_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[r][i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[r][i] = z_val * z_sigmoid_val; + dz_vals[r][i] = dout_vals[r][i] * float(out_vals[r][i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[r][i] *= z_silu_vals[r][i]; + } + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(dz + r * params.dz_d_stride, dz_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[r][i] = float(out_vals[r][i]) * z_silu_vals[r][i]; } + } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_z_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + } + + float du_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr ? 0 : D_val[r]) * dout_vals[r][i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); } + } + + float ddelta_vals[kNRows][kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + weight_t A_scaled[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled[r] = A_val[r] * kLog2e; + } else { + A_scaled[r] = complex_t(A_val[r].real_ * kLog2e, A_val[r].imag_);; + } + } + weight_t B_val[kNRows], C_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + B_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + C_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : (threadIdx.x + kNRows * 2 * Ktraits::MaxDState)] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[r][i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a; + dA_val += dx * delta_vals[r][i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[r][i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[r][i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[r][i] * (!kIsVariableB ? thread_data[i].y * B_val[r] : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + // !!!!! + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_scaled[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : B_vals[i] * delta_vals[r][i] * float(u_vals[r][i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : threadIdx.x + kNRows * 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[r][i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[r][i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[r][i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i])); + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + (dx * conj(A_val[r]) * a_conj).real_; + dA_val += delta_vals[r][i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[r][i]) * conj(!kIsVariableB ? x * B_val[r] : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + delta -= kChunkSize; + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[r][i] = delta_val <= 20.f + ? ddelta_vals[r][i] / (1.f + delta_val_neg_exp) + : ddelta_vals[r][i]; + } + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; } + } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * kNRows * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * kNRows * params.ddelta_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + + if (params.dD_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); } + } + } + if (params.ddelta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); } + } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx + r * Ktraits::MaxDState]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride + r * params.dB_d_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride + r * params.C_d_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride + r * params.dC_d_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride + r * params.B_d_stride]) : dBC_val); + } + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + kNRows * 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/kernel/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh new file mode 100644 index 000000000..a1de9d741 --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.nrows.cuh @@ -0,0 +1,586 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * kNRows * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * kNRows * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * kNRows * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * kNRows * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id * kNRows; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id * kNRows; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id * kNRows; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id * kNRows; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNRows][kNItems]; + input_t delta_vals_load[kNRows][kNItems]; + input_t dout_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + load_input(dout + r * params.dout_d_stride, dout_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u -= kChunkSize; + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + dout -= kChunkSize; + + float dout_vals[kNRows][kNItems], delta_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[r][i] = float(dout_vals_load[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * kNRows * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out + r * params.out_d_stride, out_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + } + float dz_vals[kNRows][kNItems], z_silu_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[r][i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[r][i] = z_val * z_sigmoid_val; + dz_vals[r][i] = dout_vals[r][i] * float(out_vals[r][i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[r][i] *= z_silu_vals[r][i]; + } + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(dz + r * params.dz_d_stride, dz_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[r][i] = float(out_vals[r][i]) * z_silu_vals[r][i]; } + } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_z_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + } + + float du_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[r][i] = (D_val == nullptr ? 0 : D_val[r]) * dout_vals[r][i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[r][i] * float(u_vals[r][i]); } + } + + float ddelta_vals[kNRows][kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + weight_t A_scaled[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled[r] = A_val[r] * kLog2e; + } else { + A_scaled[r] = complex_t(A_val[r].real_ * kLog2e, A_val[r].imag_);; + } + } + weight_t B_val[kNRows], C_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + B_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + C_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[r][i] * A_scaled[r]); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : (threadIdx.x + 2 * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState)] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[r][i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + kNRows * 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i]); + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + dx * A_val[r] * a; + dA_val += dx * delta_vals[r][i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[r][i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[r][i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[r][i] * (!kIsVariableB ? thread_data[i].y * B_val[r] : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + // !!!!! + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_scaled[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : B_vals[i] * delta_vals[r][i] * float(u_vals[r][i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? (state_idx + (chunk % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState) : (threadIdx.x + 2 * Ktraits::MaxDState) + r * 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[r][i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val[r] * C_val[r] : C_val[r]) + : (!kIsVariableB ? B_val[r] * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState + r * 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1 + r * params.n_chunks) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx + r * Ktraits::MaxDState] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[r][i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[r][i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[r][i] * float(u_vals[r][i]) : delta_vals[r][i] * float(u_vals[r][i]) * B_vals[i])); + du_vals[r][i] += ddelta_u * delta_vals[r][i]; + ddelta_vals[r][i] += ddelta_u * float(u_vals[r][i]) + (dx * conj(A_val[r]) * a_conj).real_; + dA_val += delta_vals[r][i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[r][i] * float(u_vals[r][i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[r][i]) * conj(!kIsVariableB ? x * B_val[r] : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx + r * Ktraits::MaxDState]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx + r * Ktraits::MaxDState] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx + r * Ktraits::MaxDState]; + } + } + } + } + + if constexpr (kDeltaSoftplus) { + input_t delta_vals_load[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + delta -= kChunkSize; + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[r][i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[r][i] = delta_val <= 20.f + ? ddelta_vals[r][i] / (1.f + delta_val_neg_exp) + : ddelta_vals[r][i]; + } + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[r][i]; } + } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * kNRows * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * kNRows * params.ddelta_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + store_output(du + r * params.du_d_stride, du_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta + r * params.ddelta_d_stride, ddelta_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + + if (params.dD_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(dD[r]), dD_val[r]); } + } + } + if (params.ddelta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(&(ddelta_bias[r]), ddelta_bias_val[r]); } + } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride + r * params.dA_d_stride]), smem_da[state_idx + r * Ktraits::MaxDState]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx + r * Ktraits::MaxDState]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride + r * params.dB_d_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride + r * params.C_d_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride + r * params.dC_d_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride + r * params.B_d_stride]) : dBC_val); + } + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + kNRows * 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/kernel/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh new file mode 100644 index 000000000..a06077bf8 --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.ori.cuh @@ -0,0 +1,533 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + // we are about to add kNRows here + static constexpr int MaxDState = MAX_DSTATE / 1; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + float dD_val = 0; + float ddelta_bias_val = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNItems], out_vals[kNItems]; + __syncthreads(); + load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); + float dz_vals[kNItems], z_silu_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[i] = z_val * z_sigmoid_val; + dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[i] *= z_silu_vals[i]; + } + __syncthreads(); + store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); + // } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * params.out_z_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); + } + } + + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + const weight_t A_val = A[state_idx * params.A_dstate_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + weight_t A_scaled; + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled = A_val * kLog2e; + } else { + A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); + } + weight_t B_val, C_val; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + B_val = B[state_idx * params.B_dstate_stride]; + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + C_val = C[state_idx * params.C_dstate_stride]; + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + // const weight_t A_val = smem_a[state_idx]; + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); + du_vals[i] += ddelta_u * delta_vals[i]; + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; + dA_val += delta_vals[i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + } + + if constexpr (kDeltaSoftplus) { + __syncthreads(); + input_t delta_vals_load[kNItems]; + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + if (params.dD_ptr != nullptr) { + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + // using Ktraits = Selective_Scan_bwd_kernel_traits; + // TODO: check this + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, 1, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, 1, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/kernel/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh new file mode 100644 index 000000000..2ef58c4fe --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan_bwd_kernel.stage1.cuh @@ -0,0 +1,526 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + kNRows * 2 * Ktraits::MaxDState + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + kNRows * Ktraits::MaxDState); + weight_t *smem_dbc = reinterpret_cast(smem_da + kNRows * Ktraits::MaxDState); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * kNRows * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * kNRows * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * kNRows * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * kNRows * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id * kNRows; + float *D_val = params.D_ptr == nullptr ? nullptr : reinterpret_cast(params.D_ptr) + dim_id * kNRows; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id * kNRows; + float *delta_bias = params.delta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.delta_bias_ptr) + dim_id * kNRows; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * (params.n_chunks) * params.dstate; + float dD_val[kNRows] = {0}; + float ddelta_bias_val[kNRows] = {0}; + int r = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNItems], out_vals[kNItems]; + __syncthreads(); + load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); + float dz_vals[kNItems], z_silu_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[i] = z_val * z_sigmoid_val; + dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[i] *= z_silu_vals[i]; + } + __syncthreads(); + store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * params.out_z_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); + } + } + + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = (D_val == nullptr ? 0 : D_val[r]) * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val[r] += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + const weight_t A_val = A[state_idx * params.A_dstate_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + weight_t A_scaled; + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled = A_val * kLog2e; + } else { + A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); + } + weight_t B_val, C_val; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + B_val = B[state_idx * params.B_dstate_stride]; + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + C_val = C[state_idx * params.C_dstate_stride]; + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + // const weight_t A_val = smem_a[state_idx]; + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * Ktraits::MaxDState : threadIdx.x + 2 * Ktraits::MaxDState] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * Ktraits::MaxDState]) + : smem_delta_a[threadIdx.x + 1 + 2 * Ktraits::MaxDState]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); + du_vals[i] += ddelta_u * delta_vals[i]; + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; + dA_val += delta_vals[i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + } + + if constexpr (kDeltaSoftplus) { + __syncthreads(); + input_t delta_vals_load[kNItems]; + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + (delta_bias == nullptr ? 0 : delta_bias[r]); + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val[r] += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + if (params.dD_ptr != nullptr) { + dD_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val[r]); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val[r] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val[r]); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val[r]); } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t) + (kNThreads + kNRows * 4 * Ktraits::MaxDState) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, 1, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, 1, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, 1, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/kernel/csrc/selective_scan/selective_scan_common.h b/kernel/csrc/selective_scan/selective_scan_common.h new file mode 100644 index 000000000..3c12af500 --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan_common.h @@ -0,0 +1,221 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For scalar_value_type + +#define MAX_DSTATE 256 + +using complex_t = c10::complex; + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp +// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 +__device__ __forceinline__ complex_t cexp2f(complex_t z) { + float t = exp2f(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +__device__ __forceinline__ complex_t cexpf(complex_t z) { + float t = expf(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { + complex_t a0 = complex_t(ab0.x, ab0.y); + complex_t b0 = complex_t(ab0.z, ab0.w); + complex_t a1 = complex_t(ab1.x, ab1.y); + complex_t b1 = complex_t(ab1.z, ab1.w); + complex_t out_a = a1 * a0; + complex_t out_b = a1 * b0 + b1; + return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + if constexpr (!Ktraits::kIsComplex) { + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); + } else { + typename Ktraits::input_t B_vals_load[kNItems * 2]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } + } +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/kernel/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/kernel/csrc/selective_scan/selective_scan_fwd_kernel.cuh new file mode 100644 index 000000000..d7126e9d4 --- /dev/null +++ b/kernel/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -0,0 +1,343 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int MaxDState = MAX_DSTATE / kNRows; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * Ktraits::MaxDState * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + Ktraits::MaxDState); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id * kNRows / (params.dim_ngroups_ratio); // Mzero: fixbug here for nrow + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_val[r] *= kLog2e; + } else { + A_val[r].real_ *= kLog2e; + } + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if constexpr (!kIsComplex) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } else { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); + } + } + } + } + // Initialize running total + scan_t running_prefix; + if constexpr (!kIsComplex) { + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + } else { + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float4(1.f, 0.f, 0.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + } + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix; // Mzero: fixbug here for nrow + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + if constexpr (!kIsComplex) { + out_vals[r][i] += thread_data[i].y * C_val; + } else { + out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; + } + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += kChunkSize * (!kIsComplex ? 1 : 2); + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); + } +} diff --git a/kernel/csrc/selective_scan/static_switch.h b/kernel/csrc/selective_scan/static_switch.h new file mode 100644 index 000000000..1d52adf8d --- /dev/null +++ b/kernel/csrc/selective_scan/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/kernel/csrc/selective_scan/uninitialized_copy.cuh b/kernel/csrc/selective_scan/uninitialized_copy.cuh new file mode 100644 index 000000000..77863ff8d --- /dev/null +++ b/kernel/csrc/selective_scan/uninitialized_copy.cuh @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + + +namespace detail +{ + +#if defined(_NVHPC_CUDA) +template +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + // NVBug 3384810 + new (ptr) T(::cuda::std::forward(val)); +} +#else +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + *ptr = ::cuda::std::forward(val); +} + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + new (ptr) T(::cuda::std::forward(val)); +} +#endif + +} // namespace detail diff --git a/kernel/readme.md b/kernel/readme.md new file mode 100644 index 000000000..a3af55056 --- /dev/null +++ b/kernel/readme.md @@ -0,0 +1 @@ +this is `selective_scan` in `mamba_ssm` \ No newline at end of file diff --git a/kernel/setup.py b/kernel/setup.py new file mode 100644 index 000000000..c9e95e92f --- /dev/null +++ b/kernel/setup.py @@ -0,0 +1,238 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform +import shutil + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + + +# with open("README.md", "r", encoding="utf-8") as fh: +# long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "mamba_ssm" + +BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +cmdclass = {} +ext_modules = [] + +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none(PACKAGE_NAME) + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + + cc_flag.append("-gencode") + cc_flag.append("arch=compute_70,code=sm_70") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + ext_modules.append( + CUDAExtension( + name="selective_scan_cuda_test", + sources=[ + "csrc/selective_scan/selective_scan.cpp", + "csrc/selective_scan/cus/selective_scan_fwd.cu", + "csrc/selective_scan/cus/selective_scan_fwd2.cu", + "csrc/selective_scan/cus/selective_scan_fwd3.cu", + "csrc/selective_scan/cus/selective_scan_fwd4.cu", + "csrc/selective_scan/cus/selective_scan_bwd.cu", + "csrc/selective_scan/cus/selective_scan_bwd2.cu", + "csrc/selective_scan/cus/selective_scan_bwd3.cu", + "csrc/selective_scan/cus/selective_scan_bwd4.cu", + # "csrc/selective_scan/cus/selective_scan_fwd_complex.cu", + # "csrc/selective_scan/cus/selective_scan_fwd2_complex.cu", + # "csrc/selective_scan/cus/selective_scan_fwd3_complex.cu", + # "csrc/selective_scan/cus/selective_scan_fwd4_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd2_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd3_complex.cu", + # "csrc/selective_scan/cus/selective_scan_bwd4_complex.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + + cc_flag + + ["--threads", "4"] + }, + include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + ) + ) + + +def get_package_version(): + with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("MAMBA_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +def get_wheel_url(): + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + torch_version_raw = parse(torch.__version__) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + mamba_ssm_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename + ) + return wheel_url, wheel_filename + +setup( + name="selective_scan_test", + version="0.0.1", + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + "mamba_ssm.egg-info", + ) + ), + author="Tri Dao, Albert Gu", + author_email="tri@tridao.me, agu@cs.cmu.edu", + description="Mamba state-space model", + long_description=None, + long_description_content_type="text/markdown", + url="https://github.com/state-spaces/mamba", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"bdist_wheel": _bdist_wheel, "build_ext": BuildExtension}, + python_requires=">=3.7", + install_requires=[ + "torch", + "packaging", + "ninja", + # "einops", + ], +) diff --git a/kernel/test_selective_scan_new2old.py b/kernel/test_selective_scan_new2old.py new file mode 100644 index 000000000..2141e0ba0 --- /dev/null +++ b/kernel/test_selective_scan_new2old.py @@ -0,0 +1,377 @@ +# Modified by Mzero #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn.functional as F +import pytest +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from einops import rearrange, repeat + + +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm"): + MODE = mode + + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() + + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + + return selective_scan_fn + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +# MODE = "mamba_ssm" +# MODE = "sscore" +# MODE = "sstest" +MODE = "mamba_ssm_sscore" # 1344 items pass +MODE = "mamba_ssm_sstest" # 1344 items pass + +if MODE in ["mamba_ssm"]: + import selective_scan_cuda as selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["sscore"]: + import selective_scan_cuda_core + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["sstest"]: + import selective_scan_cuda_test + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode=MODE) + selective_scan_ref = selective_scan_ref +elif MODE in ["mamba_ssm_sscore"]: + import selective_scan_cuda_core + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscore") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +elif MODE in ["mamba_ssm_sstest"]: + import selective_scan_cuda_test + import selective_scan_cuda + selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode="sstest") + selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") +else: + raise NotImplementedError + +print("use MODE:", MODE) +import time; time.sleep(10) + + +# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [False, True]) +@pytest.mark.parametrize('delta_softplus', [False, True]) +# @pytest.mark.parametrize('has_z', [False, True]) +@pytest.mark.parametrize('has_z', [False]) +@pytest.mark.parametrize('has_D', [False, True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +# @pytest.mark.parametrize("is_variable_C", [False, True]) +@pytest.mark.parametrize("is_variable_C", [True]) +# @pytest.mark.parametrize("is_variable_B", [False, True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("nrows", [1, 2, 3, 4]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, nrows): + print(f'method: {selective_scan_cuda}') + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 24 + dstate = 8 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state, nrows=nrows + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + + diff --git a/kernel/test_selective_scan_speed.py b/kernel/test_selective_scan_speed.py new file mode 100644 index 000000000..6110b29dc --- /dev/null +++ b/kernel/test_selective_scan_speed.py @@ -0,0 +1,334 @@ +# Modified by Mzero #20240123 +# Copyright (C) 2023, Tri Dao, Albert Gu. + +import math +import torch +import torch.nn.functional as F +import pytest +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from einops import rearrange, repeat +import time +from functools import partial + + +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None): + MODE = mode + + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() + + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + + selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" + print(repr(selective_scan_fn), "==", selective_scan_fn.__repr__()) + + return selective_scan_fn + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +def test_speed(): + wtype = torch.float32 + itype = torch.float32 + is_variable_B = True + is_variable_C = True + has_D = True + has_z = False # sscore not support z + has_delta_bias = True + varBC_groups = 2 + seqlen = 4096 + seqlen = 128 + seqlen = 64 + batch_size = 128 + dim = 24 + dim = 96 + dim = 384 + dim = 768 + dstate = 8 + # dstate = 24 + delta_softplus = True + is_complex = wtype == torch.complex64 + device = 'cuda' + TIMES = 1000 + import selective_scan_cuda_core + import selective_scan_cuda_test + import selective_scan_cuda + # copied from test_selective_scan ====================== + torch.random.manual_seed(0) + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + # ================================ + starts = [] + ends = [] + tests = [ + partial(build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm", tag="ori"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=3), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=4), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), + partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), + # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), + + ] + + for test in tests: + s = time.time() + for _ in range(TIMES): + with torch.no_grad(): + test() + torch.cuda.synchronize() + torch.cuda.empty_cache() + e = time.time() + starts.append(s) + ends.append(e) + print("fwd", test.func, e - s, flush=True) + for test in tests: + s = time.time() + for _ in range(TIMES): + outs = test() + outs[0].sum().backward() + torch.cuda.synchronize() + torch.cuda.empty_cache() + e = time.time() + starts.append(s) + ends.append(e) + print("fwdbwd", test.func, e - s, flush=True) + +test_speed() \ No newline at end of file From d07701675afa30f6bb3d8ff11fd971b2e8d3ec75 Mon Sep 17 00:00:00 2001 From: MzeroMiko <3496274007@qq.com> Date: Sat, 17 Feb 2024 21:45:18 +0800 Subject: [PATCH 8/9] update --- kernel/test_selective_scan_benchmark.log | 76 ++++ kernel/test_selective_scan_benchmark.py | 440 +++++++++++++++++++++++ kernel/test_selective_scan_speed.log | 40 +++ kernel/test_selective_scan_speed.py | 24 +- tests/ops/test_selective_scan_new2old.py | 377 ------------------- tests/ops/test_selective_scan_speed.py | 334 ----------------- 6 files changed, 568 insertions(+), 723 deletions(-) create mode 100644 kernel/test_selective_scan_benchmark.log create mode 100644 kernel/test_selective_scan_benchmark.py create mode 100644 kernel/test_selective_scan_speed.log delete mode 100644 tests/ops/test_selective_scan_new2old.py delete mode 100644 tests/ops/test_selective_scan_speed.py diff --git a/kernel/test_selective_scan_benchmark.log b/kernel/test_selective_scan_benchmark.log new file mode 100644 index 000000000..414d8dffb --- /dev/null +++ b/kernel/test_selective_scan_benchmark.log @@ -0,0 +1,76 @@ +name mamba-16-0 mamba-16-1 mamba-16-2 mamba-16-4 mamba-2-0 mamba-2-1 mamba-2-2 mamba-2-4 +dtype padding batch_size nheads seq_len +torch.bfloat16 False 16 12 50 0.114 0.113 0.098 0.102 0.044 0.043 0.042 0.048 + 100 0.120 0.119 0.103 0.111 0.055 0.054 0.052 0.057 + 1000 0.685 0.685 1.407 1.448 0.294 0.292 0.424 0.604 + 1600 1.353 1.364 2.720 2.834 0.463 0.466 0.767 1.029 + 3200 2.746 2.746 5.430 5.623 0.894 0.893 1.506 2.033 + 6400 5.483 5.483 10.871 11.151 1.737 1.736 2.993 4.036 + 16 50 0.149 0.149 0.126 0.140 0.052 0.052 0.050 0.062 + 100 0.155 0.156 0.132 0.151 0.064 0.065 0.064 0.072 + 1000 0.919 0.914 1.881 1.933 0.381 0.383 0.558 0.802 + 1600 1.826 1.834 3.653 3.762 0.598 0.596 1.016 1.359 + 3200 3.673 3.673 7.285 7.549 1.163 1.162 2.013 2.695 + 6400 7.333 7.348 14.495 14.877 2.301 2.301 3.997 5.370 + 32 50 0.287 0.287 0.236 0.247 0.084 0.084 0.081 0.101 + 100 0.300 0.301 0.248 0.264 0.109 0.108 0.105 0.120 + 1000 1.846 1.881 3.781 3.886 0.736 0.736 1.070 1.551 + 1600 3.702 3.704 7.316 7.542 1.168 1.170 1.999 2.674 + 3200 7.445 7.492 14.806 14.833 2.299 2.297 4.025 5.339 + 6400 14.733 14.736 29.709 29.961 4.567 4.560 8.067 10.670 + True 16 12 50 0.114 0.114 0.098 0.103 0.045 0.044 0.042 0.048 + 100 0.120 0.119 0.103 0.111 0.055 0.054 0.050 0.057 + 1000 0.686 0.684 1.409 1.447 0.290 0.292 0.424 0.604 + 1600 1.357 1.358 2.713 2.843 0.460 0.465 0.767 1.029 + 3200 2.732 2.743 5.417 5.602 0.884 0.895 1.508 2.030 + 6400 5.481 5.469 10.847 11.212 1.739 1.735 2.983 4.038 + 16 50 0.149 0.148 0.125 0.140 0.051 0.051 0.051 0.062 + 100 0.155 0.155 0.132 0.151 0.065 0.065 0.064 0.072 + 1000 0.915 0.914 1.881 1.929 0.381 0.384 0.556 0.801 + 1600 1.827 1.824 3.645 3.772 0.602 0.598 1.015 1.362 + 3200 3.687 3.697 7.253 7.501 1.164 1.162 2.006 2.695 + 6400 7.338 7.317 14.495 14.875 2.300 2.298 3.985 5.366 + 32 50 0.288 0.287 0.237 0.247 0.084 0.084 0.081 0.100 + 100 0.301 0.300 0.249 0.264 0.107 0.107 0.104 0.120 + 1000 1.851 1.870 3.786 3.880 0.734 0.736 1.066 1.551 + 1600 3.708 3.694 7.345 7.566 1.170 1.165 2.028 2.673 + 3200 7.479 7.504 14.754 14.846 2.295 2.295 4.007 5.339 + 6400 14.716 14.728 29.698 29.781 4.561 4.566 8.094 10.670 +name mamba-16-0-0 mamba-16-1-0 mamba-16-1-1 mamba-16-1-2 mamba-16-1-4 mamba-16-2-0 mamba-16-2-1 mamba-16-2-2 mamba-16-2-4 mamba-16-4-0 mamba-16-4-1 mamba-16-4-2 mamba-16-4-4 mamba-2-0-0 mamba-2-1-0 mamba-2-1-1 mamba-2-1-2 mamba-2-1-4 mamba-2-2-0 mamba-2-2-1 mamba-2-2-2 mamba-2-2-4 mamba-2-4-0 mamba-2-4-1 mamba-2-4-2 mamba-2-4-4 +dtype padding batch_size nheads seq_len +torch.bfloat16 False 16 12 50 0.415 0.421 0.422 0.421 0.421 0.372 0.373 0.372 0.373 0.408 0.408 0.408 0.408 0.176 0.181 0.182 0.182 0.183 0.206 0.192 0.191 0.192 0.212 0.197 0.198 0.196 + 100 0.479 0.482 0.482 0.481 0.482 0.447 0.447 0.446 0.448 0.477 0.477 0.476 0.477 0.196 0.199 0.200 0.199 0.200 0.198 0.199 0.201 0.198 0.200 0.202 0.204 0.203 + 1000 3.315 3.335 3.333 3.337 3.333 4.802 4.799 4.801 4.799 7.084 7.080 7.086 7.078 1.292 1.317 1.309 1.317 1.310 1.561 1.559 1.560 1.558 2.821 2.805 2.823 2.806 + 1600 8.076 7.999 7.996 8.002 7.997 14.112 14.159 14.110 14.112 19.099 20.105 19.105 20.113 2.462 2.474 2.476 2.474 2.475 4.065 4.063 4.064 4.061 7.474 7.504 7.467 7.509 + 3200 16.183 16.030 16.027 16.029 16.062 28.133 28.137 28.133 28.139 40.612 40.616 40.547 40.579 4.888 4.896 4.893 4.895 4.895 8.002 8.002 8.001 7.999 15.010 15.007 15.006 15.018 + 6400 32.399 32.091 32.094 32.090 32.088 56.230 56.209 56.214 56.229 86.655 86.646 86.586 86.632 9.704 9.701 9.702 9.703 9.702 15.866 15.868 15.864 15.866 30.362 30.378 30.362 30.364 + 16 50 0.536 0.545 0.545 0.544 0.545 0.470 0.471 0.470 0.471 0.523 0.522 0.523 0.522 0.195 0.200 0.199 0.198 0.201 0.198 0.197 0.198 0.196 0.208 0.212 0.209 0.259 + 100 0.621 0.624 0.625 0.624 0.625 0.569 0.568 0.569 0.567 0.608 0.610 0.608 0.608 0.276 0.283 0.282 0.282 0.282 0.280 0.281 0.283 0.281 0.283 0.283 0.295 0.276 + 1000 4.418 4.438 4.453 4.436 4.446 7.449 7.452 7.450 7.453 10.456 10.456 10.456 10.459 1.704 1.738 1.750 1.736 1.754 2.274 2.275 2.273 2.275 4.419 4.461 4.456 4.465 + 1600 10.763 10.659 10.659 10.659 10.659 18.767 18.766 18.769 18.766 26.192 26.193 26.186 26.206 3.271 3.284 3.283 3.285 3.284 5.377 5.378 5.378 5.377 9.932 9.926 9.930 9.933 + 3200 21.582 21.364 21.364 21.364 21.367 37.479 37.482 37.476 37.474 54.407 54.370 54.480 54.369 6.482 6.487 6.490 6.488 6.488 10.608 10.609 10.613 10.639 19.947 19.936 19.930 19.948 + 6400 43.175 42.754 42.766 42.748 42.782 74.893 74.880 74.867 74.869 115.408 115.261 115.425 115.330 12.980 12.917 12.914 12.913 12.914 21.105 21.103 21.100 21.099 40.462 40.453 40.451 40.466 + 32 50 1.013 1.029 1.028 1.028 1.028 0.865 0.864 0.865 0.864 0.952 0.953 0.951 0.953 0.314 0.324 0.324 0.325 0.336 0.302 0.323 0.301 0.302 0.345 0.345 0.345 0.345 + 100 1.236 1.251 1.252 1.251 1.252 1.102 1.102 1.102 1.103 1.150 1.150 1.150 1.151 0.444 0.455 0.454 0.456 0.455 0.423 0.422 0.424 0.421 0.472 0.471 0.473 0.471 + 1000 8.846 8.908 8.899 8.905 8.922 14.882 14.885 14.885 14.883 20.556 20.560 20.557 20.562 3.386 3.456 3.461 3.458 3.471 4.444 4.458 4.443 4.458 8.837 8.833 8.835 8.830 + 1600 21.503 21.286 21.287 21.286 21.287 37.399 37.403 37.395 37.393 50.549 50.578 50.577 50.600 6.499 6.527 6.527 6.528 6.526 10.608 10.610 10.608 10.610 19.651 19.673 19.675 19.650 + 3200 43.065 42.635 42.636 42.632 42.633 74.712 74.724 74.714 74.703 107.471 107.462 107.335 107.306 13.019 12.939 12.940 12.933 12.936 21.058 21.054 21.063 21.055 39.677 39.668 39.667 39.693 + 6400 86.494 85.411 85.368 85.399 85.411 149.455 149.429 149.427 149.444 230.277 230.410 230.479 230.439 25.945 25.782 25.783 25.789 25.789 42.018 42.008 42.009 42.019 80.824 80.785 80.774 80.813 + True 16 12 50 0.416 0.422 0.421 0.422 0.421 0.370 0.368 0.370 0.369 0.404 0.402 0.404 0.402 0.196 0.200 0.197 0.198 0.196 0.198 0.197 0.196 0.196 0.196 0.197 0.195 0.198 + 100 0.479 0.482 0.481 0.482 0.480 0.451 0.442 0.451 0.443 0.478 0.475 0.479 0.477 0.214 0.197 0.198 0.202 0.201 0.201 0.202 0.200 0.202 0.204 0.205 0.203 0.203 + 1000 3.316 3.331 3.327 3.332 3.331 4.814 4.812 4.813 4.814 7.082 7.101 7.088 7.101 1.282 1.319 1.309 1.316 1.310 1.554 1.547 1.554 1.547 2.819 2.810 2.818 2.809 + 1600 8.079 8.001 7.994 8.000 7.994 14.112 14.111 14.112 14.111 18.866 19.829 18.841 19.829 2.428 2.444 2.446 2.444 2.446 3.384 3.391 3.384 3.391 7.460 7.510 7.465 7.522 + 3200 16.183 16.022 16.022 16.021 16.021 28.137 28.136 28.136 28.143 40.612 40.628 40.625 40.598 4.885 4.894 4.897 4.894 4.898 7.999 8.004 8.000 8.008 15.056 15.059 15.061 15.051 + 6400 32.388 32.084 32.089 32.092 32.083 56.235 56.237 56.227 56.208 86.597 86.568 86.594 86.591 9.716 9.714 9.710 9.714 9.712 15.880 15.876 15.876 15.874 30.248 30.218 30.251 30.220 + 16 50 0.537 0.545 0.545 0.545 0.545 0.473 0.475 0.474 0.474 0.533 0.535 0.534 0.532 0.191 0.195 0.198 0.197 0.221 0.202 0.195 0.200 0.202 0.203 0.204 0.204 0.201 + 100 0.622 0.626 0.629 0.626 0.629 0.576 0.564 0.576 0.563 0.619 0.609 0.617 0.610 0.268 0.274 0.276 0.276 0.276 0.276 0.274 0.274 0.281 0.274 0.277 0.278 0.298 + 1000 4.436 4.452 4.464 4.452 4.462 7.462 7.464 7.462 7.463 10.465 10.473 10.470 10.472 1.704 1.737 1.747 1.735 1.748 2.302 2.305 2.303 2.306 4.443 4.487 4.452 4.460 + 1600 10.761 10.657 10.656 10.658 10.657 18.772 18.772 18.772 18.773 26.155 26.182 26.178 26.190 3.288 3.291 3.294 3.291 3.294 5.396 5.415 5.390 5.413 9.931 9.955 9.932 9.921 + 3200 21.572 21.358 21.356 21.361 21.364 37.471 37.477 37.463 37.474 54.483 54.453 54.416 54.450 6.488 6.490 6.494 6.490 6.495 10.611 10.621 10.609 10.625 19.945 19.940 19.941 19.942 + 6400 43.167 42.743 42.747 42.746 42.741 74.866 74.879 74.877 74.885 115.341 115.347 115.415 115.327 12.971 12.926 12.920 12.926 12.925 21.108 21.107 21.104 21.103 40.306 40.257 40.309 40.269 + 32 50 1.012 1.028 1.027 1.028 1.028 0.859 0.860 0.860 0.860 0.946 0.948 0.946 0.950 0.321 0.326 0.324 0.326 0.324 0.316 0.305 0.315 0.305 0.345 0.344 0.345 0.344 + 100 1.230 1.246 1.244 1.243 1.243 1.070 1.077 1.070 1.077 1.136 1.139 1.136 1.140 0.454 0.461 0.480 0.463 0.479 0.444 0.473 0.453 0.473 0.477 0.494 0.477 0.493 + 1000 8.882 8.923 8.933 8.927 8.928 14.839 14.834 14.840 14.834 20.510 20.505 20.512 20.503 3.389 3.459 3.460 3.459 3.462 4.462 4.510 4.463 4.510 8.802 8.817 8.778 8.821 + 1600 21.462 21.253 21.247 21.260 21.250 37.379 37.379 37.379 37.377 50.481 50.494 50.449 50.483 6.513 6.534 6.537 6.535 6.537 10.615 10.637 10.616 10.634 19.660 19.645 19.642 19.642 + 3200 43.062 42.631 42.644 42.652 42.653 74.733 74.746 74.721 74.722 107.493 107.382 107.457 107.486 12.975 12.934 12.936 12.933 12.936 21.070 21.079 21.063 21.076 39.648 39.646 39.640 39.631 + 6400 86.455 85.370 85.351 85.364 85.369 149.424 149.458 149.431 149.432 230.291 230.444 230.411 230.662 25.851 25.791 25.790 25.798 25.790 42.027 42.017 42.023 42.009 80.586 80.449 80.594 80.521 diff --git a/kernel/test_selective_scan_benchmark.py b/kernel/test_selective_scan_benchmark.py new file mode 100644 index 000000000..4c063be30 --- /dev/null +++ b/kernel/test_selective_scan_benchmark.py @@ -0,0 +1,440 @@ +import itertools +from math import sqrt + +import pandas +import torch +from tqdm import tqdm +import triton + +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func +except: + pass + + + + +def benchmark_mamba(batch, head, length, dim_head, d_state, selective_scan_cuda, *args): + from einops import rearrange, repeat + + d_model = dim_head * head + expand = 2 + d_inner = d_model * expand + device = "cuda" + + # S4D real initialization + A = repeat( + torch.arange(1, d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + + x = torch.rand( + (batch, d_inner, length), device=device, dtype=torch.bfloat16 + ).requires_grad_(True) + z = torch.rand( + (batch, d_inner, length), device=device, dtype=torch.bfloat16 + ).requires_grad_(True) + delta = torch.rand( + (batch, d_inner, length), device=device, dtype=torch.bfloat16 + ).requires_grad_(True) + delta_bias = torch.randn(d_inner).to("cuda").requires_grad_(True) + A = -torch.exp(A_log.float()) # (d_inner, d_state) + B = ( + torch.randn(batch, 1, d_state, length) + .to("cuda") + .to(torch.bfloat16) + .requires_grad_(True) + ) + C = ( + torch.randn(batch, 1, d_state, length) + .to("cuda") + .to(torch.bfloat16) + .requires_grad_(True) + ) + D = torch.ones(d_inner, device=device) # Keep in fp32 + delta_softplus = True + + ms = triton.testing.do_bench( + lambda: selective_scan_cuda.fwd( + x, delta, A, B, C, D, z, delta_bias, delta_softplus, *args + ), + warmup=100, + ) + return ms + + +def get_inputs(B, H, L, E=64, ret_padding_mask=False, dtype=torch.float32): + q = torch.rand((B, H, L, E), device="cuda", dtype=dtype) + k = torch.rand((B, H, L, E), device="cuda", dtype=dtype) + v = torch.rand((B, H, L, E), device="cuda", dtype=dtype) + + input_lengths = torch.randint(1, L, (B,), device=q.device).long() + input_lengths[-1] = L + padding_mask = torch.zeros((B, L), dtype=q.dtype, device=q.device) + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + input_lengths - 1, + ) + ] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() + if not ret_padding_mask: + padding_mask = None + return (q, k, v), padding_mask + + +def flash_attn_forward(queries, keys, values, padding_mask=None): + qkv = torch.stack([queries, keys, values], dim=2) + qkv = qkv.permute(0, 3, 2, 1, 4) + B, T, _, H, D = qkv.shape + scale = 1.0 / sqrt(D) + + if padding_mask is not None: + # unpad_input expectes True to correspond to valid indices and False to invalid + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, ~padding_mask) + packed_res = flash_attn_varlen_qkvpacked_func( + qkv, + cu_q_lens, + max_s, + dropout_p=0.0, + softmax_scale=scale, + causal=False, + alibi_slopes=None, + deterministic=False, + ) + res = pad_input(packed_res, indices, B, T) + res = res.transpose(1, 2) + else: + res = flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=scale, + causal=False, + alibi_slopes=None, + deterministic=False, + ) + res = res.transpose(1, 2) # B x T x H x D -> B x H x T x D + return res + + +def benchmark_flash(q, k, v, padding_mask): + dim_E = q.shape[-1] + H = q.shape[1] + E = dim_E * H + ms = triton.testing.do_bench( + lambda: flash_attn_forward(q, k, v, padding_mask=padding_mask), warmup=100 + ) + return ms + + +def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None): + MODE = mode + + class SelectiveScanFn(torch.autograd.Function): + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + if D is not None and (D.dtype != torch.float): + ctx._d_dtype = D.dtype + D = D.float() + if delta_bias is not None and (delta_bias.dtype != torch.float): + ctx._delta_bias_dtype = delta_bias.dtype + delta_bias = delta_bias.float() + + assert u.shape[1] % (B.shape[1] * nrows) == 0 + assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile + + if backnrows > 0: + assert u.shape[1] % (B.shape[1] * backnrows) == 0 + assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile + else: + backnrows = nrows + ctx.backnrows = backnrows + + if MODE in ["mamba_ssm"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + + elif MODE in ["sscore"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) + elif MODE in ["sstest"]: + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) + else: + raise NotImplementedError + + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + if MODE in ["mamba_ssm", "sstest"]: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + elif MODE in ["sscore"]: + return out if not return_last_state else (out, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + if MODE in ["mamba_ssm"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + elif MODE in ["sstest"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False, ctx.backnrows # option to recompute out_z, not used here + ) + elif MODE in ["sscore"]: + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows + ) + else: + raise NotImplementedError + + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + + _dD = None + if D is not None: + if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): + _dD = dD.to(ctx._d_dtype) + else: + _dD = dD + + _ddelta_bias = None + if delta_bias is not None: + if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): + _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) + else: + _ddelta_bias = ddelta_bias + + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, None, None, None) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) + + selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" + # print(repr(selective_scan_fn), "==", selective_scan_fn.__repr__()) + + return selective_scan_fn + + +def benchmark_mamba_fwdbwd(batch, head, length, dim_head, d_state, selective_scan_fn, *args): + from einops import rearrange, repeat + + d_model = dim_head * head + expand = 2 + d_inner = d_model * expand + device = "cuda" + + # S4D real initialization + A = repeat( + torch.arange(1, d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + + x = torch.rand( + (batch, d_inner, length), device=device, dtype=torch.bfloat16 + ).requires_grad_(True) + z = torch.rand( + (batch, d_inner, length), device=device, dtype=torch.bfloat16 + ).requires_grad_(True) + delta = torch.rand( + (batch, d_inner, length), device=device, dtype=torch.bfloat16 + ).requires_grad_(True) + delta_bias = torch.randn(d_inner).to("cuda").requires_grad_(True) + A = -torch.exp(A_log.float()) # (d_inner, d_state) + B = ( + torch.randn(batch, 1, d_state, length) + .to("cuda") + .to(torch.bfloat16) + .requires_grad_(True) + ) + C = ( + torch.randn(batch, 1, d_state, length) + .to("cuda") + .to(torch.bfloat16) + .requires_grad_(True) + ) + D = torch.ones(d_inner, device=device) # Keep in fp32 + delta_softplus = True + + ms = triton.testing.do_bench( + lambda: selective_scan_fn( + x, delta, A, B, C, D, z, delta_bias, delta_softplus, False, *args + )[0].sum().backward(), + warmup=100, + ) + return ms + + +def test_bench(with_flash=False, with_mamba_fwd=False, with_mamba_fwdbwd=False): + batch_sizes = [16] + heads = [12, 16, 32] + time_steps = [50, 100, 1000, 1600, 3200, 6400] + get_padding_masks = [True, False] + # d_states = [2, 4, 8, 16] + d_states = [2, 16] # to save space, otherwise, too many columns would display + dtypes = [torch.bfloat16] + E = 64 + fwdnrows = [0, 1, 2, 4] # 64 // 3 != 0 + bwdnrows = [0, 1, 2, 4] # 64 // 3 != 0 + + results = [] + + if with_flash: + for B, H, L, pm, dtype in tqdm( + itertools.product(batch_sizes, heads, time_steps, get_padding_masks, dtypes) + ): + (q, k, v), padding_mask = get_inputs( + B, H, L, E=64, ret_padding_mask=pm, dtype=dtype + ) + ms = benchmark_flash(q, k, v, padding_mask) + results.append( + { + "name": "flash", + "batch_size": B, + "nheads": H, + "seq_len": L, + "dim": H * E, + "padding": pm, + "dtype": dtype, + "ms": ms, + } + ) + + if with_mamba_fwd: + for B, H, L, pm, d_state, dtype, fwdnrow in tqdm( + itertools.product( + batch_sizes, heads, time_steps, get_padding_masks, d_states, dtypes, fwdnrows + ) + ): + (q, k, v), padding_mask = get_inputs( + B, H, L, E=64, ret_padding_mask=pm, dtype=dtype + ) + + if fwdnrow == 0: + import selective_scan_cuda + ms = benchmark_mamba(B, H, L, E, d_state, selective_scan_cuda) + else: + import selective_scan_cuda_test + ms = benchmark_mamba(B, H, L, E, d_state, selective_scan_cuda_test, fwdnrow) + results.append( + { + "name": f"mamba-{d_state}-{fwdnrow}", + "batch_size": B, + "nheads": H, + "seq_len": L, + "dim": H * E, + "padding": pm, + "dtype": dtype, + "ms": ms, + } + ) + + if with_mamba_fwdbwd: + for B, H, L, pm, d_state, dtype, fwdnrow, bwdnrow in tqdm( + itertools.product( + batch_sizes, heads, time_steps, get_padding_masks, d_states, dtypes, fwdnrows, bwdnrows + ) + ): + (q, k, v), padding_mask = get_inputs( + B, H, L, E=64, ret_padding_mask=pm, dtype=dtype + ) + + if fwdnrow == 0: + if bwdnrow == 0: + import selective_scan_cuda + ms = benchmark_mamba_fwdbwd(B, H, L, E, d_state, build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm")) + else: + continue + else: + import selective_scan_cuda_test + ms = benchmark_mamba_fwdbwd(B, H, L, E, d_state, build_selective_scan_fn(selective_scan_cuda_test, mode="sstest"), fwdnrow) + results.append( + { + "name": f"mamba-{d_state}-{fwdnrow}-{bwdnrow}", + "batch_size": B, + "nheads": H, + "seq_len": L, + "dim": H * E, + "padding": pm, + "dtype": dtype, + "ms": ms, + } + ) + + df = pandas.DataFrame(results) + piv = df.pivot( + columns="name", + values="ms", + index=["dtype", "padding", "batch_size", "nheads", "seq_len"], + ) + pandas.set_option('display.width', 1000) + pandas.set_option('display.max_rows', None) + pandas.set_option('display.max_columns', None) + print(piv.sort_index().round(3)) + + +if __name__ == "__main__": + test_bench(with_mamba_fwd=True) + test_bench(with_mamba_fwdbwd=True) + + +""" +Thank you very much, @apoorv2904. +You are right, and I nearly failed to reproduce the results I have observed before. +These days, I kept working on it. (the environment I use is 4090 24G, with py310+cu121+torch2.2) +1. I added `nrow` feature in backward to better compare with different `nrow` settings. +2. I compared my code (`selective_scan_test` here, or `selective_scan_core` in VMamba) with `mamba_ssm` rather than `selective_scan_ref`, and keeps no difference (tested all pass with [test file](https://github.com/MzeroMiko/mamba/blob/main/kernel/test_selective_scan_new2old.py)). +3. I realised that the [issue]`https://github.com/alxndrTL/mamba.py/issues/8` proves nothing here, since raising `d_state` only inference the flops in SSM (nearly equals selective scan) while raising `d_model` or `seqlen` inferences the whole mamba model. As SSM is fast compared to `the whole model + data loading`, the speed difference is small and hard to observe (which is one possibility to that issue). +4. I used my newly written [`simple benchmark`](https://github.com/MzeroMiko/mamba/blob/main/kernel/test_selective_scan_speed.py), and found the results are consistent with yours. It seems that raissing nrows would only make the code slower, until I finally realised that ***the test which shows raise the nrow will raise the speed, was done in 7x7 images, which means seqlen is 49! extremely small!***. Then I add `seqlen=64` in testing, and found in some `fwdnrow+bwdnrow` patterns, the speed is fast, see [log](https://github.com/MzeroMiko/mamba/blob/main/kernel/test_selective_scan_speed.log) for details. Though I still do not know how bwd codes inferences the fwd procedure. +5. I modified your [`benchmark`](https://github.com/MzeroMiko/mamba/blob/main/kernel/test_selective_scan_benchmark.py), and the results are consistent with `test_selective_scan_speed`, see [log](https://github.com/MzeroMiko/mamba/blob/main/kernel/test_selective_scan_benchmark.log) for details. +To conclude, with short `seqlen`, bigger `nrow` may leads to faster speed, but the reason remains unknown. +""" \ No newline at end of file diff --git a/kernel/test_selective_scan_speed.log b/kernel/test_selective_scan_speed.log new file mode 100644 index 000000000..efd9cfc14 --- /dev/null +++ b/kernel/test_selective_scan_speed.log @@ -0,0 +1,40 @@ +fwd selective_scan_fn | mamba_ssm | ori 0.2595360279083252 +fwd selective_scan_fn | sstest | f1b1 0.25940918922424316 +fwd selective_scan_fn | sstest | f2b1 0.232133150100708 +fwd selective_scan_fn | sstest | f3b1 0.2424156665802002 +fwd selective_scan_fn | sstest | f4b1 0.27019643783569336 +fwd selective_scan_fn | sstest | f1b2 0.22879815101623535 +fwd selective_scan_fn | sstest | f1b3 0.22960782051086426 +fwd selective_scan_fn | sstest | f1b4 0.2311267852783203 +fwd selective_scan_fn | sstest | f2b2 0.19332456588745117 +fwd selective_scan_fn | sstest | f3b3 0.18274688720703125 +fwd selective_scan_fn | sstest | f4b4 0.1872847080230713 +fwd selective_scan_fn | sscore | f1b1 0.24498963356018066 +fwd selective_scan_fn | sscore | f2b1 0.1880788803100586 +fwd selective_scan_fn | sscore | f3b1 0.1852104663848877 +fwd selective_scan_fn | sscore | f4b1 0.19452261924743652 +fwd selective_scan_fn | sscore | f1b2 0.2343282699584961 +fwd selective_scan_fn | sscore | f2b2 0.18627405166625977 +fwd selective_scan_fn | sscore | f2b3 0.17929315567016602 +fwd selective_scan_fn | sscore | f4b4 0.19004416465759277 +fwd selective_scan_fn | mamba_ssm | ori 0.23628973960876465 +fwdbwd selective_scan_fn | mamba_ssm | ori 1.0447983741760254 +fwdbwd selective_scan_fn | sstest | f1b1 1.0994946956634521 +fwdbwd selective_scan_fn | sstest | f2b1 0.9461770057678223 +fwdbwd selective_scan_fn | sstest | f3b1 0.9414637088775635 +fwdbwd selective_scan_fn | sstest | f4b1 0.9588637351989746 +fwdbwd selective_scan_fn | sstest | f1b2 1.0666122436523438 +fwdbwd selective_scan_fn | sstest | f1b3 1.1968886852264404 +fwdbwd selective_scan_fn | sstest | f1b4 1.3072748184204102 +fwdbwd selective_scan_fn | sstest | f2b2 0.8528542518615723 +fwdbwd selective_scan_fn | sstest | f3b3 0.9282448291778564 +fwdbwd selective_scan_fn | sstest | f4b4 0.9394724369049072 +fwdbwd selective_scan_fn | sscore | f1b1 0.994654655456543 +fwdbwd selective_scan_fn | sscore | f2b1 0.9511115550994873 +fwdbwd selective_scan_fn | sscore | f3b1 0.9489989280700684 +fwdbwd selective_scan_fn | sscore | f4b1 0.9634220600128174 +fwdbwd selective_scan_fn | sscore | f1b2 0.9859137535095215 +fwdbwd selective_scan_fn | sscore | f2b2 0.923966646194458 +fwdbwd selective_scan_fn | sscore | f2b3 0.8890247344970703 +fwdbwd selective_scan_fn | sscore | f4b4 0.9850261211395264 +fwdbwd selective_scan_fn | mamba_ssm | ori 0.9858667850494385 diff --git a/kernel/test_selective_scan_speed.py b/kernel/test_selective_scan_speed.py index 6110b29dc..2c6ce45f6 100644 --- a/kernel/test_selective_scan_speed.py +++ b/kernel/test_selective_scan_speed.py @@ -141,7 +141,7 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" - print(repr(selective_scan_fn), "==", selective_scan_fn.__repr__()) + # print(repr(selective_scan_fn), "==", selective_scan_fn.__repr__()) return selective_scan_fn @@ -297,15 +297,15 @@ def test_speed(): partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), - + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), + partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), + partial(build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm", tag="ori"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True), ] for test in tests: @@ -318,7 +318,7 @@ def test_speed(): e = time.time() starts.append(s) ends.append(e) - print("fwd", test.func, e - s, flush=True) + print("fwd", test.func.__repr__(), e - s, flush=True) for test in tests: s = time.time() for _ in range(TIMES): @@ -329,6 +329,6 @@ def test_speed(): e = time.time() starts.append(s) ends.append(e) - print("fwdbwd", test.func, e - s, flush=True) + print("fwdbwd", test.func.__repr__(), e - s, flush=True) test_speed() \ No newline at end of file diff --git a/tests/ops/test_selective_scan_new2old.py b/tests/ops/test_selective_scan_new2old.py deleted file mode 100644 index 2141e0ba0..000000000 --- a/tests/ops/test_selective_scan_new2old.py +++ /dev/null @@ -1,377 +0,0 @@ -# Modified by Mzero #20240123 -# Copyright (C) 2023, Tri Dao, Albert Gu. - -import math -import torch -import torch.nn.functional as F -import pytest -import torch -import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd -from einops import rearrange, repeat - - -def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm"): - MODE = mode - - class SelectiveScanFn(torch.autograd.Function): - @staticmethod - def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if z is not None and z.stride(-1) != 1: - z = z.contiguous() - if B.dim() == 3: - B = rearrange(B, "b dstate l -> b 1 dstate l") - ctx.squeeze_B = True - if C.dim() == 3: - C = rearrange(C, "b dstate l -> b 1 dstate l") - ctx.squeeze_C = True - if D is not None and (D.dtype != torch.float): - ctx._d_dtype = D.dtype - D = D.float() - if delta_bias is not None and (delta_bias.dtype != torch.float): - ctx._delta_bias_dtype = delta_bias.dtype - delta_bias = delta_bias.float() - - assert u.shape[1] % (B.shape[1] * nrows) == 0 - assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile - - if backnrows > 0: - assert u.shape[1] % (B.shape[1] * backnrows) == 0 - assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile - else: - backnrows = nrows - ctx.backnrows = backnrows - - if MODE in ["mamba_ssm"]: - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) - - elif MODE in ["sscore"]: - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - elif MODE in ["sstest"]: - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) - else: - raise NotImplementedError - - ctx.delta_softplus = delta_softplus - ctx.has_z = z is not None - - last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out if not return_last_state else (out, last_state) - else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) - if MODE in ["mamba_ssm", "sstest"]: - out_z = rest[0] - return out_z if not return_last_state else (out_z, last_state) - elif MODE in ["sscore"]: - return out if not return_last_state else (out, last_state) - - @staticmethod - def backward(ctx, dout, *args): - if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - z = None - out = None - else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - # Here we just pass in None and dz will be allocated in the C++ code. - if MODE in ["mamba_ssm"]: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here - ) - elif MODE in ["sstest"]: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False, ctx.backnrows # option to recompute out_z, not used here - ) - elif MODE in ["sscore"]: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows - ) - else: - raise NotImplementedError - - dz = rest[0] if ctx.has_z else None - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - - _dD = None - if D is not None: - if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): - _dD = dD.to(ctx._d_dtype) - else: - _dD = dD - - _ddelta_bias = None - if delta_bias is not None: - if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): - _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) - else: - _ddelta_bias = ddelta_bias - - return (du, ddelta, dA, dB, dC, - dD if D is not None else None, - dz, - ddelta_bias if delta_bias is not None else None, - None, None, None, None) - - def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): - """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is - not considered in the backward pass. - """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) - - return selective_scan_fn - - -def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): - """ - u: r(B D L) - delta: r(B D L) - A: c(D N) or r(D N) - B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - out: r(B D L) - last_state (optional): r(B D dstate) or c(B D dstate) - """ - dtype_in = u.dtype - u = u.float() - delta = delta.float() - if delta_bias is not None: - delta = delta + delta_bias[..., None].float() - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - is_variable_B = B.dim() >= 3 - is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) - else: - B = B.float() - C = C.float() - x = A.new_zeros((batch, dim, dstate)) - ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) - else: - if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) - else: - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) - if is_variable_C and C.dim() == 4: - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) - else: - if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) - else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - if y.is_complex(): - y = y.real * 2 - ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) - out = y if D is None else y + u * rearrange(D, "d -> d 1") - if z is not None: - out = out * F.silu(z) - out = out.to(dtype=dtype_in) - return out if not return_last_state else (out, last_state) - - -# MODE = "mamba_ssm" -# MODE = "sscore" -# MODE = "sstest" -MODE = "mamba_ssm_sscore" # 1344 items pass -MODE = "mamba_ssm_sstest" # 1344 items pass - -if MODE in ["mamba_ssm"]: - import selective_scan_cuda as selective_scan_cuda - selective_scan_fn = build_selective_scan_fn(selective_scan_cuda, mode=MODE) - selective_scan_ref = selective_scan_ref -elif MODE in ["sscore"]: - import selective_scan_cuda_core - selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode=MODE) - selective_scan_ref = selective_scan_ref -elif MODE in ["sstest"]: - import selective_scan_cuda_test - selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode=MODE) - selective_scan_ref = selective_scan_ref -elif MODE in ["mamba_ssm_sscore"]: - import selective_scan_cuda_core - import selective_scan_cuda - selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscore") - selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") -elif MODE in ["mamba_ssm_sstest"]: - import selective_scan_cuda_test - import selective_scan_cuda - selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode="sstest") - selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm") -else: - raise NotImplementedError - -print("use MODE:", MODE) -import time; time.sleep(10) - - -# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('itype', [torch.float32]) -@pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize("return_last_state", [True]) -@pytest.mark.parametrize('has_delta_bias', [False, True]) -@pytest.mark.parametrize('delta_softplus', [False, True]) -# @pytest.mark.parametrize('has_z', [False, True]) -@pytest.mark.parametrize('has_z', [False]) -@pytest.mark.parametrize('has_D', [False, True]) -@pytest.mark.parametrize("varBC_groups", [1, 2]) -# @pytest.mark.parametrize("is_variable_C", [False, True]) -@pytest.mark.parametrize("is_variable_C", [True]) -# @pytest.mark.parametrize("is_variable_B", [False, True]) -@pytest.mark.parametrize("is_variable_B", [True]) -@pytest.mark.parametrize("nrows", [1, 2, 3, 4]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype, nrows): - print(f'method: {selective_scan_cuda}') - if varBC_groups > 1 and (not is_variable_B or not is_variable_C): - pytest.skip() # This config is not applicable - device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - if has_z: # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - dim = 24 - dstate = 8 - is_complex = wtype == torch.complex64 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - if not is_variable_B: - B_shape = (dim, dstate) - elif varBC_groups == 1: - B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, - requires_grad=True) - if not is_variable_C: - C_shape = (dim, dstate) - elif varBC_groups == 1: - C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, - requires_grad=True) - if has_D: - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - else: - D = None - if has_z: - z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) - else: - z = None - if has_delta_bias: - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - else: - delta_bias = None - u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) - delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() - C_ref = C.detach().clone().requires_grad_() - D_ref = D.detach().clone().requires_grad_() if D is not None else None - z_ref = z.detach().clone().requires_grad_() if z is not None else None - u_ref = u.detach().clone().requires_grad_() - delta_ref = delta.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out, *rest = selective_scan_fn( - u, delta, A, B, C, D, z=z, - delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state, nrows=nrows - ) - if return_last_state: - state = rest[0] - out_ref, *rest = selective_scan_ref( - u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, - delta_bias=delta_bias_ref, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) - if return_last_state: - state_ref = rest[0] - # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - if return_last_state: - print(f'State max diff: {(state - state_ref).abs().max().item()}') - assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) - - g = torch.randn_like(out) - out_ref.backward(g) - out.backward(g) - - print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') - print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') - print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') - print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') - print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') - if has_D: - print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') - if has_z: - print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') - if has_delta_bias: - print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') - - assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) - assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) - assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) - assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, - atol=atolw if not is_variable_B else atol) - assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, - atol=atolw if not is_variable_C else atol) - if has_D: - assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) - if has_z: - assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) - if has_delta_bias: - assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) - - diff --git a/tests/ops/test_selective_scan_speed.py b/tests/ops/test_selective_scan_speed.py deleted file mode 100644 index 6110b29dc..000000000 --- a/tests/ops/test_selective_scan_speed.py +++ /dev/null @@ -1,334 +0,0 @@ -# Modified by Mzero #20240123 -# Copyright (C) 2023, Tri Dao, Albert Gu. - -import math -import torch -import torch.nn.functional as F -import pytest -import torch -import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd -from einops import rearrange, repeat -import time -from functools import partial - - -def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None): - MODE = mode - - class SelectiveScanFn(torch.autograd.Function): - @staticmethod - def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if z is not None and z.stride(-1) != 1: - z = z.contiguous() - if B.dim() == 3: - B = rearrange(B, "b dstate l -> b 1 dstate l") - ctx.squeeze_B = True - if C.dim() == 3: - C = rearrange(C, "b dstate l -> b 1 dstate l") - ctx.squeeze_C = True - if D is not None and (D.dtype != torch.float): - ctx._d_dtype = D.dtype - D = D.float() - if delta_bias is not None and (delta_bias.dtype != torch.float): - ctx._delta_bias_dtype = delta_bias.dtype - delta_bias = delta_bias.float() - - assert u.shape[1] % (B.shape[1] * nrows) == 0 - assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile - - if backnrows > 0: - assert u.shape[1] % (B.shape[1] * backnrows) == 0 - assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile - else: - backnrows = nrows - ctx.backnrows = backnrows - - if MODE in ["mamba_ssm"]: - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) - - elif MODE in ["sscore"]: - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows) - elif MODE in ["sstest"]: - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows) - else: - raise NotImplementedError - - ctx.delta_softplus = delta_softplus - ctx.has_z = z is not None - - last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out if not return_last_state else (out, last_state) - else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) - if MODE in ["mamba_ssm", "sstest"]: - out_z = rest[0] - return out_z if not return_last_state else (out_z, last_state) - elif MODE in ["sscore"]: - return out if not return_last_state else (out, last_state) - - @staticmethod - def backward(ctx, dout, *args): - if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - z = None - out = None - else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - # Here we just pass in None and dz will be allocated in the C++ code. - if MODE in ["mamba_ssm"]: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here - ) - elif MODE in ["sstest"]: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False, ctx.backnrows # option to recompute out_z, not used here - ) - elif MODE in ["sscore"]: - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows - ) - else: - raise NotImplementedError - - dz = rest[0] if ctx.has_z else None - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - - _dD = None - if D is not None: - if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype): - _dD = dD.to(ctx._d_dtype) - else: - _dD = dD - - _ddelta_bias = None - if delta_bias is not None: - if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype): - _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype) - else: - _ddelta_bias = ddelta_bias - - return (du, ddelta, dA, dB, dC, - dD if D is not None else None, - dz, - ddelta_bias if delta_bias is not None else None, - None, None, None, None) - - def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1): - """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is - not considered in the backward pass. - """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows) - - selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}" - print(repr(selective_scan_fn), "==", selective_scan_fn.__repr__()) - - return selective_scan_fn - - -def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): - """ - u: r(B D L) - delta: r(B D L) - A: c(D N) or r(D N) - B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - out: r(B D L) - last_state (optional): r(B D dstate) or c(B D dstate) - """ - dtype_in = u.dtype - u = u.float() - delta = delta.float() - if delta_bias is not None: - delta = delta + delta_bias[..., None].float() - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - is_variable_B = B.dim() >= 3 - is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) - else: - B = B.float() - C = C.float() - x = A.new_zeros((batch, dim, dstate)) - ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) - else: - if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) - else: - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) - if is_variable_C and C.dim() == 4: - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) - else: - if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) - else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - if y.is_complex(): - y = y.real * 2 - ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) - out = y if D is None else y + u * rearrange(D, "d -> d 1") - if z is not None: - out = out * F.silu(z) - out = out.to(dtype=dtype_in) - return out if not return_last_state else (out, last_state) - - -def test_speed(): - wtype = torch.float32 - itype = torch.float32 - is_variable_B = True - is_variable_C = True - has_D = True - has_z = False # sscore not support z - has_delta_bias = True - varBC_groups = 2 - seqlen = 4096 - seqlen = 128 - seqlen = 64 - batch_size = 128 - dim = 24 - dim = 96 - dim = 384 - dim = 768 - dstate = 8 - # dstate = 24 - delta_softplus = True - is_complex = wtype == torch.complex64 - device = 'cuda' - TIMES = 1000 - import selective_scan_cuda_core - import selective_scan_cuda_test - import selective_scan_cuda - # copied from test_selective_scan ====================== - torch.random.manual_seed(0) - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - if not is_variable_B: - B_shape = (dim, dstate) - elif varBC_groups == 1: - B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, - requires_grad=True) - if not is_variable_C: - C_shape = (dim, dstate) - elif varBC_groups == 1: - C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, - requires_grad=True) - if has_D: - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - else: - D = None - if has_z: - z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) - else: - z = None - if has_delta_bias: - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - else: - delta_bias = None - u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) - delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() - C_ref = C.detach().clone().requires_grad_() - D_ref = D.detach().clone().requires_grad_() if D is not None else None - z_ref = z.detach().clone().requires_grad_() if z is not None else None - u_ref = u.detach().clone().requires_grad_() - delta_ref = delta.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - # ================================ - starts = [] - ends = [] - tests = [ - partial(build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm", tag="ori"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=3), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f1b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=4), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f3b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), - partial(build_selective_scan_fn(selective_scan_cuda_test, mode="sstest", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f3b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b1"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=1), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f1b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=1, backnrows=2), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b2"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=2, backnrows=2), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f2b3"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=3, backnrows=3), - # partial(build_selective_scan_fn(selective_scan_cuda_core, mode="sscore", tag="f4b4"), u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state=True, nrows=4, backnrows=4), - - ] - - for test in tests: - s = time.time() - for _ in range(TIMES): - with torch.no_grad(): - test() - torch.cuda.synchronize() - torch.cuda.empty_cache() - e = time.time() - starts.append(s) - ends.append(e) - print("fwd", test.func, e - s, flush=True) - for test in tests: - s = time.time() - for _ in range(TIMES): - outs = test() - outs[0].sum().backward() - torch.cuda.synchronize() - torch.cuda.empty_cache() - e = time.time() - starts.append(s) - ends.append(e) - print("fwdbwd", test.func, e - s, flush=True) - -test_speed() \ No newline at end of file From 45ac3f43e5c762565bce66302a7e622c45eb7677 Mon Sep 17 00:00:00 2001 From: MzeroMiko <3496274007@qq.com> Date: Sat, 17 Feb 2024 21:55:06 +0800 Subject: [PATCH 9/9] update --- kernel/test_selective_scan_benchmark.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/kernel/test_selective_scan_benchmark.py b/kernel/test_selective_scan_benchmark.py index 4c063be30..1e86eb349 100644 --- a/kernel/test_selective_scan_benchmark.py +++ b/kernel/test_selective_scan_benchmark.py @@ -5,6 +5,7 @@ import torch from tqdm import tqdm import triton +from einops import rearrange, repeat try: from flash_attn.bert_padding import pad_input, unpad_input @@ -13,11 +14,7 @@ pass - - def benchmark_mamba(batch, head, length, dim_head, d_state, selective_scan_cuda, *args): - from einops import rearrange, repeat - d_model = dim_head * head expand = 2 d_inner = d_model * expand @@ -264,8 +261,6 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ def benchmark_mamba_fwdbwd(batch, head, length, dim_head, d_state, selective_scan_fn, *args): - from einops import rearrange, repeat - d_model = dim_head * head expand = 2 d_inner = d_model * expand