Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,12 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
return q, k, v

def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
Expand All @@ -285,7 +287,6 @@ def forward(

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
Expand Down Expand Up @@ -321,7 +322,6 @@ def forward(
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None,
device=q.device)
Expand Down Expand Up @@ -364,11 +364,20 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp")

def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb)
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)

x = x + self.mlp(self.norm2(x))
return x

Expand Down Expand Up @@ -528,6 +537,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

@property
def dtype(self) -> torch.dtype:
Expand Down Expand Up @@ -633,14 +643,25 @@ def forward(

# transformers
hidden_states = hidden_states.unsqueeze(1)

max_seqlen = None
seqlens = None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

# For Qwen2.5-VL-3B, float16 will overflow at last block
# for long visual tokens sequences.
Expand Down
49 changes: 37 additions & 12 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,12 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
return q, k, v

def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:

# [s, b, c] --> [s, b, 3 * head * head_dim]
Expand All @@ -329,7 +331,6 @@ def forward(

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
Expand Down Expand Up @@ -365,7 +366,6 @@ def forward(
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None,
device=q.device)
Expand Down Expand Up @@ -409,11 +409,22 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp")

def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

x = x + self.mlp(self.norm2(x))
return x

Expand Down Expand Up @@ -570,6 +581,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

@property
def dtype(self) -> torch.dtype:
Expand Down Expand Up @@ -624,8 +636,21 @@ def forward(

# transformers
x = x.unsqueeze(1)

max_seqlen = None
seqlens = None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

# adapter
x = self.merger(x)
Expand Down