Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.9;9.0")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
list(APPEND CUDA_SUPPORTED_ARCHS "10.0" "11.0" "12.0")
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
list(APPEND CUDA_SUPPORTED_ARCHS "10.0" "10.1" "12.0")
endif()

Expand Down
10 changes: 8 additions & 2 deletions csrc/flash_attn/flash_api_sparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace {
inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) {
return at::cuda::CUDAGuard(static_cast<c10::DeviceIndex>(t.get_device()));
}
} // namespace

namespace FLASH_NAMESPACE {

//
Expand Down Expand Up @@ -231,7 +237,7 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto device_guard = make_cuda_guard_from_tensor(q);

auto opts = q.options();

Expand Down Expand Up @@ -435,7 +441,7 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto device_guard = make_cuda_guard_from_tensor(q);

auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
Expand Down
12 changes: 9 additions & 3 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ namespace pybind11::detail {
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace {
inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) {
return at::cuda::CUDAGuard(static_cast<c10::DeviceIndex>(t.get_device()));
}
} // namespace

void set_params_fprop(Flash_fwd_params &params,
// sizes
const size_t b,
Expand Down Expand Up @@ -629,7 +635,7 @@ mha_fwd_get_scheduler_metadata(

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
auto device_guard = make_cuda_guard_from_tensor(seqused_k);

auto opts = seqused_k.options();
// This needs to be set after get_num_splits
Expand Down Expand Up @@ -884,7 +890,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto device_guard = make_cuda_guard_from_tensor(q);

at::Tensor softmax_lse;
if (!is_varlen_q) {
Expand Down Expand Up @@ -1454,7 +1460,7 @@ std::vector<at::Tensor> mha_bwd(

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto device_guard = make_cuda_guard_from_tensor(q);

auto opts = q.options();
// Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
Expand Down
2 changes: 1 addition & 1 deletion hopper/sm90_pipeline_no_cluster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base {
if (is_initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup;
uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup;
uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster;

cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
Expand Down