Skip to content

Commit a6210fd

Browse files
tsunghsienleepytorchmergebot
authored andcommitted
[c10d] Enhance get_process_group_ranks() to accept group=None (pytorch#154902)
Summary: This diff enhances the `get_process_group_ranks()` function to accept `group=None` as an optional argument. This allows the function to return all ranks associated with the default process group if no group is specified. Test Plan: contbuild & OSS CI Rollback Plan: Differential Revision: D75817800 Pull Request resolved: pytorch#154902 Approved by: https://github.com/wz337
1 parent bd3c329 commit a6210fd

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torch/distributed/distributed_c10d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,17 +1074,18 @@ def _get_global_rank(group, rank) -> int:
10741074
return get_global_rank(group, rank)
10751075

10761076

1077-
def get_process_group_ranks(group: ProcessGroup) -> list[int]:
1077+
def get_process_group_ranks(group: Optional[ProcessGroup]) -> list[int]:
10781078
"""
10791079
Get all ranks associated with ``group``.
10801080
10811081
Args:
1082-
group (ProcessGroup): ProcessGroup to get all ranks from.
1082+
group (Optional[ProcessGroup]): ProcessGroup to get all ranks from.
1083+
If None, the default process group will be used.
10831084
10841085
Returns:
10851086
List of global ranks ordered by group rank.
10861087
"""
1087-
return list(_world.pg_group_ranks[group].keys())
1088+
return list(_world.pg_group_ranks[group or _get_default_group()].keys())
10881089

10891090

10901091
def _get_group_size(group) -> int:
@@ -5447,7 +5448,7 @@ def new_subgroups(
54475448
)
54485449

54495450
# TODO: Use itertools.batched(get_process_group_ranks(group=group), group_size) instead when Python 3.12 is supported.
5450-
ranks = get_process_group_ranks(group=group or _get_default_group())
5451+
ranks = get_process_group_ranks(group=group)
54515452
ranks_per_subgroup_list = [
54525453
ranks[i : i + group_size] for i in range(0, len(ranks), group_size)
54535454
]

0 commit comments

Comments
 (0)