diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 69273654e3..760022a59c 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -225,9 +225,11 @@ def apply_non_moe_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None), - desired_input_layouts=(Replicate(), Replicate(), None), + input_layouts=(Shard(1), Replicate(), None, None), + desired_input_layouts=(Replicate(), Replicate(), None, None), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 7d7635a4ad..5b17ad0acf 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -126,20 +126,71 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: return freqs_cis -def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Applies rotary positional embeddings to the input tensor. Args: x (torch.Tensor): Input tensor with positional embeddings to be applied. freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Tensor with rotary embeddings applied. """ dtype = x.dtype x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + freqs_cis = reshape_for_broadcast(freqs_cis, x, positions) y = torch.view_as_real(x * freqs_cis).flatten(3) return y.to(dtype) @@ -196,6 +247,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -203,6 +255,8 @@ def forward( Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor with the same shape as the input. @@ -222,7 +276,7 @@ def forward( q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) - q_pe = apply_rotary_emb(q_pe, freqs_cis) + q_pe = apply_rotary_emb(q_pe, freqs_cis, positions) q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) # Key-value projection @@ -230,7 +284,7 @@ def forward( kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = apply_rotary_emb( - k_pe.unsqueeze(2), freqs_cis + k_pe.unsqueeze(2), freqs_cis, positions ) # (bsz, seqlen, 1, qk_rope_head_dim) kv = self.wkv_b( @@ -312,6 +366,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Transformer block. @@ -319,11 +374,15 @@ def forward( Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + x = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: @@ -413,6 +472,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Transformer model. @@ -422,6 +482,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). @@ -430,7 +492,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h = layer(h, self.freqs_cis, attention_masks, positions) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 1c381883b1..432237996f 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -207,9 +207,11 @@ def apply_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), None, None), - desired_input_layouts=(Replicate(), None, None), + input_layouts=(Shard(1), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 74b862bf76..8982fcca9f 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -88,19 +88,23 @@ def precompute_freqs_cis( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), and the first seqlen elements will be sliced, but dim must match x. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. @@ -108,16 +112,35 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten ndim = x.ndim assert ndim > 1 seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -131,13 +154,14 @@ def apply_rotary_emb( xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -213,6 +237,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. @@ -220,6 +245,8 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed frequency tensor. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -236,7 +263,7 @@ def forward( xk = xk.view(bs, seqlen, -1, self.head_dim) xv = xv.view(bs, seqlen, -1, self.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -360,6 +387,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -367,12 +395,16 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + h = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -519,6 +551,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -528,6 +561,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -537,7 +572,9 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + h = layer( + h, self.freqs_cis, attention_masks=attention_masks, positions=positions + ) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 28418d842e..b9690edeb6 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -242,9 +242,11 @@ def apply_non_moe_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), None, None), - desired_input_layouts=(Replicate(), None, None), + input_layouts=(Shard(1), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 6b9d2d2d9e..7c4f073e19 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -86,19 +86,23 @@ def precompute_freqs_cis( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), and the first seqlen elements will be sliced, but dim must match x. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. @@ -106,16 +110,35 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten ndim = x.ndim assert ndim > 1 seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -129,13 +152,14 @@ def apply_rotary_emb( xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -219,6 +243,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. @@ -226,6 +251,8 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed frequency tensor. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -243,7 +270,7 @@ def forward( xv = xv.view(bs, seqlen, -1, self.head_dim) if self.use_rope: - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -393,6 +420,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -400,12 +428,16 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + h = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) if self.moe_enabled: out = h + self.moe(self.ffn_norm(h)) else: @@ -540,6 +572,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -549,6 +582,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -558,7 +593,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h = layer(h, self.freqs_cis, attention_masks, positions) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 12aca42777..dbcfea4edc 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -241,9 +241,11 @@ def apply_non_moe_tp( for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None), - desired_input_layouts=(Replicate(), Replicate(), None), + input_layouts=(Shard(1), Replicate(), None, None), + desired_input_layouts=(Replicate(), Replicate(), None, None), ), "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index fa8fd454b1..62b5d0c381 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -57,7 +57,9 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. @@ -70,28 +72,51 @@ def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Te Args: rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. """ ndim = x.ndim assert ndim > 1 - _, seqlen, _, head_dim = x.shape - rope_cache = rope_cache[0:seqlen] - # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin - assert rope_cache.shape == (seqlen, head_dim * 2) - shape = [-1, seqlen, 1, head_dim * 2] - return rope_cache.view(*shape) + bz, seqlen, _, head_dim = x.shape + if positions is None: + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + rope_cache = rope_cache[positions.squeeze(0)] + # The shape of rope_cache is (seqlen, head_dim * 2) + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + else: + assert positions.shape == (bz, seqlen) + rope_cache_expanded = rope_cache[None, :, None, :].expand(bz, -1, -1, -1) + rope_cache = torch.gather( + rope_cache_expanded, + dim=1, + index=positions.view(bz, seqlen, 1, 1).expand(bz, seqlen, 1, head_dim * 2), + ) + # The shape of rope_cache is (bz, seqlen, 1, head_dim * 2) + assert rope_cache.shape == (bz, seqlen, 1, head_dim * 2) + return rope_cache def apply_rotary_emb( - xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor + xq: torch.Tensor, + xk: torch.Tensor, + rope_cache: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # input tensor x has shape [bsz, seq_len, num_heads, head_dim] head_dim = xq.shape[-1] - # reshape for broadcast - rope_cache = reshape_for_broadcast(rope_cache, xq) + rope_cache = reshape_for_broadcast(rope_cache, xq, positions) # [bsz, seq_len, 1, head_dim] cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) @@ -194,12 +219,16 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. Args: x (torch.Tensor): Input tensor. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -224,7 +253,7 @@ def forward( xk = self.k_norm(xk) # Apply rotary embedding - xq, xk = apply_rotary_emb(xq, xk, rope_cache) + xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -350,6 +379,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -357,12 +387,16 @@ def forward( Args: x (torch.Tensor): Input tensor. rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks) + x = x + self.attention( + self.attention_norm(x), rope_cache, attention_masks, positions + ) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -515,6 +549,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -524,6 +559,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -533,7 +570,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.rope_cache, attention_masks) + h = layer(h, self.rope_cache, attention_masks, positions) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h