Skip to content
6 changes: 3 additions & 3 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ mha_fwd_get_scheduler_metadata(

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<c10::DeviceIndex>(seqused_k.get_device())};

auto opts = seqused_k.options();
// This needs to be set after get_num_splits
Expand Down Expand Up @@ -876,7 +876,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<c10::DeviceIndex>(q.get_device())};

at::Tensor softmax_lse;
if (!is_varlen_q) {
Expand Down Expand Up @@ -1463,7 +1463,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tenso

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<c10::DeviceIndex>(q.get_device())};

auto opts = q.options();
// Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
Expand Down