Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/models/multimodal/generation/test_keye.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from typing import NamedTuple

import pytest
from PIL.Image import Image
from transformers import AutoProcessor

from vllm import LLM, EngineArgs, SamplingParams
from vllm.multimodal.utils import encode_image_base64

MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"

QUESTION = "What is the content of each image?"


class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
image_data: list[Image]
stop_token_ids: list[int] | None = None
chat_template: str | None = None
sampling_params: SamplingParams | None = None


@pytest.mark.core_model
@pytest.mark.parametrize("question", [QUESTION])
def test_keye_vl(
image_assets,
question: str,
):
images = [asset.pil_image for asset in image_assets]

image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
]

engine_args = EngineArgs(
model=MODEL_NAME,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]

processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)

prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

engine_args = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_args)

sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=None
)

outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)

print("-" * 50)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
print("-" * 50)
185 changes: 168 additions & 17 deletions vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from transformers.utils import torch_int

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -56,12 +58,14 @@
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
Expand Down Expand Up @@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()

from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb

q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
Comment on lines +344 to 350

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Call ROCm rotary kernel with wrong signature

In apply_rotary_pos_emb_flashatt the ROCm branch imports flash_attn.ops.triton.rotary.apply_rotary, which expects both q and k tensors and returns the rotated pair. The current implementation calls this kernel twice with only (tensor, cos, sin) just like the CUDA wrapper. On ROCm this will raise a TypeError for the missing argument and prevents rotary embeddings from being applied. The ROCm path should invoke apply_rotary(q, k, cos, sin) once and unpack the returned tensors, mirroring the existing usage in layers/rotary_embedding/common.py.

Useful? React with 👍 / 👎.

Expand Down Expand Up @@ -398,18 +405,28 @@ def __init__(
attn_backend_override=attn_backend_override,
)

self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)

if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)

self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -457,15 +474,10 @@ def forward(
self.head_dim,
)

if self.attn_backend == _Backend.FLASH_ATTN:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
Expand Down Expand Up @@ -1542,7 +1554,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
def _build_projector(
self,
Expand Down Expand Up @@ -1611,3 +1623,142 @@ def _process_video_input(
return tuple(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
)

def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
Comment on lines +1639 to +1640
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This conditional statement appears to incorrectly handle the case where video_grid_thw is a list[list[int]]. If video_grid_thw is a list of multiple video grids (e.g., [[t1, h1, w1], [t2, h2, w2]]), this line will slice it to just the first grid ([t1, h1, w1]). When this 1D list is passed to split_thw, it will be converted to a 1D tensor, causing an indexing error at grid_thw[:, 0] and crashing the execution. Since split_thw is already capable of handling a list[list[int]] by converting it to a 2D tensor, this slicing logic is both incorrect and unnecessary. Removing it will ensure correct behavior for all valid input types.

"""Get mrope input positions and delta value (Keye series)."""

def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.

Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].

Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""

if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)

if grid_thw.numel() == 0:
return []

t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()

video_grid_thw = split_thw(video_grid_thw)

image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size

image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []

st = 0
remain_images, remain_frames = image_nums, frame_nums

image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1

if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video

llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st

st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
)
.long()
.flatten()
)

h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]

return llm_positions, mrope_position_delta