Skip to content

Commit b6cb444

Browse files
committed
Modify to all_gather takes group
1 parent fd62e70 commit b6cb444

9 files changed

Lines changed: 53 additions & 14 deletions

File tree

ignite/distributed/comp_models/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,13 @@ def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Uni
212212

213213
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op))
214214

215-
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
215+
def all_gather(
216+
self, tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
217+
) -> Union[torch.Tensor, float, List[float], List[str]]:
216218
if not isinstance(tensor, (torch.Tensor, Number, str)):
217219
raise TypeError(f"Unhandled input type {type(tensor)}")
218220

219-
return self._collective_op(tensor, self._do_all_gather)
221+
return self._collective_op(tensor, self._do_all_gather, group=group)
220222

221223
def broadcast(
222224
self, tensor: Union[torch.Tensor, float, str, None], src: int = 0, safe_mode: bool = False
@@ -268,7 +270,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
268270
pass
269271

270272
@abstractmethod
271-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
273+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Union[Any, List[int]]] = None) -> torch.Tensor:
272274
pass
273275

274276
@abstractmethod
@@ -336,7 +338,9 @@ def spawn(*args: Any, **kwargs: Any) -> None:
336338
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
337339
return tensor
338340

339-
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
341+
def all_gather(
342+
self, tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
343+
) -> Union[torch.Tensor, float, List[float], List[str]]:
340344
if isinstance(tensor, torch.Tensor):
341345
return tensor
342346
return cast(Union[List[float], List[str]], [tensor])
@@ -351,7 +355,7 @@ def broadcast(
351355
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
352356
return tensor
353357

354-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
358+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Union[Any, List[int]]] = None) -> torch.Tensor:
355359
return tensor
356360

357361
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:

ignite/distributed/comp_models/horovod.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import warnings
2-
from typing import Any, Callable, cast, Mapping, Optional, Tuple
2+
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, Union
33

44
import torch
55

66
from ignite.distributed.comp_models.base import ComputationModel
77

88
try:
99
import horovod.torch as hvd
10+
from horovod.common.process_sets import ProcessSet
1011

1112
try:
1213
# old API
@@ -184,9 +185,13 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
184185
# output can also torch min/max_return_type: (min/max_vals, indices)
185186
return reduced_res[0]
186187

187-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
188+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Union[Any, List[int]]] = None) -> torch.Tensor:
189+
if group and not isinstance(group, ProcessSet):
190+
raise ValueError("group should be list of int or ProcessSet")
188191
if tensor.ndimension() == 0:
189192
tensor = tensor.unsqueeze(0)
193+
if group is not None:
194+
return hvd.allgather(tensor, process_set=group)
190195
return hvd.allgather(tensor)
191196

192197
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:

ignite/distributed/comp_models/native.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,13 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
426426
dist.all_reduce(tensor, reduce_op)
427427
return tensor
428428

429-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
429+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Union[Any, List[int]]] = None) -> torch.Tensor:
430+
if group is not None and not isinstance(group, dist.ProcessGroup):
431+
raise ValueError("Group should be list of int or ProcessGroup")
430432
if tensor.ndimension() == 0:
431433
tensor = tensor.unsqueeze(0)
432434
output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())]
433-
dist.all_gather(output, tensor)
435+
dist.all_gather(output, tensor, group=group)
434436
return torch.cat(output, dim=0)
435437

436438
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:

ignite/distributed/comp_models/xla.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, cast, Mapping, Optional, Tuple
1+
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, Union
22

33
import torch
44

@@ -144,12 +144,16 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
144144
xm.all_reduce(op, [tensor])
145145
return tensor
146146

147-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
147+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Union[Any, List[int]]] = None) -> torch.Tensor:
148148
# from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb
149+
150+
if not self._check_group_type(group):
151+
raise ValueError("group should be list of int or list of list of int")
152+
149153
group_size = self.get_world_size()
150154
output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device)
151155
output[self.get_rank() % group_size] = tensor
152-
xm.all_reduce("sum", [output])
156+
xm.all_reduce("sum", [output], groups=group)
153157
return output.reshape(-1, *output.shape[2:])
154158

155159
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:

ignite/distributed/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[tor
339339
return _model.all_reduce(tensor, op)
340340

341341

342-
def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
342+
def all_gather(
343+
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
344+
) -> Union[torch.Tensor, float, List[float], List[str]]:
343345
"""Helper method to perform all gather operation.
344346
345347
Args:
@@ -354,7 +356,10 @@ def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, f
354356
if _need_to_sync and isinstance(_model, _SerialModel):
355357
sync(temporary=True)
356358

357-
return _model.all_gather(tensor)
359+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
360+
group = _model.new_group(group)
361+
362+
return _model.all_gather(tensor, group=group)
358363

359364

360365
def broadcast(

tests/ignite/distributed/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,17 @@ def _test_distrib_all_gather(device):
161161
idist.all_reduce([0, 1, 2])
162162

163163

164+
def _test_distrib_all_gather_group(device):
165+
166+
if idist.get_world_size() > 1:
167+
rank = idist.get_rank()
168+
group = [0, 1]
169+
170+
t = torch.tensor([rank], device=idist.device())
171+
res = idist.all_gather(t, group=group)
172+
assert torch.equal(res, torch.tensor(group))
173+
174+
164175
def _test_distrib_broadcast(device):
165176

166177
rank = idist.get_rank()

tests/ignite/distributed/utils/test_horovod.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tests.ignite.distributed.utils import (
99
_test_distrib__get_max_length,
1010
_test_distrib_all_gather,
11+
_test_distrib_all_gather_group,
1112
_test_distrib_all_reduce,
1213
_test_distrib_barrier,
1314
_test_distrib_broadcast,
@@ -162,6 +163,7 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
162163
device = "cpu" if not torch.cuda.is_available() else "cuda"
163164
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
164165
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
166+
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)
165167

166168

167169
@pytest.mark.distributed

tests/ignite/distributed/utils/test_native.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tests.ignite.distributed.utils import (
1010
_test_distrib__get_max_length,
1111
_test_distrib_all_gather,
12+
_test_distrib_all_gather_group,
1213
_test_distrib_all_reduce,
1314
_test_distrib_barrier,
1415
_test_distrib_broadcast,
@@ -228,6 +229,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl):
228229

229230
device = idist.device()
230231
_test_distrib_all_gather(device)
232+
_test_distrib_all_gather_group(device)
231233

232234

233235
@pytest.mark.distributed
@@ -236,6 +238,7 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
236238

237239
device = idist.device()
238240
_test_distrib_all_gather(device)
241+
_test_distrib_all_gather_group(device)
239242

240243

241244
@pytest.mark.distributed

tests/ignite/distributed/utils/test_xla.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ignite.distributed.utils import has_xla_support
77
from tests.ignite.distributed.utils import (
88
_test_distrib_all_gather,
9+
_test_distrib_all_gather_group,
910
_test_distrib_all_reduce,
1011
_test_distrib_barrier,
1112
_test_distrib_broadcast,
@@ -138,11 +139,13 @@ def test_idist_all_gather_xla():
138139

139140
device = idist.device()
140141
_test_distrib_all_gather(device)
142+
_test_distrib_all_gather_group(device)
141143

142144

143145
def _test_idist_all_gather_xla_in_child_proc(index):
144146
device = idist.device()
145147
_test_distrib_all_gather(device)
148+
_test_distrib_all_gather_group(device)
146149

147150

148151
@pytest.mark.tpu

0 commit comments

Comments
 (0)