Skip to content

Commit 9271a39

Browse files
committed
Support scaling runs for DeepSeek-V3 (#1501)
Signed-off-by: Sanju C Sudhakaran <[email protected]>
1 parent fd21c01 commit 9271a39

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

scripts/performance/argument_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,10 @@ def parse_cli_args():
317317
parser.add_argument(
318318
"-vp",
319319
"--virtual_pipeline_model_parallel_size",
320-
type=int,
320+
type=lambda x: None if x == "None" else int(x),
321321
help="Number of virtual blocks per pipeline model parallel rank is the virtual model parallel size.",
322322
required=False,
323-
default=None,
323+
default=-1,
324324
)
325325
parser.add_argument(
326326
"-ep",

scripts/performance/setup_experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def main(
137137
executor.container_mounts.extend([f"{megatron_ckpt_dir}:/mnt/megatron_ckpt"])
138138
logger.info(f"Custom mounts: {executor.container_mounts}")
139139

140+
vp_size = vp_size if vp_size != -1 else None
140141
exp_name = (
141142
f"{task}_{model_name}_{model_size}_{compute_dtype}"
142143
f"_gpus{num_gpus}_tp{tp_size}_pp{pp_size}_cp{cp_size}"

scripts/performance/utils/helpers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def set_user_overrides(recipe: ConfigContainer, kwargs: Dict[str, Any]) -> None:
240240
recipe.model.pipeline_model_parallel_size = kwargs.get("pipeline_model_parallel_size")
241241
if kwargs.get("context_parallel_size") is not None:
242242
recipe.model.context_parallel_size = kwargs.get("context_parallel_size")
243-
if kwargs.get("virtual_pipeline_model_parallel_size") is not None:
243+
if kwargs.get("virtual_pipeline_model_parallel_size") != -1:
244244
recipe.model.virtual_pipeline_model_parallel_size = kwargs.get("virtual_pipeline_model_parallel_size")
245245
if kwargs.get("expert_model_parallel_size") is not None:
246246
recipe.model.expert_model_parallel_size = kwargs.get("expert_model_parallel_size")
@@ -269,6 +269,23 @@ def set_user_overrides(recipe: ConfigContainer, kwargs: Dict[str, Any]) -> None:
269269
if hasattr(recipe, "comm_overlap") and isinstance(recipe.comm_overlap, CommOverlapConfig):
270270
recipe.comm_overlap.overlap_param_gather_with_optimizer_step = True
271271

272+
def set_deepseek_v3_layout(recipe: ConfigContainer) -> None:
273+
"""Set the DeepSeek V3 layout."""
274+
pp = recipe.model.pipeline_model_parallel_size
275+
vp = recipe.model.virtual_pipeline_model_parallel_size or 1
276+
mtp_layers = getattr(recipe.model, "mtp_num_layers", 1) or 0
277+
last_layer = ["mtp"] * mtp_layers + ["loss"]
278+
279+
layout_map = {
280+
(1, 1): None,
281+
(4, 1): [["embedding"] + ["decoder"] * 16, ["decoder"] * 16, ["decoder"] * 16, ["decoder"] * 13 + last_layer],
282+
(8, 1): [["embedding"] + ["decoder"] * 8] + [["decoder"] * 8] * 6 + [["decoder"] * 5 + last_layer],
283+
(4, 2): [["embedding"] + ["decoder"] * 8] + [["decoder"] * 8] * 6 + [["decoder"] * 5 + last_layer],
284+
(16, 1): [["embedding"] + ["decoder"] * 4] + [["decoder"] * 4] * 14 + [["decoder"] + last_layer],
285+
(8, 2): [["embedding"] + ["decoder"] * 4] + [["decoder"] * 4] * 14 + [["decoder"] + last_layer],
286+
(4, 4): [["embedding"] + ["decoder"] * 4] + [["decoder"] * 4] * 14 + [["decoder"] + last_layer],
287+
}
288+
recipe.model.pipeline_model_parallel_layout = layout_map[(pp, vp)]
272289

273290
def get_model_recipe_with_user_overrides(**kwargs) -> ConfigContainer:
274291
"""Get the model recipe with user overrides."""
@@ -284,6 +301,8 @@ def get_model_recipe_with_user_overrides(**kwargs) -> ConfigContainer:
284301
recipe = get_model_recipe(model_name, model_size, gpu, compute_dtype, domain, task)
285302
set_common_perf_overrides(recipe)
286303
set_user_overrides(recipe, kwargs)
304+
if model_name == "deepseek" and model_size == "v3":
305+
set_deepseek_v3_layout(recipe)
287306

288307
# Scale global batch size based on the number of GPUs IF GBS is not specified by the use 0 r
289308
workload_base_config = get_workload_base_config(model_name, model_size, gpu, compute_dtype, domain, task)

0 commit comments

Comments
 (0)