diff --git a/torchtitan/experiments/autopartition/README.md b/torchtitan/experiments/autopartition/README.md new file mode 100644 index 0000000000..6de77cdd28 --- /dev/null +++ b/torchtitan/experiments/autopartition/README.md @@ -0,0 +1,55 @@ +# Auto-Partition in torchtitan + +## Overview + +This folder provides an automatic partitioning method that considers the computation cost of embedding layers. +This method involves calculating the floating-point operations (FLOPs) of the embedding layers and constructing an array that incorporates the FLOPs of both the transformer and embedding layers. Subsequently, a heuristic algorithm is employed to identify a balanced pipeline partition. + +## Quick Start + +### Compile + +First, we need to compile `autopipe.cpp`. +```bash +pip install pybind11 +cd ./torchtitan/experiments/autopartition/infra/cpp +mkdir build +cd build +cmake .. +make +mv *.so ../../ +``` + +The following command uses Llama 3 as an example: + +```bash +CONFIG_FILE="./torchtitan/experiments/autopartition/train_configs/debug_model.toml" ./run_train.sh +``` + +## Performance + +Hardware configuration: 4x RTX 3090 24GB, pipeline parallelism dimension is 4. + +### llama3 配置对比 +| hidden size| layers | autopipe TPS| default TPS| Speedup | +| ---------- | ---- | ---------- | -----------| ----------- | +| dim=256 | 6 | 31,094 | 29,549 | +5.2% | +| dim=256 | 12 | 21,803 | 21,923 | -0.5% | +| dim=2048 | 12 | 3,348 | 2,616 | +28.0% | +| dim=4096 | 12 | 981 | 761 | +28.9% | + +### deepseekv3(without moe) 配置对比 + +| hidden size| layers | autopipe TPS| default TPS| Speedup | +| ---------- | ---- | ---------- | -----------| ----------- | +| dim=256 | 6 | 13,373 | 13,059 | +2.4% | +| dim=256 | 12 | 7,714 | 6,859 | +12.5% | +| dim=2048 | 12 | 4,331 | 3,810 | +13.7% | +| dim=4096 | 12 | 2,888 | 2,561 | +12.8% | +| dim=4096 | 16 | 2,207 | 2,008 | +9.9% | +| dim=8192 | 16 | 4,331 | 3,935 | +10.1% | + + +### Known Issues + +- **Not Support Moe** - Auto-Partition need flops for each layers, but current profiler from deepspeed not support computing flops for moe. diff --git a/torchtitan/experiments/autopartition/__init__.py b/torchtitan/experiments/autopartition/__init__.py new file mode 100644 index 0000000000..2102ec1b38 --- /dev/null +++ b/torchtitan/experiments/autopartition/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +__all__ = [ + "get_deepseek_v3_train_spec", + "get_llama3_train_spec", +] + + +from .deepseek_v3_tain_spec import get_deepseek_v3_train_spec +from .llama3_tain_spec import get_llama3_train_spec diff --git a/torchtitan/experiments/autopartition/deepseek_v3/args.py b/torchtitan/experiments/autopartition/deepseek_v3/args.py new file mode 100644 index 0000000000..48d4b5ece1 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3/args.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass, field + +from torch import nn + +from torchtitan.config import JobConfig +from torchtitan.models.moe import MoEArgs +from torchtitan.models.utils import get_moe_model_nparams_and_flops +from torchtitan.protocols.model import BaseModelArgs +from torchtitan.tools.logging import logger +from torchtitan.tools.utils import has_cuda_capability + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class DeepSeekV3ModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + norm_eps (float): Epsilon value used for RMSNorm. + moe_args (MoEArgs): MoE configuration. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + use_flex_attn (bool): Whether to use FlexAttention. + attn_mask_type (str): Type of attention mask. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + norm_eps: float = 1e-5 # eps used for RMSNorm + + # MoE + moe_args: MoEArgs = field(default_factory=MoEArgs) + # TODO: node-limited routing is not supported yet + n_expert_groups: int = 1 + n_limited_groups: int = 1 + + # Multi-Head Latent Attention (MLA) + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + use_flex_attn: bool = False + attn_mask_type: str = "causal" + + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.moe_args.use_grouped_mm = False + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + self.moe_args._debug_force_load_balance = ( + job_config.debug.moe_force_load_balance + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_moe_model_nparams_and_flops( + self, + model, + self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim, + seq_len, + ) diff --git a/torchtitan/experiments/autopartition/deepseek_v3/model.py b/torchtitan/experiments/autopartition/deepseek_v3/model.py new file mode 100644 index 0000000000..3cf56eb1b2 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3/model.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch import nn + +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.models.moe import FeedForward, MoE +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import DeepSeekV3ModelArgs + + +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (DeepSeekV3ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. + if seqlen > args.original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, args.original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + # Create position indices + t = torch.arange(seqlen) + + # Outer product: [positions] × [frequencies] + freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__() + self.dim = model_args.dim + self.n_heads = model_args.n_heads + self.q_lora_rank = model_args.q_lora_rank + self.kv_lora_rank = model_args.kv_lora_rank + self.qk_nope_head_dim = model_args.qk_nope_head_dim + self.qk_rope_head_dim = model_args.qk_rope_head_dim + self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim + self.v_head_dim = model_args.v_head_dim + + if self.q_lora_rank == 0: + self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False) + else: + self.wq_a = nn.Linear(self.dim, self.q_lora_rank, bias=False) + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) + self.wq_b = nn.Linear( + self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False + ) + self.wkv_a = nn.Linear( + self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) + self.wkv_b = nn.Linear( + self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False) + self.softmax_scale = self.qk_head_dim**-0.5 + + if model_args.max_seq_len > model_args.original_seq_len: + mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + + # Query projection + if self.q_lora_rank == 0: + q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) + else: + q = self.wq_a(x) + q = self.wq_b(self.q_norm(q)) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of q and kv as TP may have sharded them after + # the above linear ops. + q = q.view(bsz, seqlen, -1, self.qk_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) + + # Key-value projection + kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pe = apply_rotary_emb( + k_pe.unsqueeze(2), freqs_cis + ) # (bsz, seqlen, 1, qk_rope_head_dim) + + kv = self.wkv_b( + self.kv_norm(kv) + ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) + kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat( + [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 + ) # (bsz, seqlen, n_heads, qk_head_dim) + + q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + else: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) + + # Reshape and project output + output = output.transpose( + 1, 2 + ).contiguous() # (bsz, seqlen, n_heads, v_head_dim) + output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) + return self.wo(output) # (bsz, seqlen, dim) + + def init_weights(self, init_std: float): + linear_list = [ + self.wkv_a, + self.wkv_b, + ] + if self.q_lora_rank > 0: + linear_list.extend([self.wq_a, self.wq_b]) + else: + linear_list.append(self.wq) + + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + self.kv_norm.reset_parameters() + if self.q_lora_rank > 0: + self.q_norm.reset_parameters() + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): + + super().__init__() + self.attention = Attention(model_args) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + self.moe_enabled = layer_id >= model_args.n_dense_layers + if self.moe_enabled: + self.moe = MoE( + model_args.moe_args, + dim=model_args.dim, + hidden_dim=model_args.moe_inter_dim, + ) + else: + self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + if self.moe_enabled: + x = x + self.moe(self.ffn_norm(x)) + else: + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std, buffer_device) + else: + self.feed_forward.init_weights(self.weight_init_std) + + +class DeepSeekV3Model(nn.Module, ModelProtocol): + """ + DeepSeek-V3 Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = nn.RMSNorm(model_args.dim) + self.output = nn.Linear( + model_args.dim, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + + h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis, attention_masks) + h = self.norm(h) if self.norm is not None else h + output = self.output(h) if self.output is not None else h + return output diff --git a/torchtitan/experiments/autopartition/deepseek_v3/state_dict_adapter.py b/torchtitan/experiments/autopartition/deepseek_v3/state_dict_adapter.py new file mode 100644 index 0000000000..fd4ec30284 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3/state_dict_adapter.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from typing import Any + +import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader + +from torch.distributed.tensor import DTensor +from torchtitan.models.utils import MoEStateDictAdapter + +from .args import DeepSeekV3ModelArgs + + +class DeepSeekV3StateDictAdapter(MoEStateDictAdapter): + """ + StateDictAdapter for DeepSeekV3 model. + """ + + def __init__( + self, + model_args: DeepSeekV3ModelArgs, + hf_assets_path: str | None, + ): + super().__init__(model_args, hf_assets_path) + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention Module + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attention.wkv_a.weight", + "model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attention.kv_norm.weight", + "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attention.wkv_b.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + # MLP Module + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + # Transformer Layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE Module + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.moe.experts.w1", + "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", + "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", + "model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight", + "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.moe.expert_bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + # Adjustments for from_hf_map based on model architecture + if model_args.q_lora_rank != 0: + self.from_hf_map.update( + { + "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attention.wq_a.weight", + "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attention.wq_b.weight", + } + ) + else: + self.from_hf_map.update( + { + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + } + ) + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """ + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. + """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + 1. Convert between the HF shape and the torchtitan shape. + 2. Split the GroupedExperts' weight into separate expert's weight. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + + hf_state_dict = {} + + for key, value in state_dict.items(): + if "moe.experts" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_abstract_key = to_hf_map[abstract_key] + + # Store the GroupedExperts Weight metadata for from_hf() + if isinstance(value, DTensor): + self.grouped_expert_weight_placements[ + abstract_key + ] = value.placements + self.grouped_expert_weight_shape[abstract_key] = value.shape + + # Split GroupedExperts weight to local individual expert weights + local_expert_fqn = self._get_local_experts_weights( + new_abstract_key, + abstract_key, + layer_num, + value, + ) + hf_state_dict.update(local_expert_fqn) + + else: + # keep this path for offline conversion + split_values = self._split_experts_weights( + value, self.model_args.moe_args.num_experts + ) + + for expert_num in range(0, self.model_args.moe_args.num_experts): + new_key = new_abstract_key.format(layer_num, expert_num) + hf_state_dict[new_key] = split_values[expert_num].squeeze() + + elif "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + new_key = new_key.format(layer_num) + hf_state_dict[new_key] = value + + else: + new_key = to_hf_map[key] + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. + 2. Convert between the HF shape and the torchtitan shape. + 3. Concat separate expert's weight into GroupedExperts' weight. + """ + + state_dict = {} + expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} + + for key, value in hf_state_dict.items(): + if "mlp.experts" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=2) + layer_num, expert_num = re.findall(r"\d+", key) + titan_abstract_key = self.from_hf_map[abstract_key] + new_key = titan_abstract_key.format(layer_num) + + # Store the expert's weight in expert_weights_by_layer for concatenating later. + if layer_num not in expert_weights_by_layer: + expert_weights_by_layer[layer_num] = {} + if titan_abstract_key not in expert_weights_by_layer[layer_num]: + expert_weights_by_layer[layer_num][titan_abstract_key] = {} + expert_weights_by_layer[layer_num][titan_abstract_key][ + int(expert_num) + ] = value + + if isinstance(value, DTensor): + stacked_value = self._concatenate_expert_weights_dtensor( + expert_weights_by_layer, + titan_abstract_key, + layer_num, + value.device_mesh, + ) + else: # keep this path to be compatible with offline conversion + stacked_value = self._concatenate_expert_weights( + expert_weights_by_layer, + titan_abstract_key, + layer_num, + self.model_args.moe_args.num_experts, + ) + + if stacked_value is not None: + state_dict[new_key] = stacked_value + + elif "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + new_key = new_key.format(layer_num) + state_dict[new_key] = value + + else: + new_key = self.from_hf_map[key] + state_dict[new_key] = value + + return state_dict diff --git a/torchtitan/experiments/autopartition/deepseek_v3_tain_spec.py b/torchtitan/experiments/autopartition/deepseek_v3_tain_spec.py new file mode 100644 index 0000000000..a11af94be4 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3_tain_spec.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.moe import MoEArgs +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize_deepseek_v3 import parallelize_deepseekv3 +from .infra.pipeline_parallel import pipeline_llm +from .deepseek_v3.args import DeepSeekV3ModelArgs +from .deepseek_v3.model import DeepSeekV3Model +from .deepseek_v3.state_dict_adapter import DeepSeekV3StateDictAdapter + +__all__ = [ + "parallelize_deepseekv3", + "DeepSeekV3ModelArgs", + "DeepSeekV3Model", + "deepseekv3_args", +] + + +deepseekv3_args = { + "debugmodel": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=4096, + inter_dim=1024, + moe_inter_dim=256, + n_layers=12, + n_dense_layers=12, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "debugmodel_flex_attn": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=6, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "16B": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "236B": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + moe_args=MoEArgs( + num_experts=160, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + route_scale=16.0, + score_before_experts=False, + ), + n_expert_groups=8, + n_limited_groups=3, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "671B": DeepSeekV3ModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=3, + n_heads=128, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.5, + score_before_experts=False, + ), + n_expert_groups=8, + n_limited_groups=4, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", + ), +} + + +def get_deepseek_v3_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=deepseekv3_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autopartition/infra/autopipe.py b/torchtitan/experiments/autopartition/infra/autopipe.py new file mode 100644 index 0000000000..4e0b94edf1 --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/autopipe.py @@ -0,0 +1,355 @@ +# autopipe.py +import functools +from collections import deque +from typing import Dict, List, Tuple + +COMM_OVERHEAD = 0 +MAX_INT64 = 2**63 - 1 +MAX_INT32 = 2**31 - 1 +INF = 10**18 + + +def _prefix_sum_and_dp( + model: List[int], + num_stages: int, + block_time: List[List[int]], + prefix_sum: List[int], + dp: List[List[int]], +): + """C++ calculate_prefix_sum_and_dp""" + num_blocks = len(model) + max_parts = min(num_blocks, num_stages) + + # sum prefix + prefix_sum.clear() + prefix_sum.append(0) + for b in model: + t = block_time[0][b] + block_time[1][b] + prefix_sum.append(prefix_sum[-1] + t) + + # DP sheet + dp.clear() + for _ in range(num_blocks + 1): + dp.append([MAX_INT64] * (max_parts + 1)) + dp[0][0] = 0 + + for blocks in range(1, num_blocks + 1): + max_p = min(blocks, max_parts) + for parts in range(1, max_p + 1): + best = MAX_INT64 + for prev in range(blocks): + cur = max(dp[prev][parts - 1], prefix_sum[blocks] - prefix_sum[prev]) + best = min(best, cur) + if best == 0: + break + dp[blocks][parts] = best + + +def _reconstruct( + model: List[int], + prefix_sum: List[int], + dp: List[List[int]], + rem_blocks: int, + rem_parts: int, + out: List[List[int]], +): + """C++ reconstruct_partitions""" + if rem_blocks == 0 and rem_parts == 0: + return + if rem_blocks <= 0 or rem_parts <= 0 or rem_blocks < rem_parts: + raise RuntimeError("Error during partition reconstruction") + + prev_end = 0 + while prev_end < rem_blocks: + lhs = dp[prev_end][rem_parts - 1] + rhs = prefix_sum[rem_blocks] - prefix_sum[prev_end] + if dp[rem_blocks][rem_parts] == max(lhs, rhs): + break + prev_end += 1 + + chunk = [model[i] for i in range(prev_end, rem_blocks)] + out.append(chunk) + _reconstruct(model, prefix_sum, dp, prev_end, rem_parts - 1, out) + + +def _block_partition_algo( + model: List[int], num_stages: int, block_time: List[List[int]] +) -> List[List[int]]: + """C++ block_partition_algorithm""" + prefix_sum: List[int] = [] + dp: List[List[int]] = [] + _prefix_sum_and_dp(model, num_stages, block_time, prefix_sum, dp) + + parts: List[List[int]] = [] + _reconstruct(model, prefix_sum, dp, len(model), num_stages, parts) + parts.reverse() + return parts + + +# ---------- Time Cal ---------- +def _calc_stage_times( + partition: List[List[int]], + block_time: List[List[int]], + fwd: List[int], + bwd: List[int], + last_mb: List[int], +): + """C++ calculate_stage_times""" + num_stages = len(partition) + num_micro = num_stages * 2 + for i in range(num_stages): + last_mb[i] = num_micro - num_stages + i + + for i in range(1, num_stages + 1): + fwd_sum = sum(block_time[0][b] for b in partition[i - 1]) + bwd_sum = sum(block_time[1][b] for b in partition[i - 1]) + fwd[i] = fwd_sum + bwd[i] = bwd_sum + + +def _steady_phase( + last_mb: List[int], fwd: List[int], bwd: List[int] +) -> Tuple[int, int]: + """C++ calculate_steady_phase""" + num_stages = len(last_mb) + num_micro = num_stages * 2 + + dp = [[[0, 0] for _ in range(num_micro)] for __ in range(num_stages + 2)] + + # init + init_bwd = 0 + for s in range(num_stages): + init_bwd += fwd[s + 1] + if s != num_stages - 1: + init_bwd += COMM_OVERHEAD + for s in range(num_stages - 1, -1, -1): + dp[s + 1][0][0] = MAX_INT64 + dp[s + 1][0][1] = init_bwd + init_bwd += bwd[s + 1] + COMM_OVERHEAD + + for mb in range(1, num_micro): + # forward + for s in range(num_stages): + if mb <= last_mb[s]: + val = max(dp[s][mb - 1][0] + fwd[s], dp[s + 1][mb - 1][1] + bwd[s + 1]) + if s != 0: + val += COMM_OVERHEAD + dp[s + 1][mb][0] = val + # backward + for s in range(num_stages - 1, -1, -1): + if mb <= last_mb[s]: + val = max(dp[s + 2][mb][1] + bwd[s + 2], dp[s + 1][mb][0] + fwd[s + 1]) + if s != num_stages - 1: + val += COMM_OVERHEAD + dp[s + 1][mb][1] = val + + # find critical path + critical = num_stages - 1 + while critical >= 0: + ok = True + for mb in range(1, last_mb[critical] + 1): + fcomm = COMM_OVERHEAD if critical != 0 else 0 + bcomm = COMM_OVERHEAD if critical != num_stages - 1 else 0 + if ( + dp[critical + 1][mb][0] + != dp[critical + 1][mb - 1][1] + bwd[critical + 1] + fcomm + ): + ok = False + break + if ( + dp[critical + 1][mb][1] + != dp[critical + 1][mb][0] + fwd[critical + 1] + bcomm + ): + ok = False + break + if ok: + break + critical -= 1 + + if critical < 0: + # Backstop: The stage that finishes the last micro-batch is the critical stage. + _, best_time = 0, -1 + for s in range(num_stages): + t = dp[s + 1][last_mb[s]][1] # Finished time of the last backward + if t > best_time: + best_time, critical = t, s + return dp[critical + 1][last_mb[critical]][0], critical + + +def _cooldown( + num_stages: int, critical: int, last_fwd_start: int, fwd: List[int], bwd: List[int] +) -> int: + """C++ calculate_cooldown_phase""" + sz = num_stages - critical + if sz <= 0: + return last_fwd_start + + dp = [[0] * sz for _ in range(sz)] + bwd_start = last_fwd_start + for i in range(sz): + bwd_start += fwd[critical + 1 + i] + if critical + i != num_stages - 1: + bwd_start += COMM_OVERHEAD + dp[i][sz - 1 - i] = bwd_start + + for col in range(sz - 2, -1, -1): + for row in range(sz - col - 2, -1, -1): + o1 = dp[row][col + 1] + bwd[critical + 1 + row] + COMM_OVERHEAD + o2 = dp[row + 1][col] + bwd[critical + 1 + row + 1] + COMM_OVERHEAD + dp[row][col] = max(o1, o2) + if row > 0: + dp[row][col] = max(dp[row][col], dp[row - 1][col + 1]) + return dp[0][0] + + +def _training_time( + partition: List[List[int]], block_time: List[List[int]] +) -> Tuple[int, int]: + """C++ calculate_training_time""" + num_stages = len(partition) + last_mb = [0] * num_stages + fwd = [0] * (num_stages + 2) + bwd = [0] * (num_stages + 2) + + # 计算阶段时间 + for i in range(num_stages): + last_mb[i] = num_stages * 2 - num_stages + i + fwd[i + 1] = sum(block_time[0][b] for b in partition[i]) + bwd[i + 1] = sum(block_time[1][b] for b in partition[i]) + + steady_time, critical = _steady_phase(last_mb, fwd, bwd) + if steady_time == MAX_INT64: + raise RuntimeError("Failed to calculate steady phase") + + last_bwd_start = _cooldown(num_stages, critical, steady_time, fwd, bwd) + flush = last_bwd_start + for stage in range(critical, 0, -1): + flush += bwd[stage + 1] + COMM_OVERHEAD + flush += bwd[1] + return flush, critical + + +# ---------- 最优分区搜索 ---------- +def _find_best( + block_time: List[List[int]], + num_stages: int, + init_partition: List[List[int]], + prefix_sum: List[int], + dp: List[List[int]], +) -> Dict: + """ + C++ find_best_partition + return: {"partition": [[...], ...], "cost": int, "critical_stage": int} + """ + + # Hash func: C++ VectorHash + @functools.lru_cache(maxsize=None) + def _hash(p): + h = 0 + for inner in p: + for v in inner: + h ^= (v + 0x9E3779B9) + (h << 6) + (h >> 2) + return h + + # C++ VectorEqual + def _eq(a, b): + return len(a) == len(b) and all( + len(ai) == len(bi) and all(av == bv for av, bv in zip(ai, bi)) + for ai, bi in zip(a, b) + ) + + visited = set() + queue = deque([init_partition]) + visited.add(_hash(tuple(tuple(r) for r in init_partition))) + + # Best result + best = {"partition": init_partition, "cost": INF, "critical_stage": MAX_INT32} + + while queue: + cur = queue.popleft() + + # Current time of partition. + last_mb = [0] * num_stages + fwd = [0] * (num_stages + 2) + bwd = [0] * (num_stages + 2) + _calc_stage_times(cur, block_time, fwd, bwd, last_mb) + cost, critical = _training_time(cur, block_time) + + # update best + if cost < best["cost"]: + best = {"partition": cur, "cost": cost, "critical_stage": critical} + + # If the critical path is not in the first segment, + # try re-partitioning all the blocks before the critical stage. + if critical > 0: + # Collect all blocks before (and including the first block of) the critical stage. + blocks_before = [] + for stage in range(critical): + blocks_before.extend(cur[stage]) + blocks_before.append(cur[critical][0]) + + # Redo the critical stage partitioning for these blocks with the same DP. + model_before = blocks_before + new_parts: List[List[int]] = [] + _reconstruct( + model_before, prefix_sum, dp, len(model_before), critical, new_parts + ) + new_parts.reverse() + blocks_before.pop() + + full_new = new_parts + # Put the remaining blocks of the critical segment back. + if len(full_new) <= critical: + full_new.append([]) + full_new[critical].extend(cur[critical][1:]) + + for stage in range(critical + 1, len(cur)): + full_new.append(cur[stage]) + + # "Hash deduplication + key = _hash(tuple(tuple(r) for r in full_new)) + if key not in visited: + visited.add(key) + queue.append(full_new) + + return best + + +# ---------- main ---------- +def pipeline( + forward_times: List[int], backward_times: List[int], num_stages: int +) -> List[int]: + + if not forward_times or not backward_times: + raise ValueError("Input vectors cannot be empty") + if len(forward_times) != len(backward_times): + raise ValueError("Forward and backward vectors must have same size") + if num_stages <= 0 or num_stages > len(forward_times): + raise ValueError("Invalid number of pipeline stages") + + block_time = [forward_times, backward_times] + model = list(range(len(forward_times))) + + init_partition = _block_partition_algo(model, num_stages, block_time) + prefix_sum: List[int] = [] + dp: List[List[int]] = [] + _prefix_sum_and_dp(model, num_stages, block_time, prefix_sum, dp) + + best = _find_best(block_time, num_stages, init_partition, prefix_sum, dp) + + result = [stage[0] for stage in best["partition"]] + return result + + +if __name__ == "__main__": + import traceback + + try: + fwd_flops = [10, 20, 30, 15, 25] # five block + bwd_flops = [10, 20, 30, 15, 25] + for stages in 1, 2, 3, 4, 5: + test_out = pipeline(fwd_flops, bwd_flops, stages) + print(f"stages={stages}, result={test_out}, len={len(test_out)}") + + except Exception as e: + traceback.print_exc() diff --git a/torchtitan/experiments/autopartition/infra/parallelize_deepseek_v3.py b/torchtitan/experiments/autopartition/infra/parallelize_deepseek_v3.py new file mode 100644 index 0000000000..0793820ffd --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/parallelize_deepseek_v3.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.llama4.infra.parallelize import ( + apply_compile, + apply_fsdp, + apply_moe_ep_tp, +) +from torchtitan.tools.logging import logger + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_to_all_single.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +# Adapted from llama4/infra/parallelize.py +def parallelize_deepseekv3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for deepseekv3" + ) + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + use_flex_attn=use_flex_attn, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, + ) + + if model_compile_enabled: + apply_compile(model, job_config.compile) + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.ep_enabled + else None + ), + gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + dp_mesh = world_mesh + apply_ddp( + model, + dp_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + use_flex_attn: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + if use_flex_attn: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + else: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), Replicate(), None), + desired_input_layouts=(Replicate(), Replicate(), None), + ), + # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor + # so that the intermedidate results k is generated as a DTensor and its gradient is + # correctly handled by the autograd engine. + "attention.wkv_a": NoParallel(use_local_output=False), + "attention.wkv_b": colwise_parallel(use_local_output=False), + "attention.kv_norm": NoParallel(use_local_output=False), + # NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors + "attention.inner_attention": attention_kernel_plan, + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + if transformer_block.attention.q_lora_rank == 0: + layer_plan.update( + { + "attention.wq": colwise_parallel( + use_local_output=False + ), # This is only used when q_lora_rank==0 + } + ) + else: + layer_plan.update( + { + "attention.wq_a": NoParallel(use_local_output=False), + "attention.wq_b": colwise_parallel(use_local_output=False), + "attention.q_norm": NoParallel(use_local_output=False), + } + ) + + if not transformer_block.moe_enabled: + layer_plan.update( + { + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) diff --git a/torchtitan/experiments/autopartition/infra/parallelize_llama.py b/torchtitan/experiments/autopartition/infra/parallelize_llama.py new file mode 100644 index 0000000000..86ac3a6dfe --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/parallelize_llama.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +import torch +import torch.nn as nn +from torch.distributed._composable.replicate import replicate + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config.job_config import Compile as CompileConfig +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.tools.logging import logger + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +def parallelize_llama( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, + ) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if model_compile_enabled: + apply_compile(model, job_config.compile) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_compile(model: nn.Module, compile_config: CompileConfig): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile( + transformer_block, backend=compile_config.backend, fullgraph=True + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.layers.items(): + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + fully_shard(model, **fsdp_config) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, +): + if enable_compile: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/torchtitan/experiments/autopartition/infra/pipeline_parallel.py b/torchtitan/experiments/autopartition/infra/pipeline_parallel.py new file mode 100644 index 0000000000..fb59b2d822 --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/pipeline_parallel.py @@ -0,0 +1,647 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +import math +import os +from typing import Callable + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + _PipelineScheduleRuntime, + get_schedule_class, + PipelineScheduleMulti, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, +) +from torch.utils.flop_counter import FlopCounterMode +from torchtitan.components.loss import LossFunction, rescale_accumulated_loss +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac + +from torchtitan.experiments.autopartition.infra.autopipe import pipeline +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger + +# from torchtitan/tests/unit_tests/test_activation_checkpoint.py +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + +__all__ = [ + "pipeline_llm", + "build_pipeline_schedule", + "generate_llm_fqn_per_model_part", + "pipeline_module_split", +] + +def layerwise_flops(model, x, backward=True): + """Return forward and backward FLOPs (float) for each layer of the model.""" + fwd_mflops, bwd_mflops = [], [] + + for layer_idx, layer in enumerate(model): + # forward + with FlopCounterMode(display=False) as mode: + if isinstance(x, torch.Tensor): + x_new = layer(x) + else: + x_new = layer(*x) + fwd_mflops.append(round(mode.get_total_flops() / 1e6)) + + # backward + if backward and layer_idx != 0: + + def layer_scalar(y): + return ( + y.sum() if isinstance(y, torch.Tensor) else sum(o.sum() for o in y) + ) + + y = x_new.clone() + with FlopCounterMode(display=False) as mode: + layer_scalar(y).backward() + + bwd_mflops.append(round(mode.get_total_flops() / 1e6)) + else: + bwd_mflops.append(0) + + x = x_new.detach().requires_grad_(True) + + return fwd_mflops, bwd_mflops + + +def autopipe_partition(model: nn.Module, num_stages: int, job_config: JobConfig): + """Partition layers based on automatic pipeline profiling. + + This method profiles each layer's computational cost (FLOPS) and + distributes layers to balance computation across stages. + + Args: + model (nn.Module): The neural network model to be partitioned. + num_stages (int): Number of pipeline stages to partition the model into. + job_config (JobConfig): The job configuration. + + Returns: + List of integers representing the number of layers assigned to each stage. + """ + + # Prepare input for profiling + tokenizer = build_hf_tokenizer(job_config) + + # build dataloader + dataloader = build_text_dataloader( + dp_world_size=1, + dp_rank=0, + tokenizer=tokenizer, + job_config=job_config, + ) + iterator = iter(dataloader) + inputs = list(next(iterator)[0].values()) + + mflops_fwd, mflops_bwd = layerwise_flops(model, inputs) + + logger.info(f"Autopipe partitioning with mflops: {mflops_fwd}, {mflops_bwd}") + + # Partition layers by forward and backward flops + parts = pipeline( + mflops_fwd, + mflops_bwd, + num_stages, + ) + parts.append(len(model)) # Add the total number of layers + return parts + + +def _build_module_for_profile(model, flatten_module_names): + # txd: merge autopipe + module_names_for_profile = [[item] for item in flatten_module_names] + + def _build_sequential_module( + module_names: list[str], + ) -> tuple[PipelineStage, nn.Module]: + + # Create a set of modules to keep for faster lookup + # modules_to_keep = set(module_names) + module_seq = nn.Sequential() + for mtk in module_names: + whole_model = copy.deepcopy(model) + modules_to_keep = set(mtk) + for module_name, module_value in whole_model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(whole_model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(whole_model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(whole_model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(whole_model, module_name, None) + module_seq.append(copy.deepcopy(whole_model)) + return module_seq + + seq_module = _build_sequential_module(module_names_for_profile) + + return seq_module + + +def pipeline_llm( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.world_mesh["pp"] + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + + # use auto_partition + flatten_module_names = [ + item for sublist in module_names_per_stage for item in sublist + ] + + copied_model = copy.deepcopy(model) + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + apply_ac( + copied_model, + job_config.activation_checkpoint, + model_compile_enabled=job_config.compile.enable, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + ) + seq_modules = _build_module_for_profile(copied_model, flatten_module_names) + + parts = autopipe_partition(seq_modules, parallel_dims.pp, job_config) + module_names_per_stage = [ + flatten_module_names[parts[i] : parts[i + 1]] for i in range(parallel_dims.pp) + ] + del copied_model, seq_modules + + for i, stage_ms in enumerate(module_names_per_stage): + logger.debug(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def build_pipeline_schedule( + job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable +) -> _PipelineSchedule: + """Builds a pipeline schedule for the given job configuration and stages. + + Args: + job_config (JobConfig): The job configuration. + stages (list[PipelineStage]): The stages to be scheduled. + loss_fn (Callable): The loss function. + + Returns: + _PipelineSchedule: The pipeline schedule for the given stages. + """ + pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv + + # Validate that pp_schedule_csv is a valid path + if pp_schedule_csv: + if not os.path.isfile(pp_schedule_csv): + raise FileNotFoundError( + f"The specified path {pp_schedule_csv} does not exist or is not a file." + ) + schedule_class = _PipelineScheduleRuntime + else: + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + + looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) + microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size + batch_size = job_config.training.local_batch_size + # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training + if batch_size % microbatch_size != 0: + raise ValueError( + f"Batch size {job_config.training.local_batch_size} must be divisible by microbatch_size {microbatch_size}. " + "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." + ) + n_microbatches = batch_size // microbatch_size + # We expect that the number of local stages (`len(stages)`) is the same across all ranks + num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages) + if n_microbatches < num_total_stages: + logger.warning( + f"Number of microbatches ({n_microbatches}) is less than the total number " + f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." + ) + + schedule = schedule_class( + stages if looped_schedule else stages[0], + n_microbatches=n_microbatches, + loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), + scale_grads=False, + ) + logger.info( + f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} " + f"with {n_microbatches} microbatches and {num_total_stages} stages." + ) + + if pp_schedule_csv: + assert schedule_class in [ + PipelineScheduleSingle, + PipelineScheduleMulti, + _PipelineScheduleRuntime, + ], ( + "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " + "and _PipelineScheduleRuntime support csv schedules" + ) + schedule._load_csv(pp_schedule_csv) + + return schedule + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (tok_embeddings) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + + Returns: + List of lists containing module names for each model part + + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_degree = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + + def _get_stage_indices() -> tuple[int]: + """ + Compute the stage ids for the stages that will run on this pp rank + for either a looped or V style schedule + """ + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return stage_v_pairs[pp_rank] + + for stage_idx in _get_stage_indices(): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models diff --git a/torchtitan/experiments/autopartition/job_config.py b/torchtitan/experiments/autopartition/job_config.py new file mode 100644 index 0000000000..063a23b905 --- /dev/null +++ b/torchtitan/experiments/autopartition/job_config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class CustomConfig: + auto_partition: bool = True + """Whether to use autopartition method to split module, default False""" + +@dataclass +class JobConfig: + custom_config: CustomConfig = field(default_factory=CustomConfig) diff --git a/torchtitan/experiments/autopartition/llama3/args.py b/torchtitan/experiments/autopartition/llama3/args.py new file mode 100644 index 0000000000..d83fb83102 --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3/args.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass, field + +from torch import nn + +from torchtitan.config import JobConfig +from torchtitan.models.utils import get_dense_model_nparams_and_flops +from torchtitan.protocols.model import BaseModelArgs +from torchtitan.tools.logging import logger + + +@dataclass +class RoPEScalingArgs: + scaling_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + +@dataclass +class TransformerModelArgs(BaseModelArgs): + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int = 128256 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + rope_scaling_args: RoPEScalingArgs = field(default_factory=RoPEScalingArgs) + + max_seq_len: int = 131072 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_dense_model_nparams_and_flops( + self, + model, + 2 * (self.dim // self.n_heads), + seq_len, + ) diff --git a/torchtitan/experiments/autopartition/llama3/model.py b/torchtitan/experiments/autopartition/llama3/model.py new file mode 100644 index 0000000000..124153f14c --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3/model.py @@ -0,0 +1,503 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import RoPEScalingArgs, TransformerModelArgs + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + scaling_args: RoPEScalingArgs = RoPEScalingArgs(), +) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. + scaling_args (RoPEScalingArgs | None): RoPE scaling arguments. Defaults to None. + scaling_factor (float): RoPE scaling multiplier; larger values + stretch positions to support longer contexts. Defaults to 8.0. + low_freq_factor (float): Extra scaling applied to the low-frequency + (long-wavelength) RoPE bands. Defaults to 1.0. + high_freq_factor (float): Extra scaling applied to the high-frequency + (short-wavelength) RoPE bands. Defaults to 4.0. + original_max_position_embeddings (int): Maximum position embeddings + for original model. Defaults to 8192. + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + # apply rope scaling + scaling_factor = scaling_args.scaling_factor + low_freq_factor = scaling_args.low_freq_factor + high_freq_factor = scaling_args.high_freq_factor + original_max_position_embeddings = scaling_args.original_max_position_embeddings + wavelen = 2 * math.pi / freqs + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by scaling factor + freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling_factor, freqs) + # wavelen in between: linear interpolation of the scaled freqs and the original freqs + smooth_factor = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_freqs = ( + 1 - smooth_factor + ) * freqs / scaling_factor + smooth_factor * freqs + is_medium_freqs = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + freqs = torch.where(is_medium_freqs, smoothed_freqs, freqs) + + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + assert ( + isinstance(attention_masks, BlockMask) or attention_masks is None + ), attention_masks + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float | None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module, ModelProtocol): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (Linear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + self.model_args.rope_scaling_args, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis, attention_masks=attention_masks) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output diff --git a/torchtitan/experiments/autopartition/llama3/state_dict_adapter.py b/torchtitan/experiments/autopartition/llama3/state_dict_adapter.py new file mode 100644 index 0000000000..2c386ece0d --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3/state_dict_adapter.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +from typing import Any + +logger = logging.getLogger() + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import TransformerModelArgs + + +class Llama3StateDictAdapter(StateDictAdapter): + def __init__( + self, + model_args: TransformerModelArgs, + hf_assets_path: str | None, + ): + super().__init__(model_args, hf_assets_path) + + self.model_args = model_args + self.hf_assets_path = hf_assets_path + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + # HuggingFace permutation function (exact copy from their conversion script) + def _permute(self, w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return ( + w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + .clone() + ) + + def _reverse_permute(self, w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return ( + w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + + n_heads = self.model_args.n_heads + n_kv_heads = ( + self.model_args.n_kv_heads + if self.model_args.n_kv_heads is not None + else n_heads + ) + dim = self.model_args.dim + head_dim = dim // n_heads + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + # We need to permute the weights in wq and wk layer in order to account for the difference between + # the native Llama and huggingface RoPE implementation. + if abstract_key == "layers.{}.attention.wq.weight": + value = self._permute(value, n_heads) + if abstract_key == "layers.{}.attention.wk.weight": + key_value_dim = head_dim * n_kv_heads + value = self._permute(value, n_kv_heads, key_value_dim, dim) + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = to_hf_map[key] + + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + n_heads = self.model_args.n_heads + n_kv_heads = ( + self.model_args.n_kv_heads + if self.model_args.n_kv_heads is not None + else n_heads + ) + dim = self.model_args.dim + head_dim = dim // n_heads + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + + # We need to permute the weights in wq and wk layer in order to account for the difference between + # the native Llama and huggingface RoPE implementation. + if abstract_key == "model.layers.{}.self_attn.q_proj.weight": + value = self._reverse_permute(value, n_heads) + if abstract_key == "model.layers.{}.self_attn.k_proj.weight": + key_value_dim = head_dim * n_kv_heads + value = self._reverse_permute(value, n_kv_heads, key_value_dim, dim) + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = self.from_hf_map[key] + + state_dict[new_key] = value + return state_dict diff --git a/torchtitan/experiments/autopartition/llama3_tain_spec.py b/torchtitan/experiments/autopartition/llama3_tain_spec.py new file mode 100644 index 0000000000..785b46991a --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3_tain_spec.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize_llama import parallelize_llama +from .infra.pipeline_parallel import pipeline_llm +from .llama3.args import TransformerModelArgs +from .llama3.model import Transformer +from .llama3.state_dict_adapter import Llama3StateDictAdapter + +__all__ = [ + "parallelize_llama", + "TransformerModelArgs", + "Transformer", + "llama3_args", +] + + +llama3_args = { + "debugmodel": TransformerModelArgs( + dim=256, n_layers=12, n_heads=16, vocab_size=2048, rope_theta=500000 + ), + "debugmodel_flex_attn": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2048, + rope_theta=500000, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "8B": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": TransformerModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), + "405B": TransformerModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), +} + + +def get_llama3_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Transformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autopartition/train.py b/torchtitan/experiments/autopartition/train.py new file mode 100644 index 0000000000..d8d0c4a9c4 --- /dev/null +++ b/torchtitan/experiments/autopartition/train.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.ft import FTManager +from torchtitan.components.loss import rescale_accumulated_loss +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) +from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import utils as dist_utils +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer +from . import ( # noqa: F401 # type: ignore + get_deepseek_v3_train_spec, + get_llama3_train_spec, +) + + +class AotoPartitionTrainer(Trainer): + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: JobConfig): + torch._C._log_api_usage_once("torchtitan.train") + + self.job_config = job_config + + logger.info(f"Starting job: {job_config.job.description}") + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + # Device has to be set before creating TorchFT manager. + device_module.set_device(self.device) + + job_config.maybe_log() + + # init distributed and build meshes + self.parallel_dims = parallel_dims = self.init_distributed() + + world_mesh = parallel_dims.world_mesh + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + + self.ft_manager = FTManager(job_config.fault_tolerance) + dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + + # take control of garbage collection to avoid stragglers + self.gc_handler = utils.GarbageCollection( + gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug + ) + + # Set random seed, and maybe enable deterministic mode + # (mainly for debugging, expect perf loss). + dist_utils.set_determinism( + world_mesh, + self.device, + job_config.debug, + distinct_seed_mesh_dims=["pp"], + ) + self.train_spec = get_llama3_train_spec() + + # build tokenizer and dataloader + self.tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) + + self.dataloader = self.train_spec.build_dataloader_fn( + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + + # build model (using meta init) + model_args = self.train_spec.model_args[job_config.model.flavor] + # set the model args from training job configs + model_args.update_from_config(job_config) + self.model_args = model_args + + logger.info( + f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" + ) + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): + model = self.train_spec.model_cls(model_args) + + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) + + # metrics logging + build_metrics_processor_fn = ( + build_metrics_processor + if self.train_spec.build_metrics_processor_fn is None + else self.train_spec.build_metrics_processor_fn + ) + self.metrics_processor = build_metrics_processor_fn( + job_config, parallel_dims, model_args + ) + color = self.metrics_processor.color + + # calculate model size and flops per token + ( + model_param_count, + self.metrics_processor.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + + logger.info( + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) + + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = device_type + else: + init_device = device_type + buffer_device = None + + self.loss_fn = self.train_spec.build_loss_fn( + job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager + ) + + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + ) + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * dp_degree + ) + assert self.gradient_accumulation_steps > 0 + self.loss_fn = rescale_accumulated_loss( + self.loss_fn, self.gradient_accumulation_steps + ) + + # apply parallelisms and initialization + if parallel_dims.pp_enabled: + if not self.train_spec.pipelining_fn: + raise RuntimeError( + f"Pipeline Parallel is enabled but {job_config.model.name} " + f"does not support pipelining" + ) + + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques + ( + self.pp_schedule, + self.model_parts, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) = self.train_spec.pipelining_fn( + model, + parallel_dims, + job_config, + self.device, + model_args, + self.train_spec.parallelize_fn, + self.loss_fn, + ) + # when PP is enabled, `model` obj is no longer used after this point, + # model_parts is used instead + del model + + for m in self.model_parts: + m.to_empty(device=init_device) + with torch.no_grad(): + m.init_weights(buffer_device=buffer_device) + m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + + model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights(buffer_device=buffer_device) + model.train() + + self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) + + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = self.metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + # build optimizer after applying parallelisms to the model + self.optimizers = self.train_spec.build_optimizers_fn( + self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.optimizers, job_config.lr_scheduler, job_config.training.steps + ) + # Post optimizer step model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # where it issues a single all-reduce for all parameters at once for better performance + self.optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers + self.metrics_processor.model_parts = self.model_parts + + # Initialize trainer states that will be saved in checkpoint. + # These attributes must be initialized before checkpoint loading. + self.step = 0 + self.ntokens_seen = 0 + + self.checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + self.train_context = dist_utils.get_train_context(loss_parallel_enabled) + self.maybe_enable_amp = dist_utils.maybe_enable_amp( + parallel_dims, + job_config.training.mixed_precision_param, + device_type, + ) + + # Build validator if validation is configured + if job_config.validation.enable: + assert self.train_spec.build_validator_fn is not None + + pp_schedule, pp_has_first_stage, pp_has_last_stage = ( + ( + self.pp_schedule, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) + if parallel_dims.pp_enabled + else (None, None, None) + ) + + self.validator = self.train_spec.build_validator_fn( + job_config=job_config, + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=self.tokenizer, + parallel_dims=parallel_dims, + loss_fn=self.loss_fn, + validation_context=self.train_context, + maybe_enable_amp=self.maybe_enable_amp, + metrics_processor=self.metrics_processor, + pp_schedule=pp_schedule, + pp_has_first_stage=pp_has_first_stage, + pp_has_last_stage=pp_has_last_stage, + ) + + logger.info( + "Trainer is initialized with " + f"local batch size {job_config.training.local_batch_size}, " + f"global batch size {global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.lr_scheduler.warmup_steps})" + ) + + +def main(trainer_class: type[Trainer]) -> None: + """Main entry point for training with a specified trainer class. + + Args: + trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer) + """ + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + trainer: Trainer | None = None + + try: + trainer = trainer_class(config) + + if config.checkpoint.create_seed_checkpoint: + assert ( + int(os.environ["WORLD_SIZE"]) == 1 + ), "Must create seed checkpoint using a single device, to disable sharding." + assert ( + config.checkpoint.enable + ), "Must enable checkpointing when creating a seed checkpoint." + trainer.checkpointer.save(curr_step=0, last_step=True) + logger.info("Created seed checkpoint") + else: + trainer.train() + except Exception: + if trainer: + trainer.close() + raise + else: + trainer.close() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info("Process group destroyed") + + +if __name__ == "__main__": + main(AotoPartitionTrainer) diff --git a/torchtitan/experiments/autopartition/train_configs/debug_model.toml b/torchtitan/experiments/autopartition/train_configs/debug_model.toml new file mode 100644 index 0000000000..1b0cd08848 --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/debug_model.toml @@ -0,0 +1,81 @@ +[job] +dump_folder = "./outputs" +description = "Llama 3 debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 32 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +pipeline_parallel_microbatch_size = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/autopartition/train_configs/debug_model_deepseekv3.toml b/torchtitan/experiments/autopartition/train_configs/debug_model_deepseekv3.toml new file mode 100644 index 0000000000..09cf0e5ac7 --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/debug_model_deepseekv3.toml @@ -0,0 +1,79 @@ +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/torchtitan/experiments/autopartition/train_configs/llama3_405b.toml b/torchtitan/experiments/autopartition/train_configs/llama3_405b.toml new file mode 100644 index 0000000000..48c669e404 --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/llama3_405b.toml @@ -0,0 +1,70 @@ +# NOTE: this toml config is a preset for 128 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 405B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "405B" +hf_assets_path = "./assets/hf/Llama-3.1-405B" +converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-5 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps + +[training] +local_batch_size = 2 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +enable_async_tensor_parallel = true +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[compile] +enable=true +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = true +precompute_float8_dynamic_scale_for_fsdp = true +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 diff --git a/torchtitan/experiments/autopartition/train_configs/llama3_70b.toml b/torchtitan/experiments/autopartition/train_configs/llama3_70b.toml new file mode 100644 index 0000000000..37fd35b5cb --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/llama3_70b.toml @@ -0,0 +1,69 @@ +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 70B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "70B" +hf_assets_path = "./assets/hf/Llama-3.1-70B" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 1.5e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps + +[training] +local_batch_size = 8 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 1000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "full" + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 diff --git a/torchtitan/experiments/autopartition/train_configs/llama3_8b.toml b/torchtitan/experiments/autopartition/train_configs/llama3_8b.toml new file mode 100644 index 0000000000..ef86d783bf --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/llama3_8b.toml @@ -0,0 +1,70 @@ +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 8B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "8B" +hf_assets_path = "./assets/hf/Llama-3.1-8B" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 3e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up + +[training] +local_batch_size = 1 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 1000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[compile] +enable=false +components = ["model", "loss"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192