Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions torchtitan/experiments/vllm/model/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,9 @@ def forward(
# Get embeddings from 2D tokens
h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size]

# Get RoPE cache
seqlen = h.shape[1] # seq_len dimension
rope_cache = self.model.rope_cache[:seqlen]

# Pass through transformer layers
for layer in self.model.layers.values():
h = layer(h, rope_cache, attention_masks=None)
h = layer(h, self.model.rope_cache, attention_masks=None, positions=positions_2d)

# Convert output format back to vLLM expectations
# vLLM expects hidden_states in [total_tokens, hidden_size] format
Expand Down
69 changes: 53 additions & 16 deletions torchtitan/models/qwen3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,28 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)


def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None) -> torch.Tensor:
"""
Reshapes the RoPE frequency tensor to be broadcastable with the input tensor.
Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor.

This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2),
and the first seqlen elements will be sliced, but dim must match x.

Args:
rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped.
x (torch.Tensor): Input tensor whose shape will determine the reshaping.
x (torch.Tensor): Target tensor for broadcasting compatibility.
positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache.
Shape is (1, seqlen) or (bz, seqlen). Defaults to None.

Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert ndim > 1
_, seqlen, _, head_dim = x.shape
bz, seqlen, _, head_dim = x.shape

# Extend rope_cache if needed (e.g., during vLLM profiling with 2x max_seq_len)
if seqlen > rope_cache.shape[0]:
Expand Down Expand Up @@ -104,21 +112,41 @@ def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Te
else:
rope_cache = extended_cache

rope_cache = rope_cache[0:seqlen]
# The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin
assert rope_cache.shape == (seqlen, head_dim * 2)
shape = [-1, seqlen, 1, head_dim * 2]
return rope_cache.view(*shape)
if positions is None:
rope_cache = rope_cache[0:seqlen]
# The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin
assert rope_cache.shape == (seqlen, head_dim * 2)
shape = [-1, seqlen, 1, head_dim * 2]
return rope_cache.view(*shape)
elif positions.size(0) == 1:
assert positions.shape == (1, seqlen)
rope_cache = rope_cache[positions.squeeze(0)]
# The shape of rope_cache is (seqlen, head_dim * 2)
assert rope_cache.shape == (seqlen, head_dim * 2)
shape = [-1, seqlen, 1, head_dim * 2]
return rope_cache.view(*shape)
else:
assert positions.shape == (bz, seqlen)
rope_cache_expanded = rope_cache[None, :, None, :].expand(bz, -1, -1, -1)
rope_cache = torch.gather(
rope_cache_expanded,
dim=1,
index=positions.view(bz, seqlen, 1, 1).expand(
bz, seqlen, 1, head_dim * 2
),
)
# The shape of rope_cache is (bz, seqlen, 1, head_dim * 2)
assert rope_cache.shape == (bz, seqlen, 1, head_dim * 2)
return rope_cache


def apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor
xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor, positions: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# input tensor x has shape [bsz, seq_len, num_heads, head_dim]
head_dim = xq.shape[-1]

# reshape for broadcast
rope_cache = reshape_for_broadcast(rope_cache, xq)
rope_cache = reshape_for_broadcast(rope_cache, xq, positions)

# [bsz, seq_len, 1, head_dim]
cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device)
Expand Down Expand Up @@ -216,12 +244,16 @@ def forward(
x: torch.Tensor,
rope_cache: torch.Tensor,
attention_masks: AttentionMasksType | None,
positions: torch.Tensor | None = None,
):
"""
Forward pass of the attention module.

Args:
x (torch.Tensor): Input tensor.
rope_cache (torch.Tensor): Precomputed cosine and sine frequencies.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output tensor after attention.
Expand All @@ -246,7 +278,7 @@ def forward(
xk = self.k_norm(xk)

# Apply rotary embedding
xq, xk = apply_rotary_emb(xq, xk, rope_cache)
xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions)

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
Expand Down Expand Up @@ -362,19 +394,21 @@ def forward(
x: torch.Tensor,
rope_cache: torch.Tensor,
attention_masks: AttentionMasksType | None,
positions: torch.Tensor | None = None,
):
"""
Perform a forward pass through the TransformerBlock.

Args:
x (torch.Tensor): Input tensor.
rope_cache (torch.Tensor): Precomputed cosine and sine frequencies.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks, positions)

if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
Expand Down Expand Up @@ -502,6 +536,7 @@ def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
positions: torch.Tensor | None = None,
):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -511,6 +546,8 @@ def forward(
If pipeline parallelism is enabled, this will be the input token indices
for the ranks on the first pipeline stage. This will be the activation of the
previous pipeline stage if the current rank is not on the first stage.
attention_masks (AttentionMasksType | None): Masks used when calculating attention scores.
positions (torch.Tensor | None): Position indices used to access RoPE cache. Defaults to None.

Returns:
torch.Tensor: Output logits after applying the Transformer model.
Expand All @@ -520,7 +557,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.rope_cache, attention_masks)
h = layer(h, self.rope_cache, attention_masks, positions)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
Loading