Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
init_test_distributed_environment(1, tensor_parallel_size, rank, rank,
distributed_init_port)
num_elements = 8
all_tensors = [
Expand All @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
init_test_distributed_environment(1, tensor_parallel_size, rank, rank,
distributed_init_port)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
Expand Down Expand Up @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
init_test_distributed_environment(1, tensor_parallel_size, rank, rank,
distributed_init_port)
test_dict = {
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
init_test_distributed_environment(1, world_size, rank, rank,
distributed_init_port)

custom_ar.init_custom_ar()
Expand Down Expand Up @@ -58,7 +58,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
init_test_distributed_environment(1, world_size, rank, rank,
distributed_init_port)

sz = 1024
Expand Down
2 changes: 0 additions & 2 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
is_driver_worker=True,
)

# FIXME(woosuk): We are not properly initializing pynccl when
# we have multiple nodes.
self._run_workers("init_device")
self._run_workers(
"load_model",
Expand Down
18 changes: 8 additions & 10 deletions vllm/model_executor/parallel_utils/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
init_method=None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
local_rank: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
Expand All @@ -219,25 +220,22 @@ def __init__(
store=store,
group_name=group_name,
pg_options=pg_options)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
Comment on lines -222 to -223
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cound still save self.world_size/self.rank/self.local_rank , for convenience.

torch.cuda.set_device(self.rank)
if self.rank == 0:
torch.cuda.set_device(local_rank)
if rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.rank)
tensor = torch.ByteTensor(list(
self.unique_id.internal)).cuda(local_rank)
dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist()
self.unique_id = NcclUniqueId()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line seems useless, as self.unique_id has already defined in line 227.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good to me.

for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)
result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size,
self.unique_id, rank)
assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}")
self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}")

def all_reduce(self,
tensor: torch.Tensor,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/parallel_utils/pynccl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass


def init_process_group(world_size: int, rank: int, init_method: str) -> None:
def init_process_group(world_size: int, local_rank: int, rank: int,
init_method: str) -> None:
assert not is_initialized()
global comm
comm = NCCLCommunicator(init_method=init_method,
world_size=world_size,
local_rank=local_rank,
rank=rank)


Expand Down
6 changes: 5 additions & 1 deletion vllm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def init_test_distributed_environment(
pipeline_parallel_size: int,
tensor_parallel_size: int,
local_rank: int,
rank: int,
distributed_init_port: str,
) -> None:
Expand All @@ -16,7 +17,10 @@ def init_test_distributed_environment(
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
parallel_config, rank, distributed_init_method=distributed_init_method)
parallel_config,
local_rank,
rank,
distributed_init_method=distributed_init_method)


def multi_process_tensor_parallel(
Expand Down
7 changes: 4 additions & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def init_device(self) -> None:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
init_distributed_environment(self.parallel_config, self.local_rank,
self.rank, self.distributed_init_method)
# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down Expand Up @@ -249,6 +249,7 @@ def get_cache_block_size_bytes(self, block_size: int,

def init_distributed_environment(
parallel_config: ParallelConfig,
local_rank: int,
rank: int,
distributed_init_method: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -282,9 +283,9 @@ def init_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
local_rank=local_rank,
rank=rank,
init_method=distributed_init_method,
)
Expand Down