44# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
55"""Tensor and pipeline parallel groups."""
66import contextlib
7+ from typing import Optional
78
89import torch
910
1415# Pipeline model parallel group that the current rank belongs to.
1516_PIPELINE_MODEL_PARALLEL_GROUP = None
1617
18+ # when people blindly call `torch.distributed.all_reduce` etc,
19+ # it will use this group. It is initialized with the `backend`
20+ # parameter of `init_distributed_environment` below.
21+ # Essentially, this is `torch.distributed.group.WORLD`.
22+ # We leave a line here to note that this is device-specific.
23+ # Note that this variable is not safe to use, because when users
24+ # call `init_distributed_environment` first, and then destroy
25+ # the process group themselves, this variable will keep a reference to the
26+ # destroyed process group, which is not useful.
27+ _DEVICE_WORLD_GROUP = None
28+
29+ # duing `init_distributed_environment`, we will also initialize a
30+ # group with `gloo` backend, to allow direct coordination between
31+ # processes through the CPU.
32+ _CPU_WORLD_GROUP = None
33+
34+ # In summary, after calling `init_distributed_environment`, we will
35+ # always have two groups: one for device-specific (and is the default)
36+ # and one for CPU. All processes will be part of both groups.
37+
1738# A list of global ranks for each pipeline group to ease calculation of the
1839# source rank when broadcasting from the first or last pipeline stage.
1940_PIPELINE_GLOBAL_RANKS = None
2041
2142
43+ def init_distributed_environment (
44+ world_size : int ,
45+ rank : int ,
46+ distributed_init_method : Optional [str ] = None ,
47+ local_rank : int = - 1 ,
48+ backend : str = "nccl" ,
49+ ):
50+ if not torch .distributed .is_initialized ():
51+ assert distributed_init_method is not None , (
52+ "distributed_init_method must be provided when initializing "
53+ "distributed environment" )
54+ # this backend is used for WORLD
55+ torch .distributed .init_process_group (
56+ backend = backend ,
57+ init_method = distributed_init_method ,
58+ world_size = world_size ,
59+ rank = rank )
60+ global _DEVICE_WORLD_GROUP , _CPU_WORLD_GROUP
61+ _DEVICE_WORLD_GROUP = torch .distributed .group .WORLD
62+ ranks = list (range (torch .distributed .get_world_size ()))
63+ _CPU_WORLD_GROUP = torch .distributed .new_group (ranks = ranks ,
64+ backend = "gloo" )
65+
66+
2267def initialize_model_parallel (
2368 tensor_model_parallel_size : int = 1 ,
2469 pipeline_model_parallel_size : int = 1 ,
70+ backend : Optional [str ] = None ,
2571) -> None :
2672 """
2773 Initialize model parallel groups.
@@ -48,6 +94,8 @@ def initialize_model_parallel(
4894 # Get world size and rank. Ensure some consistencies.
4995 assert torch .distributed .is_initialized ()
5096 world_size : int = torch .distributed .get_world_size ()
97+ # get the backend of _DEVICE_WORLD_GROUP
98+ backend = backend or torch .distributed .get_backend ()
5199
52100 if (world_size !=
53101 tensor_model_parallel_size * pipeline_model_parallel_size ):
@@ -69,7 +117,7 @@ def initialize_model_parallel(
69117 for i in range (num_tensor_model_parallel_groups ):
70118 ranks = range (i * tensor_model_parallel_size ,
71119 (i + 1 ) * tensor_model_parallel_size )
72- group = torch .distributed .new_group (ranks )
120+ group = torch .distributed .new_group (ranks , backend = backend )
73121 if rank in ranks :
74122 _TENSOR_MODEL_PARALLEL_GROUP = group
75123
@@ -80,7 +128,7 @@ def initialize_model_parallel(
80128 "pipeline model parallel group is already initialized" )
81129 for i in range (num_pipeline_model_parallel_groups ):
82130 ranks = range (i , world_size , num_pipeline_model_parallel_groups )
83- group = torch .distributed .new_group (ranks )
131+ group = torch .distributed .new_group (ranks , backend = backend )
84132 if rank in ranks :
85133 _PIPELINE_MODEL_PARALLEL_GROUP = group
86134 _PIPELINE_GLOBAL_RANKS = ranks
@@ -89,14 +137,17 @@ def initialize_model_parallel(
89137def ensure_model_parallel_initialized (
90138 tensor_model_parallel_size : int ,
91139 pipeline_model_parallel_size : int ,
140+ backend : Optional [str ] = None ,
92141) -> None :
93142 """Helper to initialize model parallel groups if they are not initialized,
94143 or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
95144 values if the model parallel groups are initialized.
96145 """
146+ # get the backend of _DEVICE_WORLD_GROUP
147+ backend = backend or torch .distributed .get_backend ()
97148 if not model_parallel_is_initialized ():
98149 initialize_model_parallel (tensor_model_parallel_size ,
99- pipeline_model_parallel_size )
150+ pipeline_model_parallel_size , backend )
100151 return
101152
102153 assert (
@@ -117,6 +168,12 @@ def model_parallel_is_initialized():
117168 and _PIPELINE_MODEL_PARALLEL_GROUP is not None )
118169
119170
171+ def get_cpu_world_group ():
172+ """Get the CPU world group."""
173+ assert _CPU_WORLD_GROUP is not None , ("CPU world group is not initialized" )
174+ return _CPU_WORLD_GROUP
175+
176+
120177def get_tensor_model_parallel_group ():
121178 """Get the tensor model parallel group the caller rank belongs to."""
122179 assert _TENSOR_MODEL_PARALLEL_GROUP is not None , (
0 commit comments