Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,39 @@ def all_reduce(
def all_gather_tensors_with_shapes(
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
"""Gather tensors with different shapes but with the same number of dimensions from across processes."""
"""Gather tensors of possibly different shapes but with the same number of dimensions from across processes.

This function gets the shapes of participating tensors as input so you should know them beforehand. If your tensors
are of different number of dimensions or you don't know their shapes beforehand, you could use
`torch.distributed.all_gather_object()`, otherwise this method is quite faster.

Examples:

.. code-block:: python

import ignite.distributed as idist

rank = idist.get_rank()
ws = idist.get_world_size()
tensor = torch.randn(rank+1, rank+2)
tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)], )

# To exclude rank zero:

tensors = all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(1, ws)], list(range(1, ws)))
if rank == 0:
assert tensors == tensor
else:
assert (tensors[rank-1] == tensor).all()

Args:
tensor: tensor to collect across participating processes.
shapes: A sequence containing the shape of participating processes' `tensor`s.
group: list of integer or the process group for each backend. If None, the default process group will be used.

Returns:
List[torch.Tensor]
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

Expand Down
9 changes: 6 additions & 3 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def _test_distrib_all_gather_group(device):


def _test_idist_all_gather_tensors_with_different_shapes(device):
torch.manual_seed(41)
rank = idist.get_rank()
ws = idist.get_world_size()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
Expand All @@ -311,8 +312,10 @@ def _test_idist_all_gather_tensors_with_different_shapes(device):


def _test_idist_all_gather_tensors_with_different_shapes_group(device):
torch.manual_seed(41)

rank = idist.get_rank()
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
ranks = list(range(1, idist.get_world_size()))
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
Expand All @@ -326,9 +329,9 @@ def _test_idist_all_gather_tensors_with_different_shapes_group(device):
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)], ranks)
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)], ranks)
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
for r in range(ws):
if r in ranks:
r_tensor = reference[
Expand Down