-
Notifications
You must be signed in to change notification settings - Fork 639
Integrate DeepEP to torchtitan #2107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cadd426
447d881
775b5be
1802b49
4a65d60
1c137d6
db455a3
0dddde8
4e3fda2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,9 +13,14 @@ | |
| from torch.distributed.tensor.placement_types import Placement | ||
|
|
||
| from torchtitan.distributed.parallel_dims import ParallelDims | ||
| from torchtitan.distributed.expert_parallel import DeepEPExpertParallel | ||
|
|
||
|
|
||
| __all__ = ["ParallelDims", "NoParallel"] | ||
| __all__ = [ | ||
| "ParallelDims", | ||
| "NoParallel", | ||
| "DeepEPExpertParallel", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not expose this here for now
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DeepEPExpertParallel needs to be used in |
||
| ] | ||
|
|
||
|
|
||
| # NOTE: This is to achieve replicate computation on the gate module in the MoE router. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """DeepEP distributed communication primitives for MoE.""" | ||
|
|
||
| from .deepep import ( | ||
| dispatch_tokens, | ||
| combine_tokens, | ||
| DispatchState, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "dispatch_tokens", | ||
| "combine_tokens", | ||
| "DispatchState", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is optional? IIUC this is must in order to use torch._grouped_mm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set to optional because
torch._grouped_gemmuses cutlass grouped gemm underlying - and cutlass's grouped gemm 16 Bytes alignment requirement is on contiguous dimension, let's sayA[M,K] (1, 0)- it'd be on contiguous dimension K:https://github.com/NVIDIA/cutlass/blob/d3a5492381a457e59a1fd27d97bb87c7ca95ee6e/include/cutlass/gemm/device/gemm_array.h#L344-L363
Technically in our case,
moe_intermediate_size=2048andhidden_size=7168are already multiples of 8 elements. I have run without padding and it worked. Still leave it an option here in case other underlying library might require padding on M dim.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC
moe_intermediate_sizeorhidden_sizedimension.I could be wrong. @ngimel could you advise?