@@ -240,7 +240,7 @@ def set_user_overrides(recipe: ConfigContainer, kwargs: Dict[str, Any]) -> None:
240240 recipe .model .pipeline_model_parallel_size = kwargs .get ("pipeline_model_parallel_size" )
241241 if kwargs .get ("context_parallel_size" ) is not None :
242242 recipe .model .context_parallel_size = kwargs .get ("context_parallel_size" )
243- if kwargs .get ("virtual_pipeline_model_parallel_size" ) is not None :
243+ if kwargs .get ("virtual_pipeline_model_parallel_size" ) != - 1 :
244244 recipe .model .virtual_pipeline_model_parallel_size = kwargs .get ("virtual_pipeline_model_parallel_size" )
245245 if kwargs .get ("expert_model_parallel_size" ) is not None :
246246 recipe .model .expert_model_parallel_size = kwargs .get ("expert_model_parallel_size" )
@@ -269,6 +269,23 @@ def set_user_overrides(recipe: ConfigContainer, kwargs: Dict[str, Any]) -> None:
269269 if hasattr (recipe , "comm_overlap" ) and isinstance (recipe .comm_overlap , CommOverlapConfig ):
270270 recipe .comm_overlap .overlap_param_gather_with_optimizer_step = True
271271
272+ def set_deepseek_v3_layout (recipe : ConfigContainer ) -> None :
273+ """Set the DeepSeek V3 layout."""
274+ pp = recipe .model .pipeline_model_parallel_size
275+ vp = recipe .model .virtual_pipeline_model_parallel_size or 1
276+ mtp_layers = getattr (recipe .model , "mtp_num_layers" , 1 ) or 0
277+ last_layer = ["mtp" ] * mtp_layers + ["loss" ]
278+
279+ layout_map = {
280+ (1 , 1 ): None ,
281+ (4 , 1 ): [["embedding" ] + ["decoder" ] * 16 , ["decoder" ] * 16 , ["decoder" ] * 16 , ["decoder" ] * 13 + last_layer ],
282+ (8 , 1 ): [["embedding" ] + ["decoder" ] * 8 ] + [["decoder" ] * 8 ] * 6 + [["decoder" ] * 5 + last_layer ],
283+ (4 , 2 ): [["embedding" ] + ["decoder" ] * 8 ] + [["decoder" ] * 8 ] * 6 + [["decoder" ] * 5 + last_layer ],
284+ (16 , 1 ): [["embedding" ] + ["decoder" ] * 4 ] + [["decoder" ] * 4 ] * 14 + [["decoder" ] + last_layer ],
285+ (8 , 2 ): [["embedding" ] + ["decoder" ] * 4 ] + [["decoder" ] * 4 ] * 14 + [["decoder" ] + last_layer ],
286+ (4 , 4 ): [["embedding" ] + ["decoder" ] * 4 ] + [["decoder" ] * 4 ] * 14 + [["decoder" ] + last_layer ],
287+ }
288+ recipe .model .pipeline_model_parallel_layout = layout_map [(pp , vp )]
272289
273290def get_model_recipe_with_user_overrides (** kwargs ) -> ConfigContainer :
274291 """Get the model recipe with user overrides."""
@@ -284,6 +301,8 @@ def get_model_recipe_with_user_overrides(**kwargs) -> ConfigContainer:
284301 recipe = get_model_recipe (model_name , model_size , gpu , compute_dtype , domain , task )
285302 set_common_perf_overrides (recipe )
286303 set_user_overrides (recipe , kwargs )
304+ if model_name == "deepseek" and model_size == "v3" :
305+ set_deepseek_v3_layout (recipe )
287306
288307 # Scale global batch size based on the number of GPUs IF GBS is not specified by the use 0 r
289308 workload_base_config = get_workload_base_config (model_name , model_size , gpu , compute_dtype , domain , task )
0 commit comments