-
-
Notifications
You must be signed in to change notification settings - Fork 666
Helper function all_gather_tensors_with_shapes()
#3281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7c66841
9cc00a7
0d8eb3b
538f0c0
7ac690a
988ec06
9501489
907dcc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,11 @@ | ||
| import itertools | ||
| import socket | ||
| from contextlib import contextmanager | ||
| from functools import wraps | ||
| from typing import Any, Callable, List, Mapping, Optional, Tuple, Union | ||
| from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union | ||
|
|
||
| import torch | ||
| from torch import distributed as dist | ||
|
|
||
| from ignite.distributed.comp_models import ( | ||
| _SerialModel, | ||
|
|
@@ -43,6 +45,7 @@ | |
| "one_rank_only", | ||
| "new_group", | ||
| "one_rank_first", | ||
| "all_gather_tensors_with_shapes", | ||
| ] | ||
|
|
||
| _model = _SerialModel() | ||
|
|
@@ -350,6 +353,60 @@ def all_reduce( | |
| return _model.all_reduce(tensor, op, group=group) | ||
|
|
||
|
|
||
| def all_gather_tensors_with_shapes( | ||
| tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None | ||
| ) -> List[torch.Tensor]: | ||
| """Helper method to gather tensors of possibly different shapes but with the same number of dimensions | ||
| across processes. | ||
|
|
||
| This function gets the shapes of participating tensors as input so you should know them beforehand. If your | ||
| tensors are of different number of dimensions or you don't know their shapes beforehand, you can use | ||
| ``torch.distributed.all_gather_object``, otherwise this method is quite faster. | ||
|
|
||
| Examples: | ||
| .. code-block:: python | ||
|
|
||
| import ignite.distributed as idist | ||
|
|
||
| rank = idist.get_rank() | ||
| ws = idist.get_world_size() | ||
| tensor = torch.randn(rank+1, rank+2) | ||
| tensors = idist.all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)]) | ||
|
|
||
| Args: | ||
| tensor: tensor to collect across participating processes. | ||
| shapes: A sequence containing the shape of participating processes' ``tensor`` s. | ||
| group: list of integer or the process group for each backend. If None, the default process group will be used. | ||
|
|
||
| Returns: | ||
| List[torch.Tensor] | ||
| """ | ||
| if _need_to_sync and isinstance(_model, _SerialModel): | ||
| sync(temporary=True) | ||
|
|
||
| if isinstance(group, list) and all(isinstance(item, int) for item in group): | ||
| group = _model.new_group(group) | ||
|
|
||
| if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER: | ||
| return [tensor] | ||
|
|
||
| max_shape = torch.tensor(shapes).amax(dim=0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder whether we could actually get tensor shapes using
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we can. Do you want it in this PR?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Up to you, if you would like it in another PR, OK to me as well
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make it in another PR, I'll merge this one as CI is green |
||
| padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist() | ||
| padded_tensor = torch.nn.functional.pad( | ||
| tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes)))) | ||
| ) | ||
| all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) | ||
| return [ | ||
| all_padded_tensors[ | ||
| [ | ||
| slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size) | ||
| for dim, dim_size in enumerate(shape) | ||
| ] | ||
| ] | ||
| for rank, shape in enumerate(shapes) | ||
| ] | ||
|
|
||
|
|
||
| def all_gather( | ||
| tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None | ||
| ) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.