diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index ce4f40680b0a..4114b21168cc 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -83,6 +83,11 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: ): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_sin = self.cos_sin_cache[:seqlen] + cos, sin = cos_sin.chunk(2, dim=-1) + return cos, sin + class RotaryEmbedding(RotaryEmbeddingBase): def __init__( diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 6953b805653b..65c3fc2d9e97 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -65,6 +65,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -341,7 +342,8 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: @@ -353,10 +355,12 @@ def forward( batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) - if rotary_pos_emb is not None: + if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = apply_rotary_pos_emb_vision( + qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + ) q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: @@ -454,14 +458,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) @@ -660,44 +666,6 @@ def forward( return embeddings -class Glm4vVisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - self.dim = dim - self.theta = theta - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._freqs_cached = None - - def update_freqs_cache(self, seqlen: int) -> None: - if seqlen > self._seq_len_cached: - seqlen *= 2 - self._seq_len_cached = seqlen - self.inv_freq = 1.0 / ( - self.theta - ** ( - torch.arange( - 0, - self.dim, - 2, - dtype=torch.float, - device=self.inv_freq.device, - ) - / self.dim - ) - ) - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) - freqs = torch.outer(seq, self.inv_freq) - self._freqs_cached = freqs - - def forward(self, seqlen: int) -> torch.Tensor: - self.update_freqs_cache(seqlen) - return self._freqs_cached[:seqlen] - - class Glm4vVisionTransformer(nn.Module): def __init__( self, @@ -731,7 +699,13 @@ def __init__( norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.blocks = nn.ModuleList( [ Glm4vVisionBlock( @@ -789,7 +763,9 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb( + self, grid_thw: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) @@ -817,9 +793,18 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + return cos_combined, sin_combined, pos_ids def compute_attn_mask_seqlen( self, @@ -848,7 +833,9 @@ def forward( x = self.post_conv_layernorm(x) # compute position embedding - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) + rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb( + grid_thw + ) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -867,7 +854,8 @@ def forward( x = blk( x, cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 897dd7ef29f1..2e4fd9645d88 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -64,6 +64,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.vision import should_torch_compile_mm_vit @@ -363,7 +364,8 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: @@ -378,13 +380,15 @@ def forward( head=self.num_attention_heads_per_partition, ) - if rotary_pos_emb is not None: + if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: qk, v = qkv[:, :, :2], qkv[:, :, 2] qk_reshaped = einops.rearrange( qk, "b s two head head_dim -> (two b) s head head_dim", two=2 ) - qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb) + qk_rotated = apply_rotary_pos_emb_vision( + qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin + ) qk_rotated = qk_rotated.view( 2, batch_size, @@ -434,7 +438,8 @@ def forward( dynamic_arg_dims={ "x": 0, "cu_seqlens": 0, - "rotary_pos_emb": 0, + "rotary_pos_emb_cos": 0, + "rotary_pos_emb_sin": 0, "seqlens": 0, }, mark_unbacked_dims={"seqlens": 0}, @@ -485,14 +490,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) @@ -588,42 +595,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -class Qwen2_5_VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - self.dim = dim - self.theta = theta - inv_freq = 1.0 / ( - theta ** (torch.arange(0, dim, 2, dtype=torch.float, device="cpu") / dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._freqs_cached = None - - def update_freqs_cache(self, seqlen: int) -> None: - if seqlen > self._seq_len_cached: - seqlen *= 2 - self._seq_len_cached = seqlen - self.inv_freq = 1.0 / ( - self.theta - ** ( - torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device - ) - / self.dim - ) - ) - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) - freqs = torch.outer(seq, self.inv_freq) - self._freqs_cached = freqs - - def forward(self, seqlen: int) -> torch.Tensor: - self.update_freqs_cache(seqlen) - return self._freqs_cached[:seqlen] - - class Qwen2_5_VisionTransformer(nn.Module): def __init__( self, @@ -666,7 +637,13 @@ def __init__( norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) use_upstream_fa = False self.attn_backend = get_vit_attn_backend( @@ -757,15 +734,30 @@ def rotary_pos_emb_thw(self, t, h, w): ) pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) max_size = max(h, w) - rotary_pos_emb_full = self.rotary_pos_emb(max_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - rotary_pos_emb = rotary_pos_emb.reshape( - rotary_pos_emb.shape[0] // self.spatial_merge_unit, + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + + cos_combined = cos_combined.reshape( + cos_combined.shape[0] // self.spatial_merge_unit, + self.spatial_merge_unit, + -1, + ) + sin_combined = sin_combined.reshape( + sin_combined.shape[0] // self.spatial_merge_unit, self.spatial_merge_unit, -1, ) - return rotary_pos_emb + return cos_combined, sin_combined def get_window_index_thw(self, grid_t, grid_h, grid_w): vit_merger_window_size = ( @@ -807,14 +799,19 @@ def get_window_index_thw(self, grid_t, grid_h, grid_w): @lru_cache(maxsize=1024) # noqa: B019 def get_rope_by_thw(self, t, h, w): window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w) - rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) - rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] - rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) + cos_thw, sin_thw = self.rotary_pos_emb_thw(t, h, w) + + cos_thw = cos_thw[window_index_thw, :, :] + cos_thw = cos_thw.flatten(start_dim=0, end_dim=1) + sin_thw = sin_thw[window_index_thw, :, :] + sin_thw = sin_thw.flatten(start_dim=0, end_dim=1) + cu_seqlens_thw = torch.repeat_interleave( torch.tensor([h * w], dtype=torch.int32), t ) return ( - rotary_pos_emb_thw, + cos_thw, + sin_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw, @@ -849,7 +846,8 @@ def forward( ) -> torch.Tensor: # patchify seq_len, _ = x.size() - rotary_pos_emb = [] + rotary_pos_emb_cos = [] + rotary_pos_emb_sin = [] window_index: list = [] cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] cu_seqlens: list = [] @@ -865,7 +863,8 @@ def forward( llm_w = w // self.spatial_merge_size ( - rotary_pos_emb_thw, + cos_thw, + sin_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw, @@ -878,11 +877,13 @@ def forward( cu_window_seqlens_last = cu_seqlens_window_thw[-1] cu_window_seqlens.append(cu_seqlens_window_thw) - rotary_pos_emb.append(rotary_pos_emb_thw) + rotary_pos_emb_cos.append(cos_thw) + rotary_pos_emb_sin.append(sin_thw) cu_seqlens.append(cu_seqlens_thw) - rotary_pos_emb = torch.cat(rotary_pos_emb) + rotary_pos_emb_cos = torch.cat(rotary_pos_emb_cos) + rotary_pos_emb_sin = torch.cat(rotary_pos_emb_sin) window_index = torch.cat(window_index) # compute reverse indices reverse_indices = self.invert_permutation(window_index) @@ -901,7 +902,12 @@ def forward( cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) - rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True) + rotary_pos_emb_cos = rotary_pos_emb_cos.to( + device=self.device, non_blocking=True + ) + rotary_pos_emb_sin = rotary_pos_emb_sin.to( + device=self.device, non_blocking=True + ) window_index = window_index.to(device=hidden_states.device, non_blocking=True) reverse_indices = reverse_indices.to( device=hidden_states.device, non_blocking=True @@ -928,7 +934,8 @@ def forward( hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen_now, seqlens=seqlens_now, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 5d21e249fc4c..53df5972a8fe 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -32,7 +32,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( @@ -59,7 +59,9 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding.common import ( + apply_rotary_emb_torch, dispatch_rotary_emb_function, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -275,47 +277,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +def apply_rotary_pos_emb_vision( + t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + rotary_emb_function = dispatch_rotary_emb_function( + default=partial(apply_rotary_emb_torch, is_neox_style=True) ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = rotary_emb_function(t_, cos, sin).type_as(t) + output = rotary_emb_function(t, cos, sin).type_as(t) return output @@ -412,7 +380,8 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: @@ -424,11 +393,13 @@ def forward( batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) - if rotary_pos_emb is not None: - # [2 * b, s, heads, head_dim] - qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) - q, k = torch.chunk(qk_rotated, 2, dim=0) + + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision( + qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + ) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -534,14 +505,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) @@ -628,40 +601,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -class Qwen2VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - self.dim = dim - self.theta = theta - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._freqs_cached = None - - def update_freqs_cache(self, seqlen: int) -> None: - if seqlen > self._seq_len_cached: - seqlen *= 2 - self._seq_len_cached = seqlen - self.inv_freq = 1.0 / ( - self.theta - ** ( - torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device - ) - / self.dim - ) - ) - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) - freqs = torch.outer(seq, self.inv_freq) - self._freqs_cached = freqs - - def forward(self, seqlen: int) -> torch.Tensor: - self.update_freqs_cache(seqlen) - return self._freqs_cached[:seqlen] - - class Qwen2VisionTransformer(nn.Module): def __init__( self, @@ -700,7 +639,13 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads - self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.blocks = nn.ModuleList( [ @@ -744,7 +689,9 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: + def rot_pos_emb( + self, grid_thw: list[list[int]] + ) -> tuple[torch.Tensor, torch.Tensor]: pos_ids = [] max_grid_size = 0 for t, h, w in grid_thw: @@ -773,9 +720,18 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) max_grid_size = max(max_grid_size, h, w) pos_ids = torch.cat(pos_ids, dim=0) - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + return cos_combined, sin_combined def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor @@ -806,7 +762,7 @@ def forward( grid_thw_list = grid_thw.tolist() # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw_list) + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( @@ -824,7 +780,8 @@ def forward( x = blk( x, cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 40b80ce2387c..8274b92138f7 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -60,6 +60,7 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo @@ -90,7 +91,6 @@ ) from .qwen2_5_vl import ( Qwen2_5_VisionAttention, - Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLProcessingInfo, ) from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel @@ -221,14 +221,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) @@ -332,7 +334,13 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.blocks = nn.ModuleList( [ @@ -416,9 +424,19 @@ def rot_pos_emb(self, grid_thw): pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + + return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side @@ -508,7 +526,7 @@ def forward( if self.apply_vit_abs_pos_embed: pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw) cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -519,7 +537,8 @@ def forward( cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device) + rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) hidden_states_list = [] @@ -529,7 +548,8 @@ def forward( hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 7f0c9372991d..99a4007ef7f2 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -63,6 +63,7 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -95,7 +96,6 @@ ) from .qwen2_5_vl import ( Qwen2_5_VisionAttention, - Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, @@ -232,14 +232,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, ) @@ -339,7 +341,13 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.merger = Qwen3_VisionPatchMerger( d_model=vision_config.out_hidden_size, @@ -452,9 +460,19 @@ def rot_pos_emb(self, grid_thw: list[list[int]]): for t, h, w in grid_thw ] pos_ids = torch.cat(pos_ids, dim=0) - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + + return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side @@ -547,8 +565,13 @@ def forward( pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw_list) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + rotary_pos_emb_cos = rotary_pos_emb_cos.to( + hidden_states.device, non_blocking=True + ) + rotary_pos_emb_sin = rotary_pos_emb_sin.to( + hidden_states.device, non_blocking=True + ) cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -564,7 +587,8 @@ def forward( hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, seqlens=seqlens, )