Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b6cb444
Modify to all_gather takes `group`
puhuk Sep 16, 2022
d92c5fb
Update xla.py
puhuk Sep 23, 2022
be8ec87
resolve confilct
puhuk Sep 23, 2022
b6d743e
Update
puhuk Sep 23, 2022
aaaf0f5
Update utils.py
puhuk Sep 23, 2022
c075931
Merge remote-tracking branch 'origin/master' into all_gather_group
puhuk Sep 24, 2022
b6351d2
Update xla.py
puhuk Sep 24, 2022
e510ea2
autopep8 fix
puhuk Sep 24, 2022
ce8f6fe
Update xla.py
puhuk Sep 25, 2022
8be6266
Merge branch 'all_gather_group' of github.com:puhuk/ignite into all_g…
puhuk Sep 25, 2022
4dbb387
Update xla.py
puhuk Sep 25, 2022
75d7878
Update
puhuk Sep 26, 2022
dea8fbf
autopep8 fix
puhuk Sep 26, 2022
eeefec0
Update __init__.py
puhuk Sep 26, 2022
a7f8311
Merge branch 'all_gather_group' of github.com:puhuk/ignite into all_g…
puhuk Sep 26, 2022
14c2e84
Update base.py
puhuk Sep 26, 2022
3c23fd5
Update __init__.py
puhuk Sep 26, 2022
01c2b05
Update xla.py
puhuk Sep 26, 2022
d457272
Update xla.py
puhuk Sep 26, 2022
d2dd466
Update
puhuk Sep 26, 2022
c9cadeb
Update
puhuk Sep 28, 2022
aee2764
autopep8 fix
puhuk Sep 28, 2022
07a7786
Update __init__.py
puhuk Sep 28, 2022
4c58b98
Merge branch 'all_gather_group' of github.com:puhuk/ignite into all_g…
puhuk Sep 28, 2022
bc92abe
Update utils.py
puhuk Sep 28, 2022
2e173d6
Update utils.py
puhuk Sep 29, 2022
93e22bf
Update horovod.py
puhuk Oct 1, 2022
a3838ba
Update __init__.py
puhuk Oct 1, 2022
f917f32
Merge branch 'master' into all_gather_group
vfdev-5 Oct 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,13 @@ def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Uni

return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op))

def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
def all_gather(
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError(f"Unhandled input type {type(tensor)}")

return self._collective_op(tensor, self._do_all_gather)
return self._collective_op(tensor, self._do_all_gather, group=group)

def new_group(self, ranks: List[int], **kwargs: Any) -> Any:
if isinstance(ranks, list) and all(isinstance(item, int) for item in ranks):
Expand Down Expand Up @@ -273,7 +275,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
pass

@abstractmethod
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
pass

@abstractmethod
Expand Down Expand Up @@ -345,7 +347,9 @@ def spawn(*args: Any, **kwargs: Any) -> None:
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
return tensor

def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
def all_gather(
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
if isinstance(tensor, torch.Tensor):
return tensor
return cast(Union[List[float], List[str]], [tensor])
Expand All @@ -360,7 +364,7 @@ def broadcast(
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
return tensor

def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
return tensor

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
Expand Down
4 changes: 3 additions & 1 deletion ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def _do_manual_all_reduce(self, tensor: torch.Tensor, op: Any) -> torch.Tensor:
# output can also torch min/max_return_type: (min/max_vals, indices)
return reduced_res[0]

def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group is not None:
raise NotImplementedError("all_reduce with group for horovod is not implemented")
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
return hvd.allgather(tensor)
Expand Down
6 changes: 4 additions & 2 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,13 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
dist.all_reduce(tensor, reduce_op)
return tensor

def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
if group is not None and not isinstance(group, dist.ProcessGroup):
raise ValueError("Argument group should be list of int or ProcessGroup")
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())]
dist.all_gather(output, tensor)
dist.all_gather(output, tensor, group=group)
return torch.cat(output, dim=0)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
Expand Down
8 changes: 6 additions & 2 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,16 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor:
xm.all_reduce(op, [tensor])
return tensor

def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
# from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb

if group is not None and (not isinstance(group, list) or not all(isinstance(item, int) for item in group)):
raise ValueError("Argument group should be list of int")

group_size = self.get_world_size()
output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device)
output[self.get_rank() % group_size] = tensor
xm.all_reduce("sum", [output])
xm.all_reduce("sum", [output], groups=group)
return output.reshape(-1, *output.shape[2:])

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
Expand Down
9 changes: 7 additions & 2 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[tor
return _model.all_reduce(tensor, op)


def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
def all_gather(
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str]]:
"""Helper method to perform all gather operation.

Args:
Expand All @@ -355,7 +357,10 @@ def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, f
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.all_gather(tensor)
if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

return _model.all_gather(tensor, group=group)


def broadcast(
Expand Down
37 changes: 37 additions & 0 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,43 @@ def _test_distrib_all_gather(device):
idist.all_reduce([0, 1, 2])


def _test_distrib_all_gather_group(device):

if idist.get_world_size() > 1:
ranks = [0, 1]
rank = idist.get_rank()
bnd = idist.backend()

t = torch.tensor([rank], device=device)
group = idist.new_group(ranks)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
res = idist.all_gather(t, group=group)
else:
res = idist.all_gather(t, group=group)
assert torch.equal(res, torch.tensor(ranks))

t = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
res = idist.all_gather(t, group=ranks)
else:
res = idist.all_gather(t, group=ranks)
assert torch.equal(res, torch.tensor(ranks))

ranks = "abc"

if bnd in ("nccl", "gloo", "mpi"):
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
res = idist.all_gather(t, group="abc")
elif bnd in ("xla-tpu"):
with pytest.raises(ValueError, match=r"Argument group should be list of int"):
res = idist.all_gather(t, group="abc")
elif bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_reduce with group for horovod is not implemented"):
res = idist.all_gather(t, group="abc")


def _test_distrib_broadcast(device):

rank = idist.get_rank()
Expand Down
2 changes: 2 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.ignite.distributed.utils import (
_test_distrib__get_max_length,
_test_distrib_all_gather,
_test_distrib_all_gather_group,
_test_distrib_all_reduce,
_test_distrib_barrier,
_test_distrib_broadcast,
Expand Down Expand Up @@ -163,6 +164,7 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
device = "cpu" if not torch.cuda.is_available() else "cuda"
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)


@pytest.mark.distributed
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tests.ignite.distributed.utils import (
_test_distrib__get_max_length,
_test_distrib_all_gather,
_test_distrib_all_gather_group,
_test_distrib_all_reduce,
_test_distrib_barrier,
_test_distrib_broadcast,
Expand Down Expand Up @@ -244,6 +245,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)


@pytest.mark.distributed
Expand All @@ -252,6 +254,7 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)


@pytest.mark.distributed
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ignite.distributed.utils import has_xla_support
from tests.ignite.distributed.utils import (
_test_distrib_all_gather,
_test_distrib_all_gather_group,
_test_distrib_all_reduce,
_test_distrib_barrier,
_test_distrib_broadcast,
Expand Down Expand Up @@ -147,11 +148,13 @@ def test_idist_all_gather_xla():

device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)


def _test_idist_all_gather_xla_in_child_proc(index):
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)


@pytest.mark.tpu
Expand Down