@@ -38,9 +38,8 @@ def _init_device_mesh_stub():
3838else :
3939 from torch ._C ._distributed_c10d import Backend as C10dBackend
4040 from torch .distributed .distributed_c10d import (
41- _find_pg_by_ranks_and_tag ,
4241 _get_default_group ,
43- _get_group_tag ,
42+ _resolve_process_group ,
4443 get_backend ,
4544 get_process_group_ranks ,
4645 get_rank ,
@@ -103,7 +102,7 @@ def create_sub_mesh(
103102 mesh_tensor = device_mesh .mesh
104103 # slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims.
105104 slice_dim_idx = []
106- slice_dim_group_info = []
105+ slice_dim_group_name = []
107106 # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
108107 # flattened mesh tensor.
109108 num_dims_flatten = 0
@@ -121,15 +120,15 @@ def create_sub_mesh(
121120 # then the final slice_dim_idx should be [0, 1, 2].
122121 slice_dim_idx .append (mesh_dim_indices [0 ] - num_dims_flatten )
123122 num_dims_flatten += len (mesh_dim_indices ) - 1
124- slice_dim_group_info .append (
123+ slice_dim_group_name .append (
125124 self .root_to_flatten_mapping [device_mesh ][
126125 mesh_dim_name
127- ]._dim_group_infos [0 ]
126+ ]._dim_group_names [0 ]
128127 )
129128 else :
130129 slice_dim_idx .append (mesh_dim_indices [0 ] - num_dims_flatten )
131- slice_dim_group_info .append (
132- device_mesh ._dim_group_infos [mesh_dim_indices [0 ]]
130+ slice_dim_group_name .append (
131+ device_mesh ._dim_group_names [mesh_dim_indices [0 ]]
133132 )
134133
135134 # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
@@ -155,7 +154,7 @@ def create_sub_mesh(
155154 if cur_rank in mesh_nd :
156155 res_submesh = submesh
157156
158- res_submesh ._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined]
157+ res_submesh ._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined]
159158 self .child_to_root_mapping [res_submesh ] = device_mesh
160159
161160 return res_submesh
@@ -360,8 +359,8 @@ def _get_all_submeshes(
360359 mesh_dim_names = (mesh_dim_name ,),
361360 _init_backend = False ,
362361 )
363- submesh ._dim_group_infos = (
364- [device_mesh ._dim_group_infos [mesh_dim ]]
362+ submesh ._dim_group_names = (
363+ [device_mesh ._dim_group_names [mesh_dim ]]
365364 if cur_rank in mesh_1d
366365 else []
367366 )
@@ -496,13 +495,10 @@ def _get_or_create_default_group(self):
496495 return _get_default_group ()
497496
498497 def _init_process_groups (self ):
499- # tag/ranks/ group_name associated with each mesh dimension, each
498+ # group_name associated with each mesh dimension, each
500499 # mesh dimension should have one sub-group per rank
501500 #
502- # TODO(yifu): remove tag and ranks once we fully migrate to native
503- # functional collectives. See details in:
504- # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
505- dim_group_infos : list [tuple [str , list [int ], str ]] = []
501+ dim_group_names : list [str ] = []
506502 default_group = _get_default_group ()
507503
508504 if self .mesh .ndim == 1 and self .mesh .numel () == get_world_size ():
@@ -519,13 +515,7 @@ def _init_process_groups(self):
519515 and get_backend (default_group ) == "gloo"
520516 else default_group
521517 )
522- dim_group_infos .append (
523- (
524- _get_group_tag (dim_group ),
525- ranks ,
526- dim_group .group_name ,
527- )
528- )
518+ dim_group_names .append (dim_group .group_name )
529519 else :
530520 # create sub pgs base on the mesh argument specified
531521 for dim in range (self .mesh .ndim ):
@@ -579,10 +569,9 @@ def _init_process_groups(self):
579569 has_split_group = True
580570
581571 # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
582- # and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
583- # the current rank is in the subgroup.
572+ # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
584573 # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
585- # along with appending information to the `dim_group_infos ` list whenever necessary.
574+ # along with appending information to the `dim_group_names ` list whenever necessary.
586575 for dim_mesh in pg_ranks_by_dim :
587576 subgroup_ranks = dim_mesh .tolist ()
588577
@@ -599,19 +588,13 @@ def _init_process_groups(self):
599588
600589 # only add to dim_groups if the current rank in the subgroup
601590 if self .get_rank () in subgroup_ranks :
602- if len (dim_group_infos ) > dim :
591+ if len (dim_group_names ) > dim :
603592 raise RuntimeError (
604593 f"Each device mesh dimension should get only one process group, but got { self .get_rank ()} "
605594 f"in { subgroup_ranks } !"
606595 )
607- dim_group_infos .append (
608- (
609- _get_group_tag (not_none (dim_group )),
610- subgroup_ranks ,
611- dim_group .group_name ,
612- )
613- )
614- self ._dim_group_infos = dim_group_infos
596+ dim_group_names .append (dim_group .group_name )
597+ self ._dim_group_names = dim_group_names
615598
616599 def __enter__ (self ) -> "DeviceMesh" :
617600 # set this mesh as the current mesh in mesh env
@@ -745,7 +728,7 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
745728 Returns:
746729 A :class:`ProcessGroup` object.
747730 """
748- if not hasattr (self , "_dim_group_infos " ):
731+ if not hasattr (self , "_dim_group_names " ):
749732 raise RuntimeError ("DeviceMesh process groups not initialized!" )
750733
751734 if self .mesh .ndim > 1 and mesh_dim is None :
@@ -758,28 +741,25 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
758741
759742 # Quick return if the current device_mesh is a 1D mesh.
760743 if self .mesh .ndim == 1 and mesh_dim is None :
761- return not_none (
762- _find_pg_by_ranks_and_tag (* self ._dim_group_infos [0 ][:2 ]) # type: ignore[index]
763- )
744+ return not_none (_resolve_process_group (self ._dim_group_names [0 ]))
764745
765746 root_mesh = _mesh_resources .get_root_mesh (self )
766747 root_to_flatten_mapping = _mesh_resources .root_to_flatten_mapping .get (
767748 root_mesh , None
768749 )
769750 if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping .keys ():
770- dim_group_infos = root_to_flatten_mapping [
751+ dim_group_name = root_to_flatten_mapping [
771752 mesh_dim # type: ignore[index]
772- ]._dim_group_infos [ 0 ][: 2 ]
773- return not_none (_find_pg_by_ranks_and_tag ( * dim_group_infos ))
753+ ]._dim_group_names [ 0 ]
754+ return not_none (_resolve_process_group ( dim_group_name ))
774755 else :
775756 mesh_dim = (
776757 _mesh_resources .get_mesh_dim_by_name (self , mesh_dim )
777758 if isinstance (mesh_dim , str )
778759 else mesh_dim
779760 )
780- return not_none (
781- _find_pg_by_ranks_and_tag (* self ._dim_group_infos [mesh_dim ][:2 ]) # type: ignore[index]
782- )
761+ assert isinstance (mesh_dim , int )
762+ return not_none (_resolve_process_group (self ._dim_group_names [mesh_dim ]))
783763
784764 def get_all_groups (self ) -> list [ProcessGroup ]:
785765 """
@@ -852,9 +832,7 @@ def from_group(
852832 mesh_dim_names = mesh_dim_names ,
853833 _init_backend = False ,
854834 )
855- device_mesh ._dim_group_infos = [
856- (_get_group_tag (group ), group_ranks , group .group_name )
857- ]
835+ device_mesh ._dim_group_names = [group .group_name ]
858836 return device_mesh
859837
860838 # nD scenario
@@ -880,14 +858,7 @@ def from_group(
880858 device_mesh = DeviceMesh (
881859 device_type , mesh , mesh_dim_names = mesh_dim_names , _init_backend = False
882860 )
883- device_mesh ._dim_group_infos = [
884- (
885- _get_group_tag (group ),
886- get_process_group_ranks (group ),
887- group .group_name ,
888- )
889- for group in groups
890- ]
861+ device_mesh ._dim_group_names = [group .group_name for group in groups ]
891862 return device_mesh
892863
893864 def size (self , mesh_dim : Optional [int ] = None ) -> int :
0 commit comments