Skip to content

Conversation

@elfiegg
Copy link

@elfiegg elfiegg commented Dec 4, 2025

Initial version to integrate DeepEP to torchtitan - ensures functionality and is compatible with torch.compile and SAC.
Currently construct DeepEP MoE layer and its EP parallelism via a user controlled config.

Shout out to my colleagues @gekurian @syed-ahmed @aazzolini for internal supports!

@meta-cla
Copy link

meta-cla bot commented Dec 4, 2025

Hi @elfiegg!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@elfiegg elfiegg force-pushed the loss_bug branch 7 times, most recently from a5875e5 to 6999d1e Compare December 4, 2025 05:13
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!

Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)

May I ask what is the baseline?

Btw, I think to fully utilize the power of DeepEP, we also need to have node-limited routing, which the current torchtitan DSv3 model doesn't have.

@shuhuayu let's add it? we can refer to HF or deepseek original impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making it an experiment (which restricts it to a special version of deepseek_v3), I think we should integrate it directly in core.
We can have a factory method (e.g. build_moe) which takes a string (e.g. "deep_ep") to dispatch to this version of MoE.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that's a great idea! - once I confirm this works for larger models and improves perf

Regarding integrating directly to main - do we need to manage DeepEP dependency at all or we leave it to the users to install?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I prefer

we leave it to the users to install

instead of bundling it by default. We can explicitly mention this in try-catch when we do the import.

from torch.distributed.tensor import distribute_module


class DeepEPExpertParallel(ParallelStyle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere? I'm guessing that this is not running e2e with torchtitan train.py which is still WIP.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.

This is our implementation and our deepep interface is similar to Nvidia's version.

https://github.com/AMD-AGI/torchtitan-amd/blob/59fe226a46d76a553d4e591411a464390475be02/torchtitan/distributed/expert_parallel.py#L441

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr! I think we should support node-limited routing to make multi-node setup faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the node-limited routing here: #2111. Perhaps it helps make deepep faster in multi-node setups.

@elfiegg
Copy link
Author

elfiegg commented Dec 4, 2025

Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)

May I ask what is the baseline?

Actually I have rerun this last night and the perf caught up - the lagging perf was gone once I enabled FSDP for MoE layer (which I disabled for debugging purpose). Running below command, I got 13% MFU for both baseline and DeepEP version

torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$NPROC_PER_NODE \
    --rdzv_id=deepseek_16b_multinode \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
    -m torchtitan.train \
   --parallelism.expert_parallel_degree 16 \
    --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Baseline I referred to the config here: ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml
And DeepEP version is to override --model.name=deepep.deepseek_v3

@yuankaichen-amd
Copy link

Thanks for posting the work!

We had a successful and performant DeepEP integration at: AMD-AGI@59fe226

We borrowed some design from Megatron-LM and we can use it here too.

I don't see big differences between our DeepEP interface and yours. Let's work together on this. Feel free to reach out to me or Tianyu for future discussion and collaboration.

from torch.distributed.tensor import distribute_module


class DeepEPExpertParallel(ParallelStyle):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.

This is our implementation and our deepep interface is similar to Nvidia's version.

https://github.com/AMD-AGI/torchtitan-amd/blob/59fe226a46d76a553d4e591411a464390475be02/torchtitan/distributed/expert_parallel.py#L441


if self.score_before_experts:
recv_x_sorted = (recv_x_sorted.to(torch.float32) * valid_weights_sorted.unsqueeze(-1)).to(recv_x_sorted.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code before experts.forward should go to DeepEPExpertParallel as input_fn, same for the token unpermute after experts as output_fn

Also consider using something like _indices_to_multihot_kernel (https://github.com/NVIDIA/Megatron-LM/blob/f5344166732f45bb0dd825dc875288ea97b15b47/megatron/core/fusions/fused_indices_converter.py#L32C5-L32C32) to preprocess received DeepEP data.

You are using a lot of index-selecting here which I suspect would incur significant CPU overhead (and lock/wait among CPU threads)

@elfiegg
Copy link
Author

elfiegg commented Dec 5, 2025

Thanks all for the valuable advice! - I'm currently occupied by a deadline but I will take a closer look and join the discussion tomorrow

@elfiegg
Copy link
Author

elfiegg commented Dec 5, 2025

I scanned through the comments, and here is a summary:

  1. We prefer wrapping DeepEP dispatch and combine logic into ExpertParallel module for clear injections
  2. We prefer a a factory method to build MoE module based on a configurable string - depending on user's choice and/or container environment
  3. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency
  4. We prefer integrating directly to non-experimental codebase

If this looks good to everyone, I'll start revising the PR
cc @tianyu-l @yuankaichen-amd @shuhuayu

@tianyu-l
Copy link
Contributor

tianyu-l commented Dec 5, 2025

@elfiegg sounds good overall

  1. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency

I think this may be done in a followup PR, assuming that the tradeoff can be justified by benchmarking results. cc @yuankaichen-amd WDYT?

@yuankaichen-amd
Copy link

@elfiegg sounds good overall

  1. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency

I think this may be done in a followup PR, assuming that the tradeoff can be justified by benchmarking results. cc @yuankaichen-amd WDYT?

Either way works -- it should be a low-hanging fruit. The triton kernel is available in both Megatron and my integration PR.

@yuankaichen-amd
Copy link

I scanned through the comments, and here is a summary:

  1. We prefer wrapping DeepEP dispatch and combine logic into ExpertParallel module for clear injections
  2. We prefer a a factory method to build MoE module based on a configurable string - depending on user's choice and/or container environment
  3. We prefer lowering the fusible torch naive ops like indexing to written triton kernels for efficiency
  4. We prefer integrating directly to non-experimental codebase

If this looks good to everyone, I'll start revising the PR cc @tianyu-l @yuankaichen-amd @shuhuayu

I strongly recommend wrapping DeepEP related ops into a standalone class:
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py#L52

There are six methods:
pre_dispatch / dispatch / post_dispatch
pre_combine / combine / post_combine

It will set a clear boundary between DeepEP and Torchtitan's MoE module or ExpertParallel wrapper. Also you can get a free ride for many things that Nvidia has already implemented in Megatron.

@elfiegg elfiegg force-pushed the loss_bug branch 2 times, most recently from 0a9815b to e0d4fcf Compare December 6, 2025 19:03
Comment on lines 168 to 170
x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
x, selected_experts_indices, top_scores
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version looks much cleaner.

After reading the code, I wonder what's the best way to organize code. Brainstorming with some (immature) ideas:

  1. Use if use_deepep in the code for the region of difference. @shuhuayu IIUC you were having this idea?
  2. Abstract token_dispatching + routed_experts computation into its own classes, so that the MoE class can be shared.
  3. Moving dispatch_preprocess and dispatch_postprocess also inside ExpertParallel hooks. The challenges seems that ExpertParallel classes are not getting all the inputs we need.
  4. Unify dispatch_preprocess with the TokenReorderer concept needed for the non-deepep impl. The challenge is similar to 3 in that the interfaces do not really align.

@yuankaichen-amd would love to hear your thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can revise the existing MoE modules by adding a if use_deepep branch, or finding a way to inherit this MoE module if possible.

Copy link

@yuankaichen-amd yuankaichen-amd Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we compare the existing MoE against the DeepEP MoE:

self.router(x, self.expert_bias)  

self.reorderer(top_scores, selected_experts_indices)

#### many lines omitted

# shape (bs*slen*top_k, dim)

routed_output = self.experts(routed_input, num_tokens_per_expert)

===============================================

self.router(x, self.expert_bias)

x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
       x, selected_experts_indices, top_scores
)

routed_output = self.experts(x_prep, num_tokens_per_expert)

For the old MoE class, what happens between router and experts is also a kind of preprocess.

I think we can have an "AlltoallTokenDispatcher" which modularizes these operations. So in the combined MoE implementation, we will have:

self.router(x, self.expert_bias)

some_token_dispatcher.preprocess(x, selected_experts_indices, top_scores)

routed_output = self.experts(routed_input, ...)

With this, we can even combine the ExpertParallel with DeepExerptParallel, where _token_dispatch and _token_combine will be interfaced with a token dispatcher directly.

Copy link
Contributor

@tianyu-l tianyu-l Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuankaichen-amd
Sounds nice in general, I have some concerns about the following:

  • Now token_dispatcher is a submodule of MoE but we only call token_dispatcher.preprocess in model code, and delay the actual dispatch / combine into hooks, which doesn't sound natural.
  • The benefit of using hooks was that single-device code is still correct, and we can apply EP on top of single-device code. If we do the DeepEP path which has a different preprocess method, would the single-device code EP=1 + DeepEP enabled still be "correct"?

Based on the above points, should we move towards dispatch / combine being actually called "in the model code", any be of no-op when EP is not enabled?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your concerns. By "dispatch / combine being actually called "in the model code"", are you suggesting that with this token_dispatcher design, we should also retire the ExpertParallel wrapper?

Copy link
Contributor

@shuhuayu shuhuayu Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the clarification! Actually now I am thinking that we can inherit from current MoE to create a new DeepEPMoE since they share most code, and use the current build_moe function to build it if DeepEP is used. So basically, we use separate MoEs (MoE and DeepEPMoE, and also separate ExpertParallels (ExpertParallel and DeepEPExpertParallel) to integrate DeepEP into TorchTitan so the main interface is preserved. We can create a folder under distributed named expert_parallel and put the main class of DeepEPExpertparall into existing expert_parallel.py and other supporting classes and functions into another deep_ep.py, both .py files are under the new expert_parallel folder. For the preprocessing and postprocessing functions for token_dispatch or token_combine (in DeepEPExpertParallel), we can add private methods into the same classes if necessary. WDYT?

Copy link
Author

@elfiegg elfiegg Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be similar to the current implementation - either way works for me, abstracting common API or inheriting the current modules. One benefit we get out of box by not separating ExpertParallel is ETP seems to work out of box, also as @yuankaichen-amd pointed out, comm-comp overlap might also work along.
No objection to build_moe and DeeoEPMoE - either way we need to condition the logic at some level. Separating them might be beneficial if more comm libraries come up and complicate MoE class

Copy link
Contributor

@tianyu-l tianyu-l Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comm-comp overlap might also work along.

@yuankaichen-amd could you share more details about the overlapping and how having the same EP class would help? I discussed with @shuhuayu offline and we are not sure what's the benefit of separating into six methods, namely [dispatch, combine] x [preprocess, process, postprocess].

My take is that ExpertParallel class is implicitly a token dispatcher, it's just applied in a wrapper from outside the model.
Explicitly creating another token dispatcher (for expert parallel comms), and let the implicit one (ExpertParallel) access the explicit token dispatcher class indirectly via has_attr call sounds not very straightforward.

At this moment, I'm leaning towards having two MoE classes

  • letting the DeepEPMoE inheriting the base MoE one and only overwrites the forward function
  • Having another DeepEPExpertParallel class which inherits a BaseExpertParallel protocol class.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a discussion with @tianyu-l and @shuhuayu offline. I agree that
(1) having two separate MoE classes may work best for now (@elfiegg's original design);
(2) token_dispatcher now looks unnecessary (sorry for suggesting it in the first place) and we can directly invoke DeepEP related methods in the DeepEPExpertParallel;
(3) DeepEP needs some additional inputs, we can do this by some additional attributes in DeepEPMoE module, or add a manager subclass if needed. Since @elfiegg has already had DeepEp manager implemented, I'd suggest let's leave it as it is for now. I will also need to take a closer look at AMD's DeepEP's different modes. Let's review this design later.

@elfiegg what do you think?

@tianyu-l @shuhuayu Please add if I missed anything.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM - will send changes by EoD


# Setup dispatcher metadata (routing information) for hooks to use
# The hooks will call token_dispatch/token_combine which need this metadata
x_prep, probs_prep = self.deepep_dispatcher.dispatch_preprocess(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would preprocess and postprocess do if we don't use EP, e.g. single-device computation -- would it be no-op?

Copy link
Author

@elfiegg elfiegg Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to fall back to standard impl if (no EP && use DeepEP)

from torchtitan.tools.logging import logger


class MoEWithDeepEP(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we inherit existing MoE to reuse most code except for MoEFlexTokenDispatcher init and the forward call?

dim=model_args.dim,
hidden_dim=model_args.moe_inter_dim,
communication_backend=model_args.moe_comm_backend,
score_before_experts=model_args.moe_args.score_before_experts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already available in first arg

Note that this is still an experimental feature.
"""

moe_comm_backend: Literal["standard", "deep_ep"] = "standard"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe follow the convention and call it

Suggested change
moe_comm_backend: Literal["standard", "deep_ep"] = "standard"
expert_parallel_comm_backend: Literal["standard", "deep_ep"] = "standard"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# Select parallelism style based on use_deepep flag
if use_deepep:
from torchtitan.distributed import ExpertParallelDeepEP
from torchtitan.tools.logging import logger as parallelism_logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now it's the same the logger in this class

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 15 to 16
"MoEWithDeepEP",
"MoEFlexTokenDispatcher",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now maybe let's not expose them here, at the cost of using HAS_DEEPEP everywhere

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, done

from deep_ep.utils import EventOverlap, EventHandle
HAS_DEEPEP = True
except ImportError:
HAS_DEEPEP = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of testing HAS_DEEPEP at multiple locations, I wonder if we can just error out here and let the callsites be careful about importing this file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Now only build_moe util in moe.py deal with HAS_DEEPEP - and fall back if HAS_DEEPEP is false

)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does DeepEP work with TP? If not we should error out when are enabled together.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

from torchtitan.distributed import ExpertParallelDeepEP
from torchtitan.tools.logging import logger as parallelism_logger
experts_plan = ExpertParallelDeepEP()
parallelism_logger.info(f" Applying DeepEP to MoE layer")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a general logger as "Applied Expert Parallel with xxx comm backend"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)


class ExpertParallelDeepEP(ExpertParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
class ExpertParallelDeepEP(ExpertParallel):
class DeepEPExpertParallel(ExpertParallel):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ep_group = device_mesh.get_group()
routed_input, routed_prob = mod.deepep_dispatcher.token_dispatch(routed_input, ep_group)
routed_input, num_tokens_per_expert, routed_prob = mod.deepep_dispatcher.dispatch_postprocess(routed_input, None)
return routed_input, num_tokens_per_expert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a place where padding is done so that each expert is always getting a multiple of 8 / 16 tokens (required by torch._group_mm), similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L129-L136

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the alignment requirements on M dim isn't about functionality (There should be functionality alignment requirement on contracting dim K though, but DeepSeek moe intermediate size is for sure multiples of 16bytes) - And I found padding doesn't improve performance either, so added a configurable pad_to_alignment to let user choose

logger.info(f"Allocated fallback RDMA buffer: {num_rdma_bytes} bytes")

low_latency_mode = is_multinode or group.size() > 8
buffer = Buffer(group=group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=low_latency_mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this Buffer for? It seems it's not proportional to the num of tokens.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides input tokens and output tokens, DeepEP needs to initialize a symmetric buffers (same address on all ranks) for chunked RDMA/NVLink comm; Since the output size is unknown until CPU sync, it takes pre-configured bytes

Comment on lines 142 to 144
num_recv_tokens_per_expert_tensor = torch.tensor(
num_recv_tokens_per_expert_list, dtype=torch.int64, device='cpu'
).to(recv_x.device, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what this is doing? I know in "standard" impl we are doing D2H sync, but this seems H2D sync?

Copy link
Author

@elfiegg elfiegg Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great checking! c:
It's not moving any mem between host and device; it attempts to convert a python list to a tensor on CPU side, which to(recv_x.device) isn't necessary at all!

previous_event = _create_event_if_async(async_finish)

recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, after_event = \
buffer.dispatch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how activation checkpointing is done. We need to save the forward comm result so that backward doesn't do the comms again.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created torch custom op for DeepEP so that we can work with SAC. Without custom op (manual caching) SAC would track tensors and assert the total number of created tensors ain't aligned to registry

return routed_input, num_tokens_per_expert

@staticmethod
def _partition_fn(name, mod, device_mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have to repeat this function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a BaseExpertParallel as abstract class, thus preserving _partition and _apply funcs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be repeated work in #1721

routed_output = mod.deepep_dispatcher.token_combine(routed_output, ep_group)
return routed_output

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, we don't have to have this function?

@elfiegg elfiegg changed the title Integrate DeepEP to experimental torchtitan Integrate DeepEP to torchtitan Dec 8, 2025
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this if we implemented it outside of experiments?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines 101 to 105
# Allow use_flex_attn to be set from config
if hasattr(job_config.model, 'use_flex_attn') and job_config.model.use_flex_attn is not None:
self.use_flex_attn = job_config.model.use_flex_attn
logger.info(f"Setting use_flex_attn={self.use_flex_attn} from config")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the the current version support flex attention? Is this block still needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@elfiegg
Copy link
Author

elfiegg commented Dec 9, 2025

Thanks for the feedback! A summary of our discussion:

  1. Design direction (API unification):
    move dispatch pre/post processing into ExpertParallel hooks, and extract the standard MoE token_reorder logic into a preprocess stage
    Let both DeepEP and the standard MoE initialize a common Dispatcher (subclassing an abstract AllToAllDispatcher). -
    This further isolates parallelism logic from expert computation and unifies the APIs across standard and library-specific paths

  2. Build-time selection (DeepEP vs standard):
    Introduce a single top-level HAS_DEEPEP switch to control module selection in build_moe, and enforce consistent failure behavior in the underlying implementations when the configuration is unsupported

  3. TP + EP restriction
    Not aware about any restriction, thought ETP would orthogonally handle EP and TP mesh(so DeepEP all-to-all still applies):

    class ExpertTensorParallel(ExpertParallel):
    def _token_dispatch(self, mod, inputs, device_mesh):
    routed_input, num_tokens_per_expert = inputs
    # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
    # The grad_placements on inputs is set to Partial so that necessary
    # reductions are performed during backward.
    routed_input = DTensor.from_local(
    routed_input, device_mesh["tp"], (Replicate(),)
    ).to_local(grad_placements=(Partial(),))
    inputs = (routed_input, num_tokens_per_expert)
    # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
    return super()._token_dispatch(mod, inputs, device_mesh["ep"])

    After we modify the design to wrap token dispatching logic in ExpertParallel, I assume this should automatically work?

  4. Additional suggestions
    (a) Add a more general logging/logger
    (b) Pad to multiple of 8 tokens for BF16 due to the CUTLASS 16 Bytes alignment requirement
    (c) Add SAC support for saving forward communication results for backward
    (d) Inherit standard MoE and reduce code duplication
    (d) Perform general code cleanup across the module

@yuankaichen-amd
Copy link

Thanks for the feedback! A summary of our discussion:

  1. Design direction (API unification):
    move dispatch pre/post processing into ExpertParallel hooks, and extract the standard MoE token_reorder logic into a preprocess stage
    Let both DeepEP and the standard MoE initialize a common Dispatcher (subclassing an abstract AllToAllDispatcher). -
    This further isolates parallelism logic from expert computation and unifies the APIs across standard and library-specific paths

  2. Build-time selection (DeepEP vs standard):
    Introduce a single top-level HAS_DEEPEP switch to control module selection in build_moe, and enforce consistent failure behavior in the underlying implementations when the configuration is unsupported

  3. TP + EP restriction
    Not aware about any restriction, thought ETP would orthogonally handle EP and TP mesh(so DeepEP all-to-all still applies):

    class ExpertTensorParallel(ExpertParallel):
    def _token_dispatch(self, mod, inputs, device_mesh):
    routed_input, num_tokens_per_expert = inputs
    # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
    # The grad_placements on inputs is set to Partial so that necessary
    # reductions are performed during backward.
    routed_input = DTensor.from_local(
    routed_input, device_mesh["tp"], (Replicate(),)
    ).to_local(grad_placements=(Partial(),))
    inputs = (routed_input, num_tokens_per_expert)
    # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
    return super()._token_dispatch(mod, inputs, device_mesh["ep"])

    After we modify the design to wrap token dispatching logic in ExpertParallel, I assume this should automatically work?

  4. Additional suggestions
    (a) Add a more general logging/logger framework
    (b) Pad to multiple of 8 tokens for BF16 due to the CUTLASS 16 Bytes alignment requirement
    (c) Add SAC support for saving forward communication results for backward
    (d) Inherit standard MoE and reduce code duplication
    (d) Perform general code cleanup across the module

Thanks for the feedback! A summary of our discussion:

  1. Design direction (API unification):
    move dispatch pre/post processing into ExpertParallel hooks, and extract the standard MoE token_reorder logic into a preprocess stage
    Let both DeepEP and the standard MoE initialize a common Dispatcher (subclassing an abstract AllToAllDispatcher). -
    This further isolates parallelism logic from expert computation and unifies the APIs across standard and library-specific paths

  2. Build-time selection (DeepEP vs standard):
    Introduce a single top-level HAS_DEEPEP switch to control module selection in build_moe, and enforce consistent failure behavior in the underlying implementations when the configuration is unsupported

  3. TP + EP restriction
    Not aware about any restriction, thought ETP would orthogonally handle EP and TP mesh(so DeepEP all-to-all still applies):

    class ExpertTensorParallel(ExpertParallel):
    def _token_dispatch(self, mod, inputs, device_mesh):
    routed_input, num_tokens_per_expert = inputs
    # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors.
    # The grad_placements on inputs is set to Partial so that necessary
    # reductions are performed during backward.
    routed_input = DTensor.from_local(
    routed_input, device_mesh["tp"], (Replicate(),)
    ).to_local(grad_placements=(Partial(),))
    inputs = (routed_input, num_tokens_per_expert)
    # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
    return super()._token_dispatch(mod, inputs, device_mesh["ep"])

    After we modify the design to wrap token dispatching logic in ExpertParallel, I assume this should automatically work?

  4. Additional suggestions
    (a) Add a more general logging/logger framework
    (b) Pad to multiple of 8 tokens for BF16 due to the CUTLASS 16 Bytes alignment requirement
    (c) Add SAC support for saving forward communication results for backward
    (d) Inherit standard MoE and reduce code duplication
    (d) Perform general code cleanup across the module

Thanks for the summary! It looks good to me in general. Just one suggestion -- I think we could easily have a unified MoE module once we have (1). So build_moe may retire if this works.

@elfiegg
Copy link
Author

elfiegg commented Dec 9, 2025

Thanks for the summary! It looks good to me in general. Just one suggestion -- I think we could easily have a unified MoE module once we have (1). So build_moe may retire if this works.

Agree, we can then construct the right Dispatcher based on the env - and be very compatible with the current flow

@elfiegg elfiegg force-pushed the loss_bug branch 2 times, most recently from a305ed7 to 804ccb3 Compare December 11, 2025 08:29
@yuankaichen-amd
Copy link

Thanks @elfiegg. It looks good to me in general, I only have a minor question -- why don't we also make the Deepep handle as part of the State but use a separate cache?

Let's merge this PR and I will follow up on the performance optimization and AMD compatibility.

@elfiegg
Copy link
Author

elfiegg commented Dec 11, 2025

Thanks @elfiegg. It looks good to me in general, I only have a minor question -- why don't we also make the Deepep handle as part of the State but use a separate cache?

Let's merge this PR and I will follow up on the performance optimization and AMD compatibility.

Yeah I was bothered by this last night too - so torch.library doesn't allow returning arbitrary Python object like deepep's Handle. And SAC doesn't track/save primitives like int/float so I had to return a cache_id that's wrapped in a tensor, and use that as cache identifier to retrieve handle

@elfiegg elfiegg marked this pull request as ready for review December 11, 2025 21:37
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 11, 2025
@meta-cla
Copy link

meta-cla bot commented Dec 11, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

https://github.com/deepseek-ai/DeepEP.
"""

deepep_use_alignment_padding: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is optional? IIUC this is must in order to use torch._grouped_mm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to optional because torch._grouped_gemm uses cutlass grouped gemm underlying - and cutlass's grouped gemm 16 Bytes alignment requirement is on contiguous dimension, let's say A[M,K] (1, 0) - it'd be on contiguous dimension K:

https://github.com/NVIDIA/cutlass/blob/d3a5492381a457e59a1fd27d97bb87c7ca95ee6e/include/cutlass/gemm/device/gemm_array.h#L344-L363

Technically in our case, moe_intermediate_size=2048 and hidden_size=7168 are already multiples of 8 elements. I have run without padding and it worked. Still leave it an option here in case other underlying library might require padding on M dim.

__all__ = [
"ParallelDims",
"NoParallel",
"DeepEPExpertParallel",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not expose this here for now

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepEPExpertParallel needs to be used in parallelize.py files, are you suggesting we don't integrate to models for now?

timeout=timedelta(seconds=comm_config.init_timeout_seconds),
_ranks=ranks if ranks is not None else [],
)
# _ranks argument is only available in newer PyTorch versions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchtitan is supposed to be used with latest pytorch, so let's revert this change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I have reverted them. But with the torch nightly build, I got an error:

[rank0]:[rank0]:     self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
[rank0]:[rank0]: RuntimeError: set_stride is not allowed on a Tensor created from .data or .detach().
[rank0]:[rank0]: If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
[rank0]:[rank0]: without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
[rank0]:[rank0]: For example, change:
[rank0]:[rank0]:     x.data.set_(y)
[rank0]:[rank0]: to:
[rank0]:[rank0]:     with torch.no_grad():
[rank0]:[rank0]:         x.set_(y)

And using stable release 2.9.1 torch version, everything is fine. Do you have any suggestions?

torch._higher_order_ops.flex_attention,
torch._higher_order_ops.inductor_compiled_code,
}
# Add optional ops if available (requires newer PyTorch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

try:
import torchtitan.distributed.deepep.deepep # noqa: F401
_op_sac_save_list.add(torch.ops.deepep.dispatch.default)
_op_sac_save_list.add(torch.ops.deepep.combine.default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add it later in the function where this is called? I think we should consistently use the config to toggle these fields.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
positions: torch.Tensor | None = None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this field for?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hit a bug where return_output arg is passed in but deepseekV3 can't handle...(it defaults to False for PP anyways though)

args=model_args.moe_args,
dim=model_args.dim,
hidden_dim=model_args.moe_inter_dim,
communication_backend=model_args.expert_parallel_comm_backend,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
communication_backend=model_args.expert_parallel_comm_backend,
moe_impl=model_args.moe_impl,

from .moe import MoE, MoEArgs


class MoEWithDeepEP(MoE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
class MoEWithDeepEP(MoE):
class DeepEPMoE(MoE):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

and not parallel_dims.ep_enabled
):
logger.warning(
"expert_parallel_comm_backend='deepep' has no effect when EP=1. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it just fail, because we will be missing the reordering step

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added use_deepep=False to fall back to standard MoE

return routed_input, num_tokens_per_expert

@staticmethod
def _partition_fn(name, mod, device_mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be repeated work in #1721

@elfiegg elfiegg force-pushed the loss_bug branch 2 times, most recently from ddbdc79 to 6f66f39 Compare December 12, 2025 21:29
Copy link
Author

@elfiegg elfiegg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue I had when testing this with latest torch is - it fails due to:

[rank7]:[rank7]:   File "/lustre/fsw/portfolios/sw/projects/sw_aidot/users/elfieg/torchtitan/torchtitan/distributed/deepep/deepep.py", line 368, in dispatch_tokens
[rank7]:[rank7]:     buffer.get_dispatch_layout(topk_idx=selected_experts_indices, num_experts=num_experts)
[rank7]:[rank7]:   File "/lustre/fsw/portfolios/sw/projects/sw_aidot/users/elfieg/DeepEP/deep_ep/buffer.py", line 317, in get_dispatch_layout
[rank7]:[rank7]:     self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
[rank7]:[rank7]: RuntimeError: set_stride is not allowed on a Tensor created from .data or .detach().
[rank7]:[rank7]: If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
[rank7]:[rank7]: without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
[rank7]:[rank7]: For example, change:
[rank7]:[rank7]:     x.data.set_(y)
[rank7]:[rank7]: to:
[rank7]:[rank7]:     with torch.no_grad():
[rank7]:[rank7]:         x.set_(y)

Is this known to us? It looks like a torch bug, I can provide full log / repro if it helps.

n_limited_groups: int = 1

# Expert parallel communication backend (set from config)
expert_parallel_comm_backend: str = "standard" # "standard" or "deepep"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, are we suggesting to refactor expert_parallel_comm_backend to moe_impl in job_config.parallelism?

if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
"Failed to use grouped_mm, which is only supported on SM90 or later",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

self.moe = MoE(
model_args.moe_args,
# Use build_moe factory to support different communication backends
from torchtitan.models.moe import build_moe
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

and not parallel_dims.ep_enabled
):
logger.warning(
"expert_parallel_comm_backend='deepep' has no effect when EP=1. "
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added use_deepep=False to fall back to standard MoE

from .moe import MoE, MoEArgs


class MoEWithDeepEP(MoE):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants