Skip to content

Commit 4e3fda2

Browse files
committed
address comments
1 parent 0dddde8 commit 4e3fda2

7 files changed

Lines changed: 29 additions & 47 deletions

File tree

torchtitan/distributed/utils.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,12 @@ def _get_distributed_backend(enable_cpu_backend):
341341
prefix = comm_config.save_traces_file_prefix
342342
os.makedirs(dump_dir, exist_ok=True)
343343
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/{prefix}")
344-
345-
# _ranks argument is only available in newer PyTorch versions
346-
init_kwargs = {
347-
"backend": _get_distributed_backend(enable_cpu_backend),
348-
"timeout": timedelta(seconds=comm_config.init_timeout_seconds),
349-
}
350-
# Try with _ranks first (newer PyTorch), fall back without it
351-
try:
352-
torch.distributed.init_process_group(
353-
**init_kwargs,
354-
_ranks=ranks if ranks is not None else [],
355-
)
356-
except TypeError:
357-
# Older PyTorch doesn't support _ranks
358-
torch.distributed.init_process_group(**init_kwargs)
344+
345+
torch.distributed.init_process_group(
346+
backend=_get_distributed_backend(enable_cpu_backend),
347+
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
348+
_ranks=ranks if ranks is not None else [],
349+
)
359350

360351
return torch.distributed.get_world_size()
361352

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,8 @@
4444
# used to compute the scaling factor for quantization.
4545
torch.ops.aten.max.default,
4646
torch._higher_order_ops.flex_attention,
47+
torch._higher_order_ops.inductor_compiled_code,
4748
}
48-
# Add optional ops if available (requires newer PyTorch)
49-
try:
50-
_op_sac_save_list.add(torch._higher_order_ops.inductor_compiled_code)
51-
except AttributeError:
52-
pass
53-
54-
# Add DeepEP custom ops to SAC save list
55-
try:
56-
import torchtitan.distributed.deepep.deepep # noqa: F401
57-
_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
58-
_op_sac_save_list.add(torch.ops.deepep.combine.default)
59-
except (ImportError, AttributeError):
60-
pass
61-
6249

6350
# Adapted from llama4/infra/parallelize.py
6451
def parallelize_deepseekv3(
@@ -115,11 +102,17 @@ def parallelize_deepseekv3(
115102
job_config.parallelism.expert_parallel_comm_backend == "deepep"
116103
and not parallel_dims.ep_enabled
117104
):
105+
use_deepep = False
118106
logger.warning(
119107
"expert_parallel_comm_backend='deepep' has no effect when EP=1. "
120108
"Using standard communication."
121109
)
122110

111+
if use_deepep:
112+
import torchtitan.distributed.deepep.deepep # noqa: F401
113+
_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
114+
_op_sac_save_list.add(torch.ops.deepep.combine.default)
115+
123116
# DeepEP + ETP is not supported yet
124117
if use_deepep and parallel_dims.etp_enabled:
125118
raise NotImplementedError(

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
112112
)
113113

114114
# Configure expert parallel communication backend from config (defaults to "standard")
115-
self.expert_parallel_comm_backend = job_config.parallelism.expert_parallel_comm_backend
115+
self.moe_impl = job_config.parallelism.expert_parallel_comm_backend
116116

117117
def get_nparams_and_flops(
118118
self, model: nn.Module, seq_len: int

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
get_document_mask_mod,
2020
ScaledDotProductAttentionWrapper,
2121
)
22-
from torchtitan.models.moe import FeedForward, MoE
22+
from torchtitan.models.moe import FeedForward, MoE, build_moe
2323
from torchtitan.protocols.model import AttentionMasksType
2424
from torchtitan.protocols.train_spec import ModelProtocol
2525

@@ -350,13 +350,11 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
350350

351351
self.moe_enabled = layer_id >= model_args.n_dense_layers
352352
if self.moe_enabled:
353-
# Use build_moe factory to support different communication backends
354-
from torchtitan.models.moe import build_moe
355353
self.moe = build_moe(
356354
args=model_args.moe_args,
357355
dim=model_args.dim,
358356
hidden_dim=model_args.moe_inter_dim,
359-
communication_backend=model_args.expert_parallel_comm_backend,
357+
moe_impl=model_args.moe_impl,
360358
)
361359
else:
362360
self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim)

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@
5555
torch._higher_order_ops.inductor_compiled_code,
5656
}
5757

58-
# Add DeepEP custom ops to SAC save list
59-
try:
60-
import torchtitan.distributed.deepep.deepep # noqa: F401
61-
_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
62-
_op_sac_save_list.add(torch.ops.deepep.combine.default)
63-
except (ImportError, AttributeError):
64-
pass
65-
6658

6759
def parallelize_llama(
6860
model: nn.Module,
@@ -117,11 +109,19 @@ def parallelize_llama(
117109
job_config.parallelism.expert_parallel_comm_backend == "deepep"
118110
and not parallel_dims.ep_enabled
119111
):
112+
use_deepep = False
120113
logger.warning(
121114
"expert_parallel_comm_backend='deepep' has no effect when EP=1. "
122115
"Using standard communication."
123116
)
124117

118+
if use_deepep:
119+
# Import deepep module to register custom ops before accessing them
120+
import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep
121+
_op_sac_save_list.add(torch.ops.deepep.get_dispatch_layout.default)
122+
_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
123+
_op_sac_save_list.add(torch.ops.deepep.combine.default)
124+
125125
# DeepEP + ETP is not supported yet
126126
if use_deepep and parallel_dims.etp_enabled:
127127
raise NotImplementedError(

torchtitan/models/moe/moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,16 +506,16 @@ def init_weights(
506506
)
507507

508508

509-
def build_moe(args: MoEArgs, dim: int, hidden_dim: int, communication_backend: str = "standard") -> nn.Module:
509+
def build_moe(args: MoEArgs, dim: int, hidden_dim: int, moe_impl: str = "standard") -> nn.Module:
510510
"""Factory for MoE with different backends: 'standard' (all-to-all) or 'deepep' (DeepEP).
511511
512512
If 'deepep' is requested but DeepEP is not installed, falls back to standard with a warning.
513513
"""
514-
if communication_backend == "deepep":
514+
if moe_impl == "deepep":
515515
try:
516-
from .moe_deepep import MoEWithDeepEP
516+
from .moe_deepep import DeepEPMoE
517517
logger.info(f"DeepEP MoE: num_experts={args.num_experts}, top_k={args.top_k}, dim={dim}, hidden_dim={hidden_dim}")
518-
return MoEWithDeepEP(moe_args=args, dim=dim, hidden_dim=hidden_dim)
518+
return DeepEPMoE(moe_args=args, dim=dim, hidden_dim=hidden_dim)
519519
except ImportError as e:
520520
logger.warning(
521521
f"DeepEP requested but not available: {e}. "

torchtitan/models/moe/moe_deepep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .moe import MoE, MoEArgs
1212

1313

14-
class MoEWithDeepEP(MoE):
14+
class DeepEPMoE(MoE):
1515
"""
1616
Mixture of Experts with DeepEP communication.
1717

0 commit comments

Comments
 (0)