Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)

def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[:seqlen]
cos, sin = cos_sin.chunk(2, dim=-1)
return cos, sin


class RotaryEmbedding(RotaryEmbeddingBase):
def __init__(
Expand Down
88 changes: 38 additions & 50 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -341,7 +342,8 @@ def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
Expand All @@ -353,10 +355,12 @@ def forward(
batch_size = q.shape[1]

q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

if self.is_flash_attn_backend:
Expand Down Expand Up @@ -454,14 +458,16 @@ def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
Expand Down Expand Up @@ -660,44 +666,6 @@ def forward(
return embeddings


class Glm4vVisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None

def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0,
self.dim,
2,
dtype=torch.float,
device=self.inv_freq.device,
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs

def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]


class Glm4vVisionTransformer(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -731,7 +699,13 @@ def __init__(

norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList(
[
Glm4vVisionBlock(
Expand Down Expand Up @@ -789,7 +763,9 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
def rot_pos_emb(
self, grid_thw: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
Expand Down Expand Up @@ -817,9 +793,18 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids

# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]

cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
return cos_combined, sin_combined, pos_ids

def compute_attn_mask_seqlen(
self,
Expand Down Expand Up @@ -848,7 +833,9 @@ def forward(
x = self.post_conv_layernorm(x)

# compute position embedding
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
grid_thw
)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
Expand All @@ -867,7 +854,8 @@ def forward(
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
Expand Down
Loading