Skip to content

Commit 4a65d60

Browse files
committed
revise per disscussion round 2, add SAC and optional padding
1 parent 1802b49 commit 4a65d60

File tree

16 files changed

+724
-725
lines changed

16 files changed

+724
-725
lines changed

tests/integration_tests/models.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -64,37 +64,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
6464
"deepseek_v3_pp+fsdp+tp+ep+etp",
6565
ngpu=8,
6666
),
67-
# Integration Test Cases for DeepSeek V3 with DeepEP
68-
OverrideDefinitions(
69-
[
70-
[
71-
"--model.name deepseek_v3",
72-
"--parallelism.data_parallel_shard_degree 4",
73-
"--parallelism.expert_parallel_degree 2",
74-
"--parallelism.moe_comm_backend deep_ep",
75-
],
76-
],
77-
"DeepSeek V3 FSDP+EP+DeepEP",
78-
"deepseek_v3_fsdp+ep+deepep",
79-
ngpu=4,
80-
),
81-
OverrideDefinitions(
82-
[
83-
[
84-
"--model.name deepseek_v3",
85-
"--parallelism.pipeline_parallel_degree 2",
86-
"--parallelism.pipeline_parallel_schedule Interleaved1F1B",
87-
"--parallelism.data_parallel_shard_degree 2",
88-
"--parallelism.tensor_parallel_degree 2",
89-
"--parallelism.expert_parallel_degree 4",
90-
"--parallelism.expert_tensor_parallel_degree 1",
91-
"--parallelism.moe_comm_backend deep_ep",
92-
],
93-
],
94-
"DeepSeek V3 PP+FSDP+TP+EP+DeepEP",
95-
"deepseek_v3_pp+fsdp+tp+ep+deepep",
96-
ngpu=8,
97-
),
9867
# Integration Test Cases for Qwen3 dense and MoE model
9968
OverrideDefinitions(
10069
[
@@ -123,23 +92,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
12392
"qwen3_fsdp+tp+ep+etp",
12493
ngpu=4,
12594
),
126-
# Integration Test Cases for Qwen3 with DeepEP
127-
OverrideDefinitions(
128-
[
129-
[
130-
"--model.name qwen3",
131-
"--model.flavor debugmodel_moe",
132-
"--parallelism.data_parallel_shard_degree 2",
133-
"--parallelism.tensor_parallel_degree 2",
134-
"--parallelism.expert_parallel_degree 2",
135-
"--parallelism.expert_tensor_parallel_degree 2",
136-
"--parallelism.moe_comm_backend deep_ep",
137-
],
138-
],
139-
"Qwen3 FSDP+TP+EP+ETP+DeepEP",
140-
"qwen3_fsdp+tp+ep+etp+deepep",
141-
ngpu=4,
142-
),
14395
# Integration Test Cases for Llama 4
14496
OverrideDefinitions(
14597
[
@@ -158,24 +110,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
158110
"llama4_pp+fsdp+tp+ep+compile",
159111
ngpu=8,
160112
),
161-
# Integration Test Cases for Llama 4 with DeepEP
162-
OverrideDefinitions(
163-
[
164-
[
165-
"--model.name llama4",
166-
"--parallelism.pipeline_parallel_degree 2",
167-
"--parallelism.pipeline_parallel_schedule Interleaved1F1B",
168-
"--parallelism.data_parallel_shard_degree 2",
169-
"--parallelism.tensor_parallel_degree 2",
170-
"--parallelism.expert_parallel_degree 4",
171-
"--parallelism.expert_tensor_parallel_degree 1",
172-
"--parallelism.moe_comm_backend deep_ep",
173-
],
174-
],
175-
"Llama 4 PP+FSDP+TP+EP+DeepEP",
176-
"llama4_pp+fsdp+tp+ep+deepep",
177-
ngpu=8,
178-
),
179113
]
180114

181115
return model_tests

torchtitan/config/job_config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,17 +416,26 @@ class Parallelism:
416416
Note that this is still an experimental feature.
417417
"""
418418

419-
moe_comm_backend: Literal["standard", "deep_ep"] = "standard"
419+
expert_parallel_comm_backend: Literal["standard", "deepep"] = "standard"
420420
"""
421-
MoE expert-parallel communication backend. No effect for non-MoE models or when ep = 1.
421+
Expert-parallel communication backend. No effect for non-MoE models or when ep = 1.
422422
423423
- "standard": Uses PyTorch all-to-all collectives (default)
424-
- "deep_ep": Uses DeepEP custom kernels for more efficient communication
424+
- "deepep": Uses DeepEP custom kernels for more efficient communication
425425
426426
DeepEP requires installation:
427427
https://github.com/deepseek-ai/DeepEP.
428428
"""
429429

430+
deepep_use_alignment_padding: bool = False
431+
"""
432+
Whether to use alignment padding for DeepEP token dispatch.
433+
Only applies when expert_parallel_comm_backend="deepep".
434+
435+
Recommended for large models (671B+) where the padding overhead is
436+
amortized over more compute. May cause slowdown for smaller models.
437+
"""
438+
430439

431440
@dataclass
432441
class Checkpoint:

torchtitan/distributed/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
from torch.distributed.tensor.placement_types import Placement
1414

1515
from torchtitan.distributed.parallel_dims import ParallelDims
16-
from torchtitan.distributed.expert_parallel import ExpertParallelDeepEP
17-
from torchtitan.distributed.deepep import MoEFlexTokenDispatcher
16+
from torchtitan.distributed.expert_parallel import DeepEPExpertParallel
1817

1918

2019
__all__ = [
2120
"ParallelDims",
2221
"NoParallel",
23-
"MoEFlexTokenDispatcher",
24-
"ExpertParallelDeepEP",
22+
"DeepEPExpertParallel",
2523
]
2624

2725

torchtitan/distributed/deepep/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66

77
"""DeepEP distributed communication primitives for MoE."""
88

9-
from .flex_dispatcher import MoEFlexTokenDispatcher
9+
from .deepep import (
10+
dispatch_tokens,
11+
combine_tokens,
12+
DispatchState,
13+
)
1014

1115
__all__ = [
12-
"MoEFlexTokenDispatcher",
16+
"dispatch_tokens",
17+
"combine_tokens",
18+
"DispatchState",
1319
]
14-

0 commit comments

Comments
 (0)