Skip to content

Commit a3279a6

Browse files
puhukvfdev-5
andauthored
Modify to all_gather takes group (#2715)
* Modify to all_gather takes `group` * Update xla.py * resolve confilct * Update * Update utils.py * Update xla.py * autopep8 fix * Update xla.py * Update xla.py * Update * autopep8 fix * Update __init__.py * Update base.py * Update __init__.py * Update xla.py * Update xla.py * Update * Update * autopep8 fix * Update __init__.py * Update utils.py * Update utils.py * Update horovod.py * Update __init__.py Co-authored-by: puhuk <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent b3bef8c commit a3279a6

9 files changed

Lines changed: 77 additions & 12 deletions

File tree

ignite/distributed/comp_models/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,13 @@ def all_reduce(
213213

214214
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group))
215215

216-
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
216+
def all_gather(
217+
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
218+
) -> Union[torch.Tensor, float, List[float], List[str]]:
217219
if not isinstance(tensor, (torch.Tensor, Number, str)):
218220
raise TypeError(f"Unhandled input type {type(tensor)}")
219221

220-
return self._collective_op(tensor, self._do_all_gather)
222+
return self._collective_op(tensor, self._do_all_gather, group=group)
221223

222224
def new_group(self, ranks: List[int], **kwargs: Any) -> Any:
223225
if isinstance(ranks, list) and all(isinstance(item, int) for item in ranks):
@@ -275,7 +277,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
275277
pass
276278

277279
@abstractmethod
278-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
280+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
279281
pass
280282

281283
@abstractmethod
@@ -349,7 +351,9 @@ def all_reduce(
349351
) -> Union[torch.Tensor, float]:
350352
return tensor
351353

352-
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
354+
def all_gather(
355+
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
356+
) -> Union[torch.Tensor, float, List[float], List[str]]:
353357
if isinstance(tensor, torch.Tensor):
354358
return tensor
355359
return cast(Union[List[float], List[str]], [tensor])
@@ -364,7 +368,7 @@ def broadcast(
364368
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor:
365369
return tensor
366370

367-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
371+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
368372
return tensor
369373

370374
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:

ignite/distributed/comp_models/horovod.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
186186
# output can also torch min/max_return_type: (min/max_vals, indices)
187187
return reduced_res[0]
188188

189-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
189+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
190+
if group is not None:
191+
raise NotImplementedError("all_gather with group for horovod is not implemented")
190192
if tensor.ndimension() == 0:
191193
tensor = tensor.unsqueeze(0)
192194
return hvd.allgather(tensor)

ignite/distributed/comp_models/native.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,13 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
431431
dist.all_reduce(tensor, reduce_op)
432432
return tensor
433433

434-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
434+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
435+
if group is not None and not isinstance(group, dist.ProcessGroup):
436+
raise ValueError("Argument group should be list of int or ProcessGroup")
435437
if tensor.ndimension() == 0:
436438
tensor = tensor.unsqueeze(0)
437439
output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())]
438-
dist.all_gather(output, tensor)
440+
dist.all_gather(output, tensor, group=group)
439441
return torch.cat(output, dim=0)
440442

441443
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:

ignite/distributed/comp_models/xla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,16 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
144144
xm.all_reduce(op, [tensor], groups=group)
145145
return tensor
146146

147-
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
147+
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
148148
# from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb
149+
150+
if group is not None and (not isinstance(group, list) or not all(isinstance(item, int) for item in group)):
151+
raise ValueError("Argument group should be list of int")
152+
149153
group_size = self.get_world_size()
150154
output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device)
151155
output[self.get_rank() % group_size] = tensor
152-
xm.all_reduce("sum", [output])
156+
xm.all_reduce("sum", [output], groups=group)
153157
return output.reshape(-1, *output.shape[2:])
154158

155159
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:

ignite/distributed/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,22 +348,30 @@ def all_reduce(
348348
return _model.all_reduce(tensor, op, group=group)
349349

350350

351-
def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
351+
def all_gather(
352+
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
353+
) -> Union[torch.Tensor, float, List[float], List[str]]:
352354
"""Helper method to perform all gather operation.
353355
354356
Args:
355357
tensor: tensor or number or str to collect across participating processes.
358+
group: list of integer or the process group for each backend. If None, the default process group will be used.
356359
357360
Returns:
358361
torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
359362
torch.Tensor of shape ``(world_size, )`` if input is a number or
360363
List of strings if input is a string
361364
365+
.. versionchanged:: 0.5.0
366+
added ``group``
362367
"""
363368
if _need_to_sync and isinstance(_model, _SerialModel):
364369
sync(temporary=True)
365370

366-
return _model.all_gather(tensor)
371+
if isinstance(group, list) and all(isinstance(item, int) for item in group):
372+
group = _model.new_group(group)
373+
374+
return _model.all_gather(tensor, group=group)
367375

368376

369377
def broadcast(

tests/ignite/distributed/utils/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,43 @@ def _test_distrib_all_gather(device):
198198
idist.all_reduce([0, 1, 2])
199199

200200

201+
def _test_distrib_all_gather_group(device):
202+
203+
if idist.get_world_size() > 1:
204+
ranks = [0, 1]
205+
rank = idist.get_rank()
206+
bnd = idist.backend()
207+
208+
t = torch.tensor([rank], device=device)
209+
group = idist.new_group(ranks)
210+
if bnd in ("horovod"):
211+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
212+
res = idist.all_gather(t, group=group)
213+
else:
214+
res = idist.all_gather(t, group=group)
215+
assert torch.equal(res, torch.tensor(ranks))
216+
217+
t = torch.tensor([rank], device=device)
218+
if bnd in ("horovod"):
219+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
220+
res = idist.all_gather(t, group=ranks)
221+
else:
222+
res = idist.all_gather(t, group=ranks)
223+
assert torch.equal(res, torch.tensor(ranks))
224+
225+
ranks = "abc"
226+
227+
if bnd in ("nccl", "gloo", "mpi"):
228+
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
229+
res = idist.all_gather(t, group="abc")
230+
elif bnd in ("xla-tpu"):
231+
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
232+
res = idist.all_gather(t, group="abc")
233+
elif bnd in ("horovod"):
234+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
235+
res = idist.all_gather(t, group="abc")
236+
237+
201238
def _test_distrib_broadcast(device):
202239

203240
rank = idist.get_rank()

tests/ignite/distributed/utils/test_horovod.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tests.ignite.distributed.utils import (
99
_test_distrib__get_max_length,
1010
_test_distrib_all_gather,
11+
_test_distrib_all_gather_group,
1112
_test_distrib_all_reduce,
1213
_test_distrib_all_reduce_group,
1314
_test_distrib_barrier,
@@ -165,6 +166,7 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
165166
device = "cpu" if not torch.cuda.is_available() else "cuda"
166167
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
167168
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
169+
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)
168170

169171

170172
@pytest.mark.distributed

tests/ignite/distributed/utils/test_native.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tests.ignite.distributed.utils import (
1010
_test_distrib__get_max_length,
1111
_test_distrib_all_gather,
12+
_test_distrib_all_gather_group,
1213
_test_distrib_all_reduce,
1314
_test_distrib_all_reduce_group,
1415
_test_distrib_barrier,
@@ -247,6 +248,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl):
247248

248249
device = idist.device()
249250
_test_distrib_all_gather(device)
251+
_test_distrib_all_gather_group(device)
250252

251253

252254
@pytest.mark.distributed
@@ -255,6 +257,7 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
255257

256258
device = idist.device()
257259
_test_distrib_all_gather(device)
260+
_test_distrib_all_gather_group(device)
258261

259262

260263
@pytest.mark.distributed

tests/ignite/distributed/utils/test_xla.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ignite.distributed.utils import has_xla_support
77
from tests.ignite.distributed.utils import (
88
_test_distrib_all_gather,
9+
_test_distrib_all_gather_group,
910
_test_distrib_all_reduce,
1011
_test_distrib_all_reduce_group,
1112
_test_distrib_barrier,
@@ -150,11 +151,13 @@ def test_idist_all_gather_xla():
150151

151152
device = idist.device()
152153
_test_distrib_all_gather(device)
154+
_test_distrib_all_gather_group(device)
153155

154156

155157
def _test_idist_all_gather_xla_in_child_proc(index):
156158
device = idist.device()
157159
_test_distrib_all_gather(device)
160+
_test_distrib_all_gather_group(device)
158161

159162

160163
@pytest.mark.tpu

0 commit comments

Comments
 (0)