diff --git a/tests/models/multimodal/generation/test_keye.py b/tests/models/multimodal/generation/test_keye.py new file mode 100644 index 000000000000..6f98bde1d91e --- /dev/null +++ b/tests/models/multimodal/generation/test_keye.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict +from typing import NamedTuple + +import pytest +from PIL.Image import Image +from transformers import AutoProcessor + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.multimodal.utils import encode_image_base64 + +MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview" + +QUESTION = "What is the content of each image?" + + +class ModelRequestData(NamedTuple): + engine_args: EngineArgs + prompt: str + image_data: list[Image] + stop_token_ids: list[int] | None = None + chat_template: str | None = None + sampling_params: SamplingParams | None = None + + +@pytest.mark.core_model +@pytest.mark.parametrize("question", [QUESTION]) +def test_keye_vl( + image_assets, + question: str, +): + images = [asset.pil_image for asset in image_assets] + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images + ] + + engine_args = EngineArgs( + model=MODEL_NAME, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + engine_args = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_args) + + sampling_params = SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=None + ) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + assert len(generated_text) > 10, ( + f"Generated text is too short: {generated_text}" + ) + print("-" * 50) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index acfd51a6d0cc..5f8659a3064e 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,7 +17,9 @@ from transformers.utils import torch_int from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import ( + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -56,12 +58,14 @@ PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt( cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + elif current_platform.is_rocm(): + from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) @@ -398,18 +405,28 @@ def __init__( attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() - ): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + use_upstream_fa=False, + attn_backend_override=attn_backend_override, + ) + ) - if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, + }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } + def forward( self, hidden_states: torch.Tensor, @@ -457,15 +474,10 @@ def forward( self.head_dim, ) - if self.attn_backend == _Backend.FLASH_ATTN: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - + if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func( + output = self.flash_attn_varlen_func( q, k, v, @@ -1542,7 +1554,7 @@ def get_mm_mapping(self) -> MultiModelKeys: dummy_inputs=KeyeDummyInputsBuilder, ) class KeyeForConditionalGeneration( - BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): def _build_projector( self, @@ -1611,3 +1623,142 @@ def _process_video_input( return tuple( self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) ) + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + context_len: int = 0, + seq_len: int | None = None, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: + video_grid_thw = video_grid_thw[0] + """Get mrope input positions and delta value (Keye series).""" + + def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: + """ + Split grid_thw along the t dimension. + + Args: + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. + + Returns: + List of [1, h, w] rows, repeated t times for each original row. + """ + + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.long) + + if grid_thw.numel() == 0: + return [] + + t, hw = grid_thw[:, 0], grid_thw[:, 1:] + ones = torch.ones_like(hw[:, :1]) # [N,1] + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) + return out.tolist() + + video_grid_thw = split_thw(video_grid_thw) + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_nums = len(image_grid_thw) + frame_nums = len(video_grid_thw) + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_frames = image_nums, frame_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + frame_nums): + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_frames > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_frames -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta