diff --git a/CMakeLists.txt b/CMakeLists.txt index 015ea93c53..dbeb1902dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,9 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.9;9.0") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + list(APPEND CUDA_SUPPORTED_ARCHS "10.0" "11.0" "12.0") +elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) list(APPEND CUDA_SUPPORTED_ARCHS "10.0" "10.1" "12.0") endif() diff --git a/csrc/flash_attn/flash_api_sparse.cpp b/csrc/flash_attn/flash_api_sparse.cpp index 2ff90749f0..62a92d8f78 100644 --- a/csrc/flash_attn/flash_api_sparse.cpp +++ b/csrc/flash_attn/flash_api_sparse.cpp @@ -16,6 +16,12 @@ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +namespace { +inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) { + return at::cuda::CUDAGuard(static_cast(t.get_device())); +} +} // namespace + namespace FLASH_NAMESPACE { // @@ -231,7 +237,7 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(q); auto opts = q.options(); @@ -435,7 +441,7 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_ // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(q); auto opts = q.options(); auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 6a4f0e6ee6..0cfebb0146 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -59,6 +59,12 @@ namespace pybind11::detail { #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +namespace { +inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) { + return at::cuda::CUDAGuard(static_cast(t.get_device())); +} +} // namespace + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -629,7 +635,7 @@ mha_fwd_get_scheduler_metadata( // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(seqused_k); auto opts = seqused_k.options(); // This needs to be set after get_num_splits @@ -884,7 +890,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(q); at::Tensor softmax_lse; if (!is_varlen_q) { @@ -1454,7 +1460,7 @@ std::vector mha_bwd( // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto device_guard = make_cuda_guard_from_tensor(q); auto opts = q.options(); // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 diff --git a/hopper/sm90_pipeline_no_cluster.hpp b/hopper/sm90_pipeline_no_cluster.hpp index 65a3d1554b..1fb805aec1 100644 --- a/hopper/sm90_pipeline_no_cluster.hpp +++ b/hopper/sm90_pipeline_no_cluster.hpp @@ -39,7 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned(