|
11 | 11 | gpu_p2p_access_check) |
12 | 12 | from vllm.distributed.parallel_state import is_in_the_same_node |
13 | 13 | from vllm.logger import init_logger |
14 | | -from vllm.utils import cuda_device_count_stateless |
| 14 | +from vllm.utils import cuda_device_count_stateless, is_full_nvlink |
15 | 15 |
|
16 | 16 | try: |
17 | | - import pynvml |
18 | | - |
19 | | - # Simulate ImportError if custom_ar ops are not supported. |
20 | | - if not ops.is_custom_op_supported("_C_custom_ar::meta_size"): |
21 | | - raise ImportError("custom_ar", __file__) |
22 | | - |
| 17 | + assert ops.is_custom_op_supported("_C_custom_ar::meta_size") |
23 | 18 | custom_ar = True |
24 | | - |
25 | | - @contextmanager |
26 | | - def _nvml(): |
27 | | - try: |
28 | | - pynvml.nvmlInit() |
29 | | - yield |
30 | | - finally: |
31 | | - pynvml.nvmlShutdown() |
32 | | - |
33 | | -except ImportError: |
34 | | - # For AMD GPUs |
| 19 | +except Exception: |
| 20 | + # For AMD GPUs and CPUs |
35 | 21 | custom_ar = False |
36 | | - pynvml = None |
37 | | - |
38 | | - @contextmanager |
39 | | - def _nvml(): |
40 | | - try: |
41 | | - yield |
42 | | - finally: |
43 | | - pass |
44 | | - |
45 | 22 |
|
46 | 23 | logger = init_logger(__name__) |
47 | 24 |
|
48 | 25 |
|
49 | | -@_nvml() |
50 | | -def _is_full_nvlink(device_ids: List[int]) -> bool: |
51 | | - """ |
52 | | - query if the set of gpus are fully connected by nvlink (1 hop) |
53 | | - Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`, |
54 | | - so it works on real physical device ids. |
55 | | - """ |
56 | | - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] |
57 | | - for i, handle in enumerate(handles): |
58 | | - for j, peer_handle in enumerate(handles): |
59 | | - if i < j: |
60 | | - try: |
61 | | - p2p_status = pynvml.nvmlDeviceGetP2PStatus( |
62 | | - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) |
63 | | - if p2p_status != pynvml.NVML_P2P_STATUS_OK: |
64 | | - return False |
65 | | - except pynvml.NVMLError as error: |
66 | | - logger.error( |
67 | | - "NVLink detection failed. This is normal if your" |
68 | | - " machine has no NVLink equipped.", |
69 | | - exc_info=error) |
70 | | - return False |
71 | | - return True |
72 | | - |
73 | | - |
74 | 26 | def _can_p2p(rank: int, world_size: int) -> bool: |
75 | 27 | for i in range(world_size): |
76 | 28 | if i == rank: |
@@ -161,7 +113,7 @@ def __init__(self, |
161 | 113 | # test nvlink first, this will filter out most of the cases |
162 | 114 | # where custom allreduce is not supported |
163 | 115 | # this checks hardware and driver support for NVLink |
164 | | - full_nvlink = _is_full_nvlink(physical_device_ids) |
| 116 | + full_nvlink = is_full_nvlink(physical_device_ids) |
165 | 117 | if world_size > 2 and not full_nvlink: |
166 | 118 | logger.warning( |
167 | 119 | "Custom allreduce is disabled because it's not supported on" |
|
0 commit comments