Skip to content

Commit ba4b752

Browse files
njhillMu Huai
authored andcommitted
[BugFix] Fix torch distributed stateless PG backend init (vllm-project#14870)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent d5e2839 commit ba4b752

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,10 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
7676
GPUs_per_dp_rank))
7777
proc.start()
7878
procs.append(proc)
79+
exit_code = 0
7980
for proc in procs:
8081
proc.join()
82+
if proc.exitcode:
83+
exit_code = proc.exitcode
84+
85+
exit(exit_code)

vllm/distributed/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,10 @@ def stateless_init_torch_distributed_process_group(
299299
# different systems (e.g. RPC) in case the store is multi-tenant.
300300
prefix_store = PrefixStore(init_method, store)
301301

302-
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
303-
304302
pg: ProcessGroup = ProcessGroup(
305303
prefix_store,
306304
group_rank,
307305
group_size,
308-
pg_options,
309306
)
310307

311308
if backend == "gloo":
@@ -327,7 +324,10 @@ def stateless_init_torch_distributed_process_group(
327324
backend_options)
328325
backend_type = ProcessGroup.BackendType.NCCL
329326
device = torch.device("cuda")
327+
else:
328+
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
330329

330+
pg._set_default_backend(backend_type)
331331
backend_class._set_sequence_number_for_group()
332332

333333
pg._register_backend(device, backend_type, backend_class)

0 commit comments

Comments
 (0)