Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,16 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
group_size = group.size()
elif isinstance(group, list):
group_size = len(group)
group = self._do_new_group(group)
else:
raise ValueError("Argument group should be list of int or ProcessGroup")
if tensor.ndimension() == 0:
tensor = tensor.unsqueeze(0)
output = [torch.zeros_like(tensor) for _ in range(group_size)]
dist.all_gather(output, tensor, group=group)
if group is not None:
dist.all_gather(output, tensor, group=group)
else:
dist.all_gather(output, tensor)
return torch.cat(output, dim=0)

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
Expand Down