@@ -311,26 +311,21 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
311311 // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
312312 // processing 1 row.
313313 constexpr int kNRows = 1 ;
314+ constexpr bool kIsVariableB = true ;
315+ constexpr bool kIsVariableC = true ;
316+ constexpr bool kHasZ = true ;
314317 BOOL_SWITCH (params.seqlen % (kNThreads * kNItems ) == 0 , kIsEvenLen , [&] {
315- BOOL_SWITCH (params.is_variable_B , kIsVariableB , [&] {
316- BOOL_SWITCH (params.is_variable_C , kIsVariableC , [&] {
317- BOOL_SWITCH (params.z_ptr != nullptr , kHasZ , [&] {
318- BOOL_SWITCH (params.index_ptr != nullptr , kUseIndex , [&] {
319- using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads , kNItems , kNRows , kIsEvenLen , kIsVariableB , kIsVariableC , kHasZ , kUseIndex , input_t , weight_t >;
320- // constexpr int kSmemSize = Ktraits::kSmemSize;
321- constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof (typename Ktraits::scan_t );
322- // printf("smem_size = %d\n", kSmemSize);
323- dim3 grid (params.batch , params.dim / kNRows );
324- auto kernel = &selective_scan_fwd_kernel<Ktraits>;
325- if (kSmemSize >= 48 * 1024 ) {
326- C10_CUDA_CHECK (cudaFuncSetAttribute (
327- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
328- }
329- kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
330- C10_CUDA_KERNEL_LAUNCH_CHECK ();
331- });
332- });
333- });
318+ BOOL_SWITCH (params.index_ptr != nullptr , kUseIndex , [&] {
319+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads , kNItems , kNRows , kIsEvenLen , kIsVariableB , kIsVariableC , kHasZ , kUseIndex , input_t , weight_t >;
320+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof (typename Ktraits::scan_t );
321+ dim3 grid (params.batch , params.dim / kNRows );
322+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
323+ if (kSmemSize >= 48 * 1024 ) {
324+ C10_CUDA_CHECK (cudaFuncSetAttribute (
325+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize ));
326+ }
327+ kernel<<<grid, Ktraits::kNThreads , kSmemSize , stream>>> (params);
328+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
334329 });
335330 });
336331}
@@ -369,27 +364,23 @@ template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaS
369364
370365#define CHECK_SHAPE (x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ " )" )
371366
372- #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16 (ITYPE, NAME, ...) \
367+ #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (ITYPE, NAME, ...) \
373368 if (ITYPE == at::ScalarType::Half) { \
374369 using input_t = at::Half; \
370+ using weight_t = at::Half; \
375371 __VA_ARGS__ (); \
376372 } else if (ITYPE == at::ScalarType::BFloat16) { \
377373 using input_t = at::BFloat16; \
374+ using weight_t = at::BFloat16; \
378375 __VA_ARGS__ (); \
379376 } else if (ITYPE == at::ScalarType::Float) { \
380377 using input_t = float ; \
378+ using weight_t = float ; \
381379 __VA_ARGS__ (); \
382380 } else { \
383381 AT_ERROR (#NAME, " not implemented for input type '" , toString (ITYPE), " '" ); \
384382 }
385383
386- #define DISPATCH_WTYPE_FLOAT (WTYPE, NAME, ...) \
387- if (WTYPE == at::ScalarType::Float) { \
388- using weight_t = float ; \
389- __VA_ARGS__ (); \
390- } else { \
391- AT_ERROR (#NAME, " not implemented for weight type '" , toString (WTYPE), " '" ); \
392- }
393384
394385template <typename input_t , typename weight_t >
395386void selective_scan_fwd_cuda (SSMParamsBase ¶ms, cudaStream_t stream);
@@ -598,10 +589,8 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
598589 // Cast to char to avoid compiler warning about narrowing
599590 at::cuda::CUDAGuard device_guard{(char )u.get_device ()};
600591 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
601- DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16 (u.scalar_type (), " selective_scan_fwd" , [&] {
602- DISPATCH_WTYPE_FLOAT (A.scalar_type (), " selective_scan_fwd" , [&] {
592+ DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16 (u.scalar_type (), " selective_scan_fwd" , [&] {
603593 selective_scan_fwd_cuda<input_t , weight_t >(params, stream);
604- });
605594 });
606595 std::vector<at::Tensor> result = {out, x.value ()};
607596 if (has_z) { result.push_back (out_z); }
0 commit comments