@@ -928,6 +928,24 @@ def get_pp_group() -> GroupCoordinator:
928928 return _PP
929929
930930
931+ _CP : Optional [GroupCoordinator ] = None
932+
933+
934+ def get_cp_group () -> GroupCoordinator :
935+ assert _CP is not None , ("context parallel group is not initialized" )
936+ return _CP
937+
938+
939+ def get_context_model_parallel_world_size ():
940+ """Return world size for the tensor model parallel group."""
941+ return get_cp_group ().world_size
942+
943+
944+ def get_context_model_parallel_rank ():
945+ """Return my rank for the tensor model parallel group."""
946+ return get_cp_group ().rank_in_group
947+
948+
931949@deprecated ("`get_pipeline_model_parallel_group` has been replaced with "
932950 "`get_pp_group` and may be removed in v0.12. Please use "
933951 "`get_pp_group` instead." )
@@ -1034,6 +1052,7 @@ def init_distributed_environment(
10341052def initialize_model_parallel (
10351053 tensor_model_parallel_size : int = 1 ,
10361054 pipeline_model_parallel_size : int = 1 ,
1055+ context_model_parallel_size : int = 1 ,
10371056 backend : Optional [str ] = None ,
10381057) -> None :
10391058 """
@@ -1082,7 +1101,7 @@ def initialize_model_parallel(
10821101 # to get group_ranks for each dimension, transpose that dimension to the
10831102 # last dimension, then reshape to 2D, then unbind the last dimension
10841103 all_ranks = torch .arange (world_size ).reshape (
1085- - 1 , data_parallel_size , pipeline_model_parallel_size ,
1104+ - 1 , data_parallel_size , pipeline_model_parallel_size , context_model_parallel_size ,
10861105 tensor_model_parallel_size ) # noqa
10871106
10881107 # Build the tensor model-parallel groups.
@@ -1102,7 +1121,7 @@ def initialize_model_parallel(
11021121 global _PP
11031122 assert _PP is None , (
11041123 "pipeline model parallel group is already initialized" )
1105- group_ranks = all_ranks .transpose (2 , 3 ).reshape (
1124+ group_ranks = all_ranks .transpose (2 , 4 ).reshape (
11061125 - 1 , pipeline_model_parallel_size ).unbind (0 )
11071126 group_ranks = [x .tolist () for x in group_ranks ]
11081127 _PP = init_model_parallel_group (group_ranks ,
@@ -1113,7 +1132,7 @@ def initialize_model_parallel(
11131132 global _DP
11141133 assert _DP is None , ("data parallel group is already initialized" )
11151134 group_ranks = all_ranks .transpose (1 ,
1116- 3 ).reshape (- 1 ,
1135+ 4 ).reshape (- 1 ,
11171136 data_parallel_size ).unbind (0 )
11181137 group_ranks = [x .tolist () for x in group_ranks ]
11191138 _DP = init_model_parallel_group (group_ranks ,
@@ -1124,23 +1143,34 @@ def initialize_model_parallel(
11241143 global _EP
11251144 assert _EP is None , ("expert parallel group is already initialized" )
11261145 group_ranks = all_ranks .transpose (1 , 2 ).reshape (
1127- - 1 , data_parallel_size * tensor_model_parallel_size ).unbind (0 )
1146+ - 1 , data_parallel_size * tensor_model_parallel_size * context_model_parallel_size ).unbind (0 )
11281147 group_ranks = [x .tolist () for x in group_ranks ]
11291148 _EP = init_model_parallel_group (group_ranks ,
11301149 get_world_group ().local_rank ,
11311150 backend ,
11321151 group_name = "ep" )
11331152
1153+ global _CP
1154+ assert _CP is None , ("context parallel group is already initialized" )
1155+ group_ranks = all_ranks .transpose (3 , 4 ).reshape (
1156+ - 1 , context_model_parallel_size ).unbind (0 )
1157+ group_ranks = [x .tolist () for x in group_ranks ]
1158+ _CP = init_model_parallel_group (group_ranks ,
1159+ get_world_group ().local_rank ,
1160+ backend ,
1161+ group_name = "cp" )
1162+
11341163 logger .info (
11351164 "rank %s in world size %s is assigned as "
1136- "DP rank %s, PP rank %s, TP rank %s, EP rank %s" , rank , world_size ,
1165+ "DP rank %s, PP rank %s, TP rank %s, EP rank %s, CP rank %s " , rank , world_size ,
11371166 _DP .rank_in_group , _PP .rank_in_group , _TP .rank_in_group ,
1138- _EP .rank_in_group )
1167+ _EP .rank_in_group , _CP . rank_in_group )
11391168
11401169
11411170def ensure_model_parallel_initialized (
11421171 tensor_model_parallel_size : int ,
11431172 pipeline_model_parallel_size : int ,
1173+ context_model_parallel_size : int ,
11441174 backend : Optional [str ] = None ,
11451175) -> None :
11461176 """Helper to initialize model parallel groups if they are not initialized,
@@ -1151,7 +1181,7 @@ def ensure_model_parallel_initialized(
11511181 get_world_group ().device_group )
11521182 if not model_parallel_is_initialized ():
11531183 initialize_model_parallel (tensor_model_parallel_size ,
1154- pipeline_model_parallel_size , backend )
1184+ pipeline_model_parallel_size , context_model_parallel_size , backend )
11551185 return
11561186
11571187 assert (
@@ -1164,6 +1194,11 @@ def ensure_model_parallel_initialized(
11641194 "pipeline parallel group already initialized, but of unexpected size. "
11651195 f"got: { pp_world_size = } vs. "
11661196 f"wanted: { pipeline_model_parallel_size = } " )
1197+ cp_world_size = get_cp_group ().world_size
1198+ assert (cp_world_size == context_model_parallel_size ), (
1199+ "context parallel group already initialized, but of unexpected size: "
1200+ f"{ cp_world_size = } vs. "
1201+ f"{ context_model_parallel_size = } " )
11671202
11681203
11691204def prepare_communication_buffer_for_model (model : torch .nn .Module ):
@@ -1256,6 +1291,11 @@ def destroy_model_parallel():
12561291 _EP .destroy ()
12571292 _EP = None
12581293
1294+ global _CP
1295+ if _CP :
1296+ _CP .destroy ()
1297+ _CP = None
1298+
12591299
12601300def destroy_distributed_environment ():
12611301 global _WORLD , _NODE_COUNT
0 commit comments