diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 8dc6d0cc15b0..4d4186588a91 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -213,11 +213,13 @@ def all_reduce( return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group)) - def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]: + def all_gather( + self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None + ) -> Union[torch.Tensor, float, List[float], List[str]]: if not isinstance(tensor, (torch.Tensor, Number, str)): raise TypeError(f"Unhandled input type {type(tensor)}") - return self._collective_op(tensor, self._do_all_gather) + return self._collective_op(tensor, self._do_all_gather, group=group) def new_group(self, ranks: List[int], **kwargs: Any) -> Any: if isinstance(ranks, list) and all(isinstance(item, int) for item in ranks): @@ -275,7 +277,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ pass @abstractmethod - def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: + def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: pass @abstractmethod @@ -349,7 +351,9 @@ def all_reduce( ) -> Union[torch.Tensor, float]: return tensor - def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]: + def all_gather( + self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None + ) -> Union[torch.Tensor, float, List[float], List[str]]: if isinstance(tensor, torch.Tensor): return tensor return cast(Union[List[float], List[str]], [tensor]) @@ -364,7 +368,7 @@ def broadcast( def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[Any] = None) -> torch.Tensor: return tensor - def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: + def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: return tensor def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 556b20bd2f3e..35c45bff7dda 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -186,7 +186,9 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor: # output can also torch min/max_return_type: (min/max_vals, indices) return reduced_res[0] - def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: + def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: + if group is not None: + raise NotImplementedError("all_gather with group for horovod is not implemented") if tensor.ndimension() == 0: tensor = tensor.unsqueeze(0) return hvd.allgather(tensor) diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index b390ac3f5594..e520d820aefe 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -431,11 +431,13 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ dist.all_reduce(tensor, reduce_op) return tensor - def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: + def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: + if group is not None and not isinstance(group, dist.ProcessGroup): + raise ValueError("Argument group should be list of int or ProcessGroup") if tensor.ndimension() == 0: tensor = tensor.unsqueeze(0) output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())] - dist.all_gather(output, tensor) + dist.all_gather(output, tensor, group=group) return torch.cat(output, dim=0) def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index 0d174f738bd8..fa0d40192e14 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -144,12 +144,16 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[ xm.all_reduce(op, [tensor], groups=group) return tensor - def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: + def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor: # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb + + if group is not None and (not isinstance(group, list) or not all(isinstance(item, int) for item in group)): + raise ValueError("Argument group should be list of int") + group_size = self.get_world_size() output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device) output[self.get_rank() % group_size] = tensor - xm.all_reduce("sum", [output]) + xm.all_reduce("sum", [output], groups=group) return output.reshape(-1, *output.shape[2:]) def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index d908ecb071b1..636de5a41ac3 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -348,22 +348,30 @@ def all_reduce( return _model.all_reduce(tensor, op, group=group) -def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]: +def all_gather( + tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None +) -> Union[torch.Tensor, float, List[float], List[str]]: """Helper method to perform all gather operation. Args: tensor: tensor or number or str to collect across participating processes. + group: list of integer or the process group for each backend. If None, the default process group will be used. Returns: torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or torch.Tensor of shape ``(world_size, )`` if input is a number or List of strings if input is a string + .. versionchanged:: 0.5.0 + added ``group`` """ if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) - return _model.all_gather(tensor) + if isinstance(group, list) and all(isinstance(item, int) for item in group): + group = _model.new_group(group) + + return _model.all_gather(tensor, group=group) def broadcast( diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 0007d03e8395..3492e3e995e4 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -198,6 +198,43 @@ def _test_distrib_all_gather(device): idist.all_reduce([0, 1, 2]) +def _test_distrib_all_gather_group(device): + + if idist.get_world_size() > 1: + ranks = [0, 1] + rank = idist.get_rank() + bnd = idist.backend() + + t = torch.tensor([rank], device=device) + group = idist.new_group(ranks) + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + res = idist.all_gather(t, group=group) + else: + res = idist.all_gather(t, group=group) + assert torch.equal(res, torch.tensor(ranks)) + + t = torch.tensor([rank], device=device) + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + res = idist.all_gather(t, group=ranks) + else: + res = idist.all_gather(t, group=ranks) + assert torch.equal(res, torch.tensor(ranks)) + + ranks = "abc" + + if bnd in ("nccl", "gloo", "mpi"): + with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"): + res = idist.all_gather(t, group="abc") + elif bnd in ("xla-tpu"): + with pytest.raises(ValueError, match=r"Argument group should be list of int"): + res = idist.all_gather(t, group="abc") + elif bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + res = idist.all_gather(t, group="abc") + + def _test_distrib_broadcast(device): rank = idist.get_rank() diff --git a/tests/ignite/distributed/utils/test_horovod.py b/tests/ignite/distributed/utils/test_horovod.py index 4fa4495c091b..fa6c77f81cc1 100644 --- a/tests/ignite/distributed/utils/test_horovod.py +++ b/tests/ignite/distributed/utils/test_horovod.py @@ -8,6 +8,7 @@ from tests.ignite.distributed.utils import ( _test_distrib__get_max_length, _test_distrib_all_gather, + _test_distrib_all_gather_group, _test_distrib_all_reduce, _test_distrib_all_reduce_group, _test_distrib_barrier, @@ -165,6 +166,7 @@ def test_idist_all_gather_hvd(gloo_hvd_executor): device = "cpu" if not torch.cuda.is_available() else "cuda" np = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True) + gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True) @pytest.mark.distributed diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index ab4873a2822e..b1d885da4e40 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -9,6 +9,7 @@ from tests.ignite.distributed.utils import ( _test_distrib__get_max_length, _test_distrib_all_gather, + _test_distrib_all_gather_group, _test_distrib_all_reduce, _test_distrib_all_reduce_group, _test_distrib_barrier, @@ -247,6 +248,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl): device = idist.device() _test_distrib_all_gather(device) + _test_distrib_all_gather_group(device) @pytest.mark.distributed @@ -255,6 +257,7 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo): device = idist.device() _test_distrib_all_gather(device) + _test_distrib_all_gather_group(device) @pytest.mark.distributed diff --git a/tests/ignite/distributed/utils/test_xla.py b/tests/ignite/distributed/utils/test_xla.py index 565e300bd1da..281e1ba50d81 100644 --- a/tests/ignite/distributed/utils/test_xla.py +++ b/tests/ignite/distributed/utils/test_xla.py @@ -6,6 +6,7 @@ from ignite.distributed.utils import has_xla_support from tests.ignite.distributed.utils import ( _test_distrib_all_gather, + _test_distrib_all_gather_group, _test_distrib_all_reduce, _test_distrib_all_reduce_group, _test_distrib_barrier, @@ -150,11 +151,13 @@ def test_idist_all_gather_xla(): device = idist.device() _test_distrib_all_gather(device) + _test_distrib_all_gather_group(device) def _test_idist_all_gather_xla_in_child_proc(index): device = idist.device() _test_distrib_all_gather(device) + _test_distrib_all_gather_group(device) @pytest.mark.tpu