Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 08dedd5

Browse files
youkaichaoRobert Shaw
authored andcommitted
[misc][cuda] use nvml to avoid accidentally cuda initialization (vllm-project#6007)
1 parent b4eec34 commit 08dedd5

File tree

13 files changed

+86
-68
lines changed

13 files changed

+86
-68
lines changed

tests/kernels/test_cutlass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from tests.nm_utils.utils_skip import should_skip_test_group
1111
from vllm import _custom_ops as ops
12+
from vllm.utils import get_device_capability_stateless
1213

1314
if should_skip_test_group(group_name="TEST_KERNELS"):
1415
pytest.skip("TEST_KERNELS=DISABLE, skipping kernels test group",
@@ -18,7 +19,7 @@
1819
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
1920
]
2021

21-
capability = torch.cuda.get_device_capability()
22+
capability = get_device_capability_stateless()
2223
capability = capability[0] * 10 + capability[1]
2324

2425

tests/quantization/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import torch
22

33
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
4+
from vllm.utils import get_device_capability_stateless
45

56

67
def is_quant_method_supported(quant_method: str) -> bool:
78
# Currently, all quantization methods require Nvidia or AMD GPUs
89
if not torch.cuda.is_available():
910
return False
1011

11-
capability = torch.cuda.get_device_capability()
12+
capability = get_device_capability_stateless()
1213
capability = capability[0] * 10 + capability[1]
1314
return (capability >=
1415
QUANTIZATION_METHODS[quant_method].get_min_capability())

vllm/attention/ops/blocksparse_attention/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import torch
44

5-
from vllm.utils import is_cpu, is_hip
5+
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip
66

77
from .utils import (dense_to_crow_col, get_head_sliding_step,
88
get_sparse_attn_mask)
99

1010
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
11-
and torch.cuda.get_device_capability()[0] >= 8)
11+
and get_device_capability_stateless()[0] >= 8)
1212

1313
if IS_COMPUTE_8_OR_ABOVE:
1414
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
@@ -235,4 +235,4 @@ def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
235235
v,
236236
cu_seqlens_k,
237237
cu_seqlens_q=cu_seqlens_q,
238-
sm_scale=sm_scale)
238+
sm_scale=sm_scale)

vllm/attention/ops/prefix_prefill.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import triton
66
import triton.language as tl
77

8+
from vllm.utils import get_device_capability_stateless
9+
810
if triton.__version__ >= "2.1.0":
911

1012
@triton.jit
@@ -683,7 +685,7 @@ def context_attention_fwd(q,
683685
alibi_slopes=None,
684686
sliding_window=None):
685687

686-
cap = torch.cuda.get_device_capability()
688+
cap = get_device_capability_stateless()
687689
BLOCK = 128 if cap[0] >= 8 else 64
688690
# shape constraints
689691
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,66 +11,18 @@
1111
gpu_p2p_access_check)
1212
from vllm.distributed.parallel_state import is_in_the_same_node
1313
from vllm.logger import init_logger
14-
from vllm.utils import cuda_device_count_stateless
14+
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
1515

1616
try:
17-
import pynvml
18-
19-
# Simulate ImportError if custom_ar ops are not supported.
20-
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
21-
raise ImportError("custom_ar", __file__)
22-
17+
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
2318
custom_ar = True
24-
25-
@contextmanager
26-
def _nvml():
27-
try:
28-
pynvml.nvmlInit()
29-
yield
30-
finally:
31-
pynvml.nvmlShutdown()
32-
33-
except ImportError:
34-
# For AMD GPUs
19+
except Exception:
20+
# For AMD GPUs and CPUs
3521
custom_ar = False
36-
pynvml = None
37-
38-
@contextmanager
39-
def _nvml():
40-
try:
41-
yield
42-
finally:
43-
pass
44-
4522

4623
logger = init_logger(__name__)
4724

4825

49-
@_nvml()
50-
def _is_full_nvlink(device_ids: List[int]) -> bool:
51-
"""
52-
query if the set of gpus are fully connected by nvlink (1 hop)
53-
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
54-
so it works on real physical device ids.
55-
"""
56-
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
57-
for i, handle in enumerate(handles):
58-
for j, peer_handle in enumerate(handles):
59-
if i < j:
60-
try:
61-
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
62-
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
63-
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
64-
return False
65-
except pynvml.NVMLError as error:
66-
logger.error(
67-
"NVLink detection failed. This is normal if your"
68-
" machine has no NVLink equipped.",
69-
exc_info=error)
70-
return False
71-
return True
72-
73-
7426
def _can_p2p(rank: int, world_size: int) -> bool:
7527
for i in range(world_size):
7628
if i == rank:
@@ -161,7 +113,7 @@ def __init__(self,
161113
# test nvlink first, this will filter out most of the cases
162114
# where custom allreduce is not supported
163115
# this checks hardware and driver support for NVLink
164-
full_nvlink = _is_full_nvlink(physical_device_ids)
116+
full_nvlink = is_full_nvlink(physical_device_ids)
165117
if world_size > 2 and not full_nvlink:
166118
logger.warning(
167119
"Custom allreduce is disabled because it's not supported on"

vllm/lora/punica.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8+
from vllm.utils import get_device_capability_stateless
89

910

1011
def _check_punica_support():
1112
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
1213
return
1314

14-
if torch.cuda.get_device_capability() < (8, 0):
15+
if get_device_capability_stateless() < (8, 0):
1516
raise ImportError(
1617
"punica LoRA kernels require compute capability >= 8.0")
1718
else:

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1515
CompressionFormat, QuantizationArgs, QuantizationStrategy,
1616
find_first_name_or_class_match)
17+
from vllm.utils import get_device_capability_stateless
1718

1819

1920
class CompressedTensorsConfig(QuantizationConfig):
@@ -84,7 +85,7 @@ def get_config_filenames(cls) -> List[str]:
8485
return []
8586

8687
def _check_gptq_and_marlin_can_run(self):
87-
capability = torch.cuda.get_device_capability()
88+
capability = get_device_capability_stateless()
8889
capability = capability[0] * 10 + capability[1]
8990
if capability < 80:
9091
raise RuntimeError("The quantization config is not supported for ",

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
from vllm.model_executor.layers.quantization.base_config import (
1111
QuantizationConfig, QuantizeMethodBase)
1212
from vllm.model_executor.utils import set_weight_attrs
13-
from vllm.utils import print_warning_once
13+
from vllm.utils import get_device_capability_stateless, print_warning_once
1414

1515
ACTIVATION_SCHEMES = ["static", "dynamic"]
1616

1717
logger = init_logger(__name__)
1818

1919

2020
def cutlass_fp8_supported() -> bool:
21-
capability = torch.cuda.get_device_capability()
21+
capability = get_device_capability_stateless()
2222
capability = capability[0] * 10 + capability[1]
2323

2424
return ops.cutlass_scaled_mm_supports_fp8(capability)

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
set_weight_attrs)
1212
from vllm.model_executor.layers.quantization.base_config import (
1313
QuantizationConfig)
14+
from vllm.utils import get_device_capability_stateless
1415

1516
logger = init_logger(__name__)
1617

@@ -165,7 +166,7 @@ def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
165166
return False
166167

167168
# If the capability of the device is too low, cannot convert.
168-
major, minor = torch.cuda.get_device_capability()
169+
major, minor = get_device_capability_stateless()
169170
device_capability = major * 10 + minor
170171
if device_capability < cls.get_min_capability():
171172
return False

vllm/model_executor/layers/quantization/utils/marlin_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
1313
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1414
get_pack_factor, quantize_weights, sort_weights)
15+
from vllm.utils import get_device_capability_stateless
1516

16-
__cuda_arch = torch.cuda.get_device_capability()
17+
__cuda_arch = get_device_capability_stateless()
1718

1819
MARLIN_TILE = 16
1920

0 commit comments

Comments
 (0)