Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,8 @@ def forward(
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
q = apply_rotary_pos_emb_vision(q,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
Expand Down Expand Up @@ -230,14 +230,13 @@ def apply_rotary_emb_torch(x: torch.Tensor,


def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor,
use_flash_attn=False) -> torch.Tensor:
freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn:
from flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output

Expand Down