Skip to content

Commit abf02fa

Browse files
committed
Reduce combinations of bool switch to reduce wheel size
1 parent 2ca8db7 commit abf02fa

File tree

2 files changed

+31
-43
lines changed

2 files changed

+31
-43
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -404,19 +404,18 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
404404
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
405405
void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
406406
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
407-
BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] {
408-
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
409-
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
410-
constexpr int kSmemSize = Ktraits::kSmemSize;
411-
dim3 grid(params.batch, params.dim);
412-
auto kernel = &causal_conv1d_fwd_kernel<Ktraits, kHasSeqPosIdx>;
413-
if (kSmemSize >= 48 * 1024) {
414-
C10_CUDA_CHECK(cudaFuncSetAttribute(
415-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
416-
}
417-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
418-
C10_CUDA_KERNEL_LAUNCH_CHECK();
419-
});
407+
constexpr kHasSeqPosIdx = false;
408+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
409+
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
410+
constexpr int kSmemSize = Ktraits::kSmemSize;
411+
dim3 grid(params.batch, params.dim);
412+
auto kernel = &causal_conv1d_fwd_kernel<Ktraits, kHasSeqPosIdx>;
413+
if (kSmemSize >= 48 * 1024) {
414+
C10_CUDA_CHECK(cudaFuncSetAttribute(
415+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
416+
}
417+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
418+
C10_CUDA_KERNEL_LAUNCH_CHECK();
420419
});
421420
}
422421

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -311,26 +311,21 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
311311
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
312312
// processing 1 row.
313313
constexpr int kNRows = 1;
314+
constexpr bool kIsVariableB = true;
315+
constexpr bool kIsVariableC = true;
316+
constexpr bool kHasZ = true;
314317
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
315-
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
316-
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
317-
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
318-
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
319-
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
320-
// constexpr int kSmemSize = Ktraits::kSmemSize;
321-
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
322-
// printf("smem_size = %d\n", kSmemSize);
323-
dim3 grid(params.batch, params.dim / kNRows);
324-
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
325-
if (kSmemSize >= 48 * 1024) {
326-
C10_CUDA_CHECK(cudaFuncSetAttribute(
327-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
328-
}
329-
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
330-
C10_CUDA_KERNEL_LAUNCH_CHECK();
331-
});
332-
});
333-
});
318+
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
319+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
320+
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
321+
dim3 grid(params.batch, params.dim / kNRows);
322+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
323+
if (kSmemSize >= 48 * 1024) {
324+
C10_CUDA_CHECK(cudaFuncSetAttribute(
325+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
326+
}
327+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
328+
C10_CUDA_KERNEL_LAUNCH_CHECK();
334329
});
335330
});
336331
}
@@ -369,27 +364,23 @@ template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaS
369364

370365
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
371366

372-
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
367+
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
373368
if (ITYPE == at::ScalarType::Half) { \
374369
using input_t = at::Half; \
370+
using weight_t = at::Half; \
375371
__VA_ARGS__(); \
376372
} else if (ITYPE == at::ScalarType::BFloat16) { \
377373
using input_t = at::BFloat16; \
374+
using weight_t = at::BFloat16; \
378375
__VA_ARGS__(); \
379376
} else if (ITYPE == at::ScalarType::Float) { \
380377
using input_t = float; \
378+
using weight_t = float; \
381379
__VA_ARGS__(); \
382380
} else { \
383381
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
384382
}
385383

386-
#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \
387-
if (WTYPE == at::ScalarType::Float) { \
388-
using weight_t = float; \
389-
__VA_ARGS__(); \
390-
} else { \
391-
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
392-
}
393384

394385
template<typename input_t, typename weight_t>
395386
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
@@ -598,10 +589,8 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
598589
// Cast to char to avoid compiler warning about narrowing
599590
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
600591
auto stream = at::cuda::getCurrentCUDAStream().stream();
601-
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
602-
DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] {
592+
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
603593
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
604-
});
605594
});
606595
std::vector<at::Tensor> result = {out, x.value()};
607596
if (has_z) { result.push_back(out_z); }

0 commit comments

Comments
 (0)