-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Bugfix] [Model] Missing MRoPE function definition from KeyeForConditionalGeneration
#27895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from dataclasses import asdict | ||
| from typing import NamedTuple | ||
|
|
||
| import pytest | ||
| from PIL.Image import Image | ||
| from transformers import AutoProcessor | ||
|
|
||
| from vllm import LLM, EngineArgs, SamplingParams | ||
| from vllm.multimodal.utils import encode_image_base64 | ||
|
|
||
| MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview" | ||
|
|
||
| QUESTION = "What is the content of each image?" | ||
|
|
||
|
|
||
| class ModelRequestData(NamedTuple): | ||
| engine_args: EngineArgs | ||
| prompt: str | ||
| image_data: list[Image] | ||
| stop_token_ids: list[int] | None = None | ||
| chat_template: str | None = None | ||
| sampling_params: SamplingParams | None = None | ||
|
|
||
|
|
||
| @pytest.mark.core_model | ||
| @pytest.mark.parametrize("question", [QUESTION]) | ||
| def test_keye_vl( | ||
| image_assets, | ||
| question: str, | ||
| ): | ||
| images = [asset.pil_image for asset in image_assets] | ||
|
|
||
| image_urls = [ | ||
| f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images | ||
| ] | ||
|
|
||
| engine_args = EngineArgs( | ||
| model=MODEL_NAME, | ||
| trust_remote_code=True, | ||
| max_model_len=8192, | ||
| max_num_seqs=5, | ||
| limit_mm_per_prompt={"image": len(image_urls)}, | ||
| ) | ||
|
|
||
| placeholders = [{"type": "image", "image": url} for url in image_urls] | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| *placeholders, | ||
| {"type": "text", "text": question}, | ||
| ], | ||
| }, | ||
| ] | ||
|
|
||
| processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) | ||
|
|
||
| prompt = processor.apply_chat_template( | ||
| messages, tokenize=False, add_generation_prompt=True | ||
| ) | ||
|
|
||
| engine_args = asdict(engine_args) | {"seed": 42} | ||
| llm = LLM(**engine_args) | ||
|
|
||
| sampling_params = SamplingParams( | ||
| temperature=0.0, max_tokens=256, stop_token_ids=None | ||
| ) | ||
|
|
||
| outputs = llm.generate( | ||
| { | ||
| "prompt": prompt, | ||
| "multi_modal_data": {"image": images}, | ||
| }, | ||
| sampling_params=sampling_params, | ||
| ) | ||
|
|
||
| print("-" * 50) | ||
| for o in outputs: | ||
| generated_text = o.outputs[0].text | ||
| print(generated_text) | ||
| print("-" * 50) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,9 @@ | |
| from transformers.utils import torch_int | ||
|
|
||
| from vllm.attention.backends.registry import _Backend | ||
| from vllm.attention.layer import check_upstream_fa_availability | ||
| from vllm.attention.layer import ( | ||
| maybe_get_vit_flash_attn_backend, | ||
| ) | ||
| from vllm.config import VllmConfig | ||
| from vllm.config.multimodal import BaseDummyOptions | ||
| from vllm.distributed import get_tensor_model_parallel_world_size | ||
|
|
@@ -56,12 +58,14 @@ | |
| PromptUpdate, | ||
| ) | ||
| from vllm.multimodal.profiling import BaseDummyInputsBuilder | ||
| from vllm.platforms import current_platform | ||
| from vllm.sequence import IntermediateTensors | ||
| from vllm.utils.tensor_schema import TensorSchema, TensorShape | ||
|
|
||
| from .interfaces import ( | ||
| MultiModalEmbeddings, | ||
| SupportsLoRA, | ||
| SupportsMRoPE, | ||
| SupportsMultiModal, | ||
| SupportsPP, | ||
| ) | ||
|
|
@@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt( | |
| cos = cos.chunk(2, dim=-1)[0].contiguous() | ||
| sin = sin.chunk(2, dim=-1)[0].contiguous() | ||
|
|
||
| from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb | ||
| if current_platform.is_cuda(): | ||
| from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb | ||
| elif current_platform.is_rocm(): | ||
| from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb | ||
|
|
||
| q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) | ||
| k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) | ||
|
|
@@ -398,18 +405,28 @@ def __init__( | |
| attn_backend_override=attn_backend_override, | ||
| ) | ||
|
|
||
| self.use_upstream_fa = False | ||
| if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( | ||
| torch.get_default_dtype() | ||
| ): | ||
| self.attn_backend = _Backend.FLASH_ATTN | ||
| self.use_upstream_fa = True | ||
| self.attn_backend, self.flash_attn_varlen_func = ( | ||
| maybe_get_vit_flash_attn_backend( | ||
| self.attn_backend, | ||
| use_upstream_fa=False, | ||
| attn_backend_override=attn_backend_override, | ||
| ) | ||
| ) | ||
|
|
||
| if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: | ||
| if self.attn_backend not in { | ||
| _Backend.FLASH_ATTN, | ||
| _Backend.XFORMERS, | ||
| _Backend.ROCM_AITER_FA, | ||
| }: | ||
| raise RuntimeError( | ||
| f"Keye-VL does not support {self.attn_backend} backend now." | ||
| ) | ||
|
|
||
| self.is_flash_attn_backend = self.attn_backend in { | ||
| _Backend.FLASH_ATTN, | ||
| _Backend.ROCM_AITER_FA, | ||
| } | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
@@ -457,15 +474,10 @@ def forward( | |
| self.head_dim, | ||
| ) | ||
|
|
||
| if self.attn_backend == _Backend.FLASH_ATTN: | ||
| if self.use_upstream_fa: | ||
| from flash_attn import flash_attn_varlen_func | ||
| else: | ||
| from vllm.vllm_flash_attn import flash_attn_varlen_func | ||
|
|
||
| if self.is_flash_attn_backend: | ||
| q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) | ||
|
|
||
| output = flash_attn_varlen_func( | ||
| output = self.flash_attn_varlen_func( | ||
| q, | ||
| k, | ||
| v, | ||
|
|
@@ -1542,7 +1554,7 @@ def get_mm_mapping(self) -> MultiModelKeys: | |
| dummy_inputs=KeyeDummyInputsBuilder, | ||
| ) | ||
| class KeyeForConditionalGeneration( | ||
| BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP | ||
| BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE | ||
| ): | ||
| def _build_projector( | ||
| self, | ||
|
|
@@ -1611,3 +1623,142 @@ def _process_video_input( | |
| return tuple( | ||
| self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) | ||
| ) | ||
|
|
||
| def get_mrope_input_positions( | ||
| self, | ||
| input_tokens: list[int], | ||
| hf_config: PretrainedConfig, | ||
| image_grid_thw: list[list[int]] | torch.Tensor, | ||
| video_grid_thw: list[list[int]] | torch.Tensor, | ||
| context_len: int = 0, | ||
| seq_len: int | None = None, | ||
| second_per_grid_ts: list[float] | None = None, | ||
| audio_feature_lengths: torch.Tensor | None = None, | ||
| use_audio_in_video: bool = False, | ||
| ) -> tuple[torch.Tensor, int]: | ||
| if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: | ||
| video_grid_thw = video_grid_thw[0] | ||
|
Comment on lines
+1639
to
+1640
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conditional statement appears to incorrectly handle the case where |
||
| """Get mrope input positions and delta value (Keye series).""" | ||
|
|
||
| def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: | ||
| """ | ||
| Split grid_thw along the t dimension. | ||
|
|
||
| Args: | ||
| grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. | ||
|
|
||
| Returns: | ||
| List of [1, h, w] rows, repeated t times for each original row. | ||
| """ | ||
|
|
||
| if isinstance(grid_thw, list): | ||
| grid_thw = torch.tensor(grid_thw, dtype=torch.long) | ||
|
|
||
| if grid_thw.numel() == 0: | ||
| return [] | ||
|
|
||
| t, hw = grid_thw[:, 0], grid_thw[:, 1:] | ||
| ones = torch.ones_like(hw[:, :1]) # [N,1] | ||
| out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) | ||
| return out.tolist() | ||
|
|
||
| video_grid_thw = split_thw(video_grid_thw) | ||
|
|
||
| image_token_id = hf_config.image_token_id | ||
| video_token_id = hf_config.video_token_id | ||
| spatial_merge_size = hf_config.vision_config.spatial_merge_size | ||
|
|
||
| image_nums = len(image_grid_thw) | ||
| frame_nums = len(video_grid_thw) | ||
| llm_pos_ids_list: list = [] | ||
|
|
||
| st = 0 | ||
| remain_images, remain_frames = image_nums, frame_nums | ||
|
|
||
| image_index, video_index = 0, 0 | ||
| for _ in range(image_nums + frame_nums): | ||
| if remain_images > 0: | ||
| try: | ||
| ed_image = input_tokens.index(image_token_id, st) | ||
| except ValueError: | ||
| ed_image = len(input_tokens) + 1 | ||
| else: | ||
| ed_image = len(input_tokens) + 1 | ||
| if remain_frames > 0: | ||
| try: | ||
| ed_video = input_tokens.index(video_token_id, st) | ||
| except ValueError: | ||
| ed_video = len(input_tokens) + 1 | ||
| else: | ||
| ed_video = len(input_tokens) + 1 | ||
|
|
||
| if ed_image < ed_video: | ||
| t, h, w = ( | ||
| image_grid_thw[image_index][0], | ||
| image_grid_thw[image_index][1], | ||
| image_grid_thw[image_index][2], | ||
| ) | ||
| image_index += 1 | ||
| remain_images -= 1 | ||
| ed = ed_image | ||
| else: | ||
| t, h, w = ( | ||
| video_grid_thw[video_index][0], | ||
| video_grid_thw[video_index][1], | ||
| video_grid_thw[video_index][2], | ||
| ) | ||
| video_index += 1 | ||
| remain_frames -= 1 | ||
| ed = ed_video | ||
|
|
||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||
| t, | ||
| h // spatial_merge_size, | ||
| w // spatial_merge_size, | ||
| ) | ||
| text_len = ed - st | ||
|
|
||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| llm_pos_ids_list.append( | ||
| torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx | ||
| ) | ||
|
|
||
| t_index = ( | ||
| ( | ||
| torch.arange(llm_grid_t) | ||
| .view(-1, 1) | ||
| .expand(-1, llm_grid_h * llm_grid_w) | ||
| ) | ||
| .long() | ||
| .flatten() | ||
| ) | ||
|
|
||
| h_index = ( | ||
| torch.arange(llm_grid_h) | ||
| .view(1, -1, 1) | ||
| .expand(llm_grid_t, -1, llm_grid_w) | ||
| .flatten() | ||
| ) | ||
| w_index = ( | ||
| torch.arange(llm_grid_w) | ||
| .view(1, 1, -1) | ||
| .expand(llm_grid_t, llm_grid_h, -1) | ||
| .flatten() | ||
| ) | ||
| llm_pos_ids_list.append( | ||
| torch.stack([t_index, h_index, w_index]) + text_len + st_idx | ||
| ) | ||
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w | ||
|
|
||
| if st < len(input_tokens): | ||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| text_len = len(input_tokens) - st | ||
| llm_pos_ids_list.append( | ||
| torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx | ||
| ) | ||
|
|
||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||
| mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() | ||
| llm_positions = llm_positions[:, context_len:seq_len] | ||
|
|
||
| return llm_positions, mrope_position_delta | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
apply_rotary_pos_emb_flashattthe ROCm branch importsflash_attn.ops.triton.rotary.apply_rotary, which expects bothqandktensors and returns the rotated pair. The current implementation calls this kernel twice with only(tensor, cos, sin)just like the CUDA wrapper. On ROCm this will raise aTypeErrorfor the missing argument and prevents rotary embeddings from being applied. The ROCm path should invokeapply_rotary(q, k, cos, sin)once and unpack the returned tensors, mirroring the existing usage inlayers/rotary_embedding/common.py.Useful? React with 👍 / 👎.