diff --git a/torchtitan/experiments/vllm/model/qwen3.py b/torchtitan/experiments/vllm/model/qwen3.py index 6abaef480d..92549822c9 100644 --- a/torchtitan/experiments/vllm/model/qwen3.py +++ b/torchtitan/experiments/vllm/model/qwen3.py @@ -288,13 +288,9 @@ def forward( # Get embeddings from 2D tokens h = self.model.tok_embeddings(tokens_2d) # [1, total_tokens, hidden_size] - # Get RoPE cache - seqlen = h.shape[1] # seq_len dimension - rope_cache = self.model.rope_cache[:seqlen] - # Pass through transformer layers for layer in self.model.layers.values(): - h = layer(h, rope_cache, attention_masks=None) + h = layer(h, self.model.rope_cache, attention_masks=None, positions=positions_2d) # Convert output format back to vLLM expectations # vLLM expects hidden_states in [total_tokens, hidden_size] format diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 3acd012afd..aa3ab87782 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -54,20 +54,28 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None) -> torch.Tensor: """ - Reshapes the RoPE frequency tensor to be broadcastable with the input tensor. + Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), + and the first seqlen elements will be sliced, but dim must match x. Args: rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. - x (torch.Tensor): Input tensor whose shape will determine the reshaping. + x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. """ ndim = x.ndim assert ndim > 1 - _, seqlen, _, head_dim = x.shape + bz, seqlen, _, head_dim = x.shape # Extend rope_cache if needed (e.g., during vLLM profiling with 2x max_seq_len) if seqlen > rope_cache.shape[0]: @@ -104,21 +112,41 @@ def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Te else: rope_cache = extended_cache - rope_cache = rope_cache[0:seqlen] - # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin - assert rope_cache.shape == (seqlen, head_dim * 2) - shape = [-1, seqlen, 1, head_dim * 2] - return rope_cache.view(*shape) + if positions is None: + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + rope_cache = rope_cache[positions.squeeze(0)] + # The shape of rope_cache is (seqlen, head_dim * 2) + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + else: + assert positions.shape == (bz, seqlen) + rope_cache_expanded = rope_cache[None, :, None, :].expand(bz, -1, -1, -1) + rope_cache = torch.gather( + rope_cache_expanded, + dim=1, + index=positions.view(bz, seqlen, 1, 1).expand( + bz, seqlen, 1, head_dim * 2 + ), + ) + # The shape of rope_cache is (bz, seqlen, 1, head_dim * 2) + assert rope_cache.shape == (bz, seqlen, 1, head_dim * 2) + return rope_cache def apply_rotary_emb( - xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor, positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # input tensor x has shape [bsz, seq_len, num_heads, head_dim] head_dim = xq.shape[-1] - # reshape for broadcast - rope_cache = reshape_for_broadcast(rope_cache, xq) + rope_cache = reshape_for_broadcast(rope_cache, xq, positions) # [bsz, seq_len, 1, head_dim] cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) @@ -216,12 +244,16 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. Args: x (torch.Tensor): Input tensor. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -246,7 +278,7 @@ def forward( xk = self.k_norm(xk) # Apply rotary embedding - xq, xk = apply_rotary_emb(xq, xk, rope_cache) + xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -362,6 +394,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -369,12 +402,13 @@ def forward( Args: x (torch.Tensor): Input tensor. rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. - """ - x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks) + x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks, positions) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -502,6 +536,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -511,6 +546,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -520,7 +557,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.rope_cache, attention_masks) + h = layer(h, self.rope_cache, attention_masks, positions) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h