diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 13916fc8c147..0395f7200fd7 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, rank, + init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) test_dict = { "a": torch.arange(8, dtype=torch.float32, device="cuda"), diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 0bd3bf883745..1e6e7f89a528 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, rank, + init_test_distributed_environment(1, world_size, rank, distributed_init_port) custom_ar.init_custom_ar() @@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, rank, + init_test_distributed_environment(1, world_size, rank, distributed_init_port) sz = 1024 diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 797f18915dec..29782045130a 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -14,7 +14,9 @@ def distributed_run(fn, world_size): for i in range(number_of_processes): env = os.environ.copy() env['RANK'] = str(i) + env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) + env['LOCAL_WORLD_SIZE'] = str(number_of_processes) env['MASTER_ADDR'] = 'localhost' env['MASTER_PORT'] = '12345' p = multiprocessing.Process(target=fn, args=(env, )) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index ca37d9fe9577..a0c2921df221 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -202,11 +202,11 @@ def __init__( init_method=None, timeout=datetime.timedelta(seconds=10), world_size: int = -1, - local_rank: int = -1, rank: int = -1, store=None, group_name: str = "", pg_options=None, + local_rank: int = -1, ): if not dist.is_initialized(): backend = backend or "nccl" @@ -220,6 +220,11 @@ def __init__( store=store, group_name=group_name, pg_options=pg_options) + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + if local_rank == -1: + local_rank = self.rank + self.local_rank = local_rank torch.cuda.set_device(local_rank) if rank == 0: self.unique_id = ncclGetUniqueId() diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/model_executor/parallel_utils/pynccl_utils.py index 5b5eebbde44f..45915b49a131 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/model_executor/parallel_utils/pynccl_utils.py @@ -35,8 +35,10 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(world_size: int, local_rank: int, rank: int, - init_method: str) -> None: +def init_process_group(world_size: int, + rank: int, + init_method: str, + local_rank: int = -1) -> None: assert not is_initialized() global comm logger.info(f"vLLM is using nccl=={ncclGetVersion()}") diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 735cc0037ba5..94e962e12e87 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -8,9 +8,9 @@ def init_test_distributed_environment( pipeline_parallel_size: int, tensor_parallel_size: int, - local_rank: int, rank: int, distributed_init_port: str, + local_rank: int = -1, ) -> None: parallel_config = ParallelConfig(pipeline_parallel_size, tensor_parallel_size, @@ -18,9 +18,9 @@ def init_test_distributed_environment( distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( parallel_config, - local_rank, rank, - distributed_init_method=distributed_init_method) + distributed_init_method=distributed_init_method, + local_rank=local_rank) def multi_process_tensor_parallel( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4ffe78040010..48facb57de19 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -97,8 +97,9 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.local_rank, - self.rank, self.distributed_init_method) + init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) @@ -249,9 +250,9 @@ def get_cache_block_size_bytes(self, block_size: int, def init_distributed_environment( parallel_config: ParallelConfig, - local_rank: int, rank: int, distributed_init_method: Optional[str] = None, + local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" if torch.distributed.is_initialized():