Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ def apply_non_moe_tp(
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
input_layouts=(Shard(1), Replicate(), None),
desired_input_layouts=(Replicate(), Replicate(), None),
input_layouts=(Shard(1), Replicate(), None, None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that when positions is not None, this is making implicit assumption that positions has the the expected sharding when it's used, namely
sharded on batch dim by DP, replicate on TP mesh, sharded on seq dim by CP

I don't have a good solution right now -- making it Replicate by default will fail here when positions is None https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L521
but clearly this is leaving a footgun. I'd suggest we add a comment for now.

Copy link
Contributor Author

@acisseJZhong acisseJZhong Dec 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah when testing positions [1, seq_len] and [bz, seqlen] in dp4tp2, I need to manually change both layouts to Replicate(). But for the default case it should be None.

For CP, I think we need to manually change https://fburl.com/v2rn2s48.
For FSDP, not sure how the sharding info is specified today but looks like it's already handled?

will add a comment for now

desired_input_layouts=(Replicate(), Replicate(), None, None),
),
# NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor
# so that the intermedidate results k is generated as a DTensor and its gradient is
Expand Down
74 changes: 68 additions & 6 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,71 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor:
return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
def reshape_for_broadcast(
freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None
) -> torch.Tensor:
"""
Reshape frequency tensor for broadcasting it with another tensor.

This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2),
and the first seqlen elements will be sliced, but dim must match x.

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache.
Shape is (1, seqlen) or (bz, seqlen). Defaults to None.

Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert ndim > 1
seqlen = x.shape[1]
if positions is None:
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
elif positions.size(0) == 1:
assert positions.shape == (1, seqlen)
freqs_cis = freqs_cis[positions.squeeze(0)]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
else:
assert positions.shape == (x.shape[0], seqlen)
freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1)
freqs_cis = torch.gather(
freqs_cis_expanded,
dim=1,
index=positions.view(x.shape[0], seqlen, 1, 1).expand(
x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1]
),
)
return freqs_cis


def apply_rotary_emb(
x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor | None = None
) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.

Args:
x (torch.Tensor): Input tensor with positional embeddings to be applied.
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
freqs_cis = reshape_for_broadcast(freqs_cis, x, positions)
y = torch.view_as_real(x * freqs_cis).flatten(3)
return y.to(dtype)

Expand Down Expand Up @@ -196,13 +247,16 @@ def forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
positions: torch.Tensor | None = None,
):
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output tensor with the same shape as the input.
Expand All @@ -222,15 +276,15 @@ def forward(
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
q_pe = apply_rotary_emb(q_pe, freqs_cis, positions)
q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim)

# Key-value projection
kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

k_pe = apply_rotary_emb(
k_pe.unsqueeze(2), freqs_cis
k_pe.unsqueeze(2), freqs_cis, positions
) # (bsz, seqlen, 1, qk_rope_head_dim)

kv = self.wkv_b(
Expand Down Expand Up @@ -312,18 +366,23 @@ def forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
positions: torch.Tensor | None = None,
):
"""
Forward pass for the Transformer block.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
x = x + self.attention(
self.attention_norm(x), freqs_cis, attention_masks, positions
)
if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
else:
Expand Down Expand Up @@ -413,6 +472,7 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
positions: torch.Tensor | None = None,
):
"""
Forward pass for the Transformer model.
Expand All @@ -422,6 +482,8 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
Expand All @@ -430,7 +492,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks)
h = layer(h, self.freqs_cis, attention_masks, positions)
h = self.norm(h) if self.norm is not None else h
output = self.output(h) if self.output is not None else h
return output
6 changes: 4 additions & 2 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,11 @@ def apply_tp(
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
input_layouts=(Shard(1), None, None),
desired_input_layouts=(Replicate(), None, None),
input_layouts=(Shard(1), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
Expand Down
57 changes: 47 additions & 10 deletions torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,36 +88,59 @@ def precompute_freqs_cis(
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def reshape_for_broadcast(
freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None
) -> torch.Tensor:
"""
Reshape frequency tensor for broadcasting it with another tensor.

This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2),
and the first seqlen elements will be sliced, but dim must match x.

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache.
Shape is (1, seqlen) or (bz, seqlen). Defaults to None.

Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert ndim > 1
seqlen = x.shape[1]
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
if positions is None:
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
elif positions.size(0) == 1:
assert positions.shape == (1, seqlen)
freqs_cis = freqs_cis[positions.squeeze(0)]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
else:
assert positions.shape == (x.shape[0], seqlen)
freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1)
freqs_cis = torch.gather(
freqs_cis_expanded,
dim=1,
index=positions.view(x.shape[0], seqlen, 1, 1).expand(
x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1]
),
)
return freqs_cis


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
positions: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
Expand All @@ -131,13 +154,14 @@ def apply_rotary_emb(
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
Expand Down Expand Up @@ -213,13 +237,16 @@ def forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
positions: torch.Tensor | None = None,
):
"""
Forward pass of the attention module.

Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output tensor after attention.
Expand All @@ -236,7 +263,7 @@ def forward(
xk = xk.view(bs, seqlen, -1, self.head_dim)
xv = xv.view(bs, seqlen, -1, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions)

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
Expand Down Expand Up @@ -360,19 +387,24 @@ def forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
positions: torch.Tensor | None = None,
):
"""
Perform a forward pass through the TransformerBlock.

Args:
x (torch.Tensor): Input tensor.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
h = x + self.attention(
self.attention_norm(x), freqs_cis, attention_masks, positions
)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down Expand Up @@ -519,6 +551,7 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
positions: torch.Tensor | None = None,
):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -528,6 +561,8 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output logits after applying the Transformer model.
Expand All @@ -537,7 +572,9 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
h = layer(
h, self.freqs_cis, attention_masks=attention_masks, positions=positions
)
h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
return output
6 changes: 4 additions & 2 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,11 @@ def apply_non_moe_tp(
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
input_layouts=(Shard(1), None, None),
desired_input_layouts=(Replicate(), None, None),
input_layouts=(Shard(1), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
Expand Down
Loading
Loading