diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 3a503981a873..5a48e56a5fbb 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -17,6 +17,7 @@ import torch +from vllm.platforms import current_platform from vllm.triton_utils import triton logger = logging.getLogger(__name__) @@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. # Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != "hip" else "cuda" -device_torch_lib = getattr(torch, device) +device = "cuda" if current_platform.is_cuda_alike() else get_available_device() +device_torch_lib = getattr(torch, device, None) device_platform = _check_platform() is_amd = device_platform == "amd"