-
Notifications
You must be signed in to change notification settings - Fork 680
perf(pipeline): implement auto-partition algorithm #2113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
TXacs
wants to merge
5
commits into
pytorch:main
Choose a base branch
from
McmillanTAC:autopartition
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
6a06ed7
perf(pipeline): implement auto-partition algorithm
TXacs 1f8b2f4
Format to fix and add license
TXacs 4d90aea
Optimize FLOPs calculation for partition layer
TXacs 9e276bf
Modification based on torchtitan's transformer-only recomputation con…
TXacs a41af40
Feat: migrate AutoPipe to pure Python & integrate FlopCounterMode
TXacs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # Auto-Partition in torchtitan | ||
|
|
||
| ## Overview | ||
|
|
||
| This folder provides an automatic partitioning method that considers the computation cost of embedding layers. | ||
| This method involves calculating the floating-point operations (FLOPs) of the embedding layers and constructing an array that incorporates the FLOPs of both the transformer and embedding layers. Subsequently, a heuristic algorithm is employed to identify a balanced pipeline partition. | ||
|
|
||
| ## Quick Start | ||
|
|
||
| ### Compile | ||
|
|
||
| First, we need to compile `autopipe.cpp`. | ||
| ```bash | ||
| pip install pybind11 | ||
| cd ./torchtitan/experiments/autopartition/infra/cpp | ||
| mkdir build | ||
| cd build | ||
| cmake .. | ||
| make | ||
| mv *.so ../../ | ||
| ``` | ||
|
|
||
| The following command uses Llama 3 as an example: | ||
|
|
||
| ```bash | ||
| CONFIG_FILE="./torchtitan/experiments/autopartition/train_configs/debug_model.toml" ./run_train.sh | ||
| ``` | ||
|
|
||
| ## Performance | ||
|
|
||
| Hardware configuration: 4x RTX 3090 24GB, pipeline parallelism dimension is 4. | ||
|
|
||
| ### llama3 配置对比 | ||
| | hidden size| layers | autopipe TPS| default TPS| Speedup | | ||
| | ---------- | ---- | ---------- | -----------| ----------- | | ||
| | dim=256 | 6 | 31,094 | 29,549 | +5.2% | | ||
| | dim=256 | 12 | 21,803 | 21,923 | -0.5% | | ||
| | dim=2048 | 12 | 3,348 | 2,616 | +28.0% | | ||
| | dim=4096 | 12 | 981 | 761 | +28.9% | | ||
|
|
||
| ### deepseekv3(without moe) 配置对比 | ||
|
|
||
| | hidden size| layers | autopipe TPS| default TPS| Speedup | | ||
| | ---------- | ---- | ---------- | -----------| ----------- | | ||
| | dim=256 | 6 | 13,373 | 13,059 | +2.4% | | ||
| | dim=256 | 12 | 7,714 | 6,859 | +12.5% | | ||
| | dim=2048 | 12 | 4,331 | 3,810 | +13.7% | | ||
| | dim=4096 | 12 | 2,888 | 2,561 | +12.8% | | ||
| | dim=4096 | 16 | 2,207 | 2,008 | +9.9% | | ||
| | dim=8192 | 16 | 4,331 | 3,935 | +10.1% | | ||
|
|
||
|
|
||
| ### Known Issues | ||
|
|
||
| - **Not Support Moe** - Auto-Partition need flops for each layers, but current profiler from deepspeed not support computing flops for moe. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| __all__ = [ | ||
| "get_deepseek_v3_train_spec", | ||
| "get_llama3_train_spec", | ||
| ] | ||
|
|
||
|
|
||
| from .deepseek_v3_tain_spec import get_deepseek_v3_train_spec | ||
| from .llama3_tain_spec import get_llama3_train_spec |
121 changes: 121 additions & 0 deletions
121
torchtitan/experiments/autopartition/deepseek_v3/args.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
|
|
||
| from dataclasses import dataclass, field | ||
|
|
||
| from torch import nn | ||
|
|
||
| from torchtitan.config import JobConfig | ||
| from torchtitan.models.moe import MoEArgs | ||
| from torchtitan.models.utils import get_moe_model_nparams_and_flops | ||
| from torchtitan.protocols.model import BaseModelArgs | ||
| from torchtitan.tools.logging import logger | ||
| from torchtitan.tools.utils import has_cuda_capability | ||
|
|
||
|
|
||
| # Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py | ||
| @dataclass | ||
| class DeepSeekV3ModelArgs(BaseModelArgs): | ||
| """ | ||
| Data class for defining model arguments and hyperparameters. | ||
|
|
||
| Attributes: | ||
| max_batch_size (int): Maximum batch size. | ||
| max_seq_len (int): Maximum sequence length. | ||
| vocab_size (int): Vocabulary size. | ||
| dim (int): Model dimension. | ||
| inter_dim (int): Intermediate dimension for MLP layers. | ||
| moe_inter_dim (int): Intermediate dimension for MoE layers. | ||
| n_layers (int): Number of transformer layers. | ||
| n_dense_layers (int): Number of dense layers in the model. | ||
| n_heads (int): Number of attention heads. | ||
| norm_eps (float): Epsilon value used for RMSNorm. | ||
| moe_args (MoEArgs): MoE configuration. | ||
| n_expert_groups (int): Number of expert groups. | ||
| n_limited_groups (int): Number of limited groups for MoE routing. | ||
| q_lora_rank (int): LoRA rank for query projections. | ||
| kv_lora_rank (int): LoRA rank for key-value projections. | ||
| qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. | ||
| qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. | ||
| v_head_dim (int): Dimension for value projections. | ||
| use_flex_attn (bool): Whether to use FlexAttention. | ||
| attn_mask_type (str): Type of attention mask. | ||
| original_seq_len (int): Original sequence length. | ||
| rope_theta (float): Base for rotary positional encoding. | ||
| rope_factor (float): Scaling factor for extended sequence lengths. | ||
| beta_fast (int): Fast beta correction factor. | ||
| beta_slow (int): Slow beta correction factor. | ||
| """ | ||
|
|
||
| max_batch_size: int = 8 | ||
| max_seq_len: int = 4096 * 4 | ||
| vocab_size: int = 102400 | ||
| dim: int = 2048 | ||
| inter_dim: int = 10944 | ||
| moe_inter_dim: int = 1408 | ||
| n_layers: int = 27 | ||
| n_dense_layers: int = 1 | ||
| n_heads: int = 16 | ||
| norm_eps: float = 1e-5 # eps used for RMSNorm | ||
|
|
||
| # MoE | ||
| moe_args: MoEArgs = field(default_factory=MoEArgs) | ||
| # TODO: node-limited routing is not supported yet | ||
| n_expert_groups: int = 1 | ||
| n_limited_groups: int = 1 | ||
|
|
||
| # Multi-Head Latent Attention (MLA) | ||
| q_lora_rank: int = 0 | ||
| kv_lora_rank: int = 512 | ||
| qk_nope_head_dim: int = 128 | ||
| qk_rope_head_dim: int = 64 | ||
| v_head_dim: int = 128 | ||
| use_flex_attn: bool = False | ||
| attn_mask_type: str = "causal" | ||
|
|
||
| # yarn | ||
| original_seq_len: int = 4096 | ||
| rope_theta: float = 10000.0 | ||
| rope_factor: float = 40 | ||
| beta_fast: int = 32 | ||
| beta_slow: int = 1 | ||
| mscale: float = 1.0 | ||
|
|
||
| def update_from_config(self, job_config: JobConfig, **kwargs) -> None: | ||
| seq_len = job_config.training.seq_len | ||
| if seq_len > self.max_seq_len: | ||
| logger.warning( | ||
| f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." | ||
| ) | ||
| self.max_seq_len = seq_len | ||
|
|
||
| if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): | ||
| logger.warning( | ||
| "Failed to use grouped mm, which is only supported on SM90 or later", | ||
| ) | ||
| self.moe_args.use_grouped_mm = False | ||
|
|
||
| if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: | ||
| raise NotImplementedError( | ||
| "CP support for FlexAttention is still in progress." | ||
| ) | ||
|
|
||
| self.moe_args._debug_force_load_balance = ( | ||
| job_config.debug.moe_force_load_balance | ||
| ) | ||
|
|
||
| def get_nparams_and_flops( | ||
| self, model: nn.Module, seq_len: int | ||
| ) -> tuple[int, float]: | ||
| return get_moe_model_nparams_and_flops( | ||
| self, | ||
| model, | ||
| self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim, | ||
| seq_len, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comment outdated? i do not see any .cpp files in the PR anymore.