Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from collections.abc import Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -41,15 +41,13 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
from vllm.model_executor.models.qwen2_audio import (
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
_get_feat_extract_output_lengths)
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -66,9 +64,9 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
Expand All @@ -81,6 +79,26 @@
logger = init_logger(__name__)


class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- na: Number of audios
- nmb: Number of mel bins
- msl: Maximum sequence length
- tsl: Total sequence length
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("nmb", "tsl"),
]

feature_attention_mask: Annotated[
torch.Tensor,
TensorShape("na", "msl"),
]


def create_qwen2_5_omni_thinker_field_factory(
spatial_merge_size: int
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
Expand Down Expand Up @@ -536,7 +554,7 @@ def _validate_and_reshape_mm_tensor(self,
return torch.concat(mm_input, dim=dim)

def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
input_audio_features = kwargs.pop('input_audio_features', None)
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
Expand All @@ -550,7 +568,8 @@ def _parse_and_validate_audio_input(
if not isinstance(input_audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}")
return Qwen2AudioFeatureInputs(
return Qwen2_5OmniAudioFeatureInputs(
type="audio_features",
input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask)
Expand Down Expand Up @@ -633,7 +652,7 @@ def _parse_and_validate_video_input(

def _process_audio_input(
self,
audio_input: Qwen2AudioFeatureInputs,
audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:
Expand All @@ -660,8 +679,8 @@ def _process_audio_input(
feature_lens=audio_feature_lengths,
aftercnn_lens=audio_feat_lengths,
)
audio_features = audio_outputs.last_hidden_state
return audio_features.split(audio_output_lengths.tolist())
return audio_outputs.last_hidden_state.split(
audio_output_lengths.tolist())

def _process_image_input(
self,
Expand Down Expand Up @@ -707,7 +726,7 @@ def _process_video_input(
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
nn.Module, SupportsMultiModal, SupportsPP,
Qwen2_5OmniConditionalGenerationMixin):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
Expand Down Expand Up @@ -800,15 +819,6 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=[], # No explicit connector in this model
tower_model=["visual",
"audio_tower"], # Exclude vision and audio towers
)

def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:

Expand Down
173 changes: 103 additions & 70 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial
from typing import Callable, Literal, Optional, TypedDict, Union
from typing import Annotated, Callable, Literal, Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -64,6 +64,7 @@
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
Expand All @@ -80,84 +81,125 @@
# === Vision Inputs === #


class Qwen2_5_VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
class Qwen2_5_VLImagePixelInputs(TensorSchema):
"""

image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: Number of patches
- ni: Number of images
- cps: Number of channels * patch_size * patch_size

Historical context:
- pixel_values shape: (num_patches, num_channels * patch_size *
patch_size)
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
formatnum_channels * patch_size * patch_size
"""
type: Literal["pixel_values"]

pixel_values: Annotated[
torch.Tensor,
TensorShape("np", "cps"),
]

image_grid_thw: Annotated[
torch.Tensor,
TensorShape("ni", 3),
]

class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).

Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
"""

image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images

Historical context:
- image_embeds shape: (num_image_features, hidden_size)
- num_image_features varies based on the number and resolution of the
images.
- hidden_size must match the hidden size of language model backbone.
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
format
"""
type: Literal["image_embeds"]

image_embeds: Annotated[
torch.Tensor,
TensorShape("nf", "hs"),
]

image_grid_thw: Annotated[
torch.Tensor,
TensorShape("ni", 3),
]


Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLImageEmbeddingInputs]


class Qwen2_5_VLVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
"""
Dimensions:
- np: Number of patches
- nv: Number of videos
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size

Historical context:
- pixel_values_videos shape: (num_patches, num_channels *
temporal_patch_size * patch_size * patch_size)
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
- second_per_grid_ts: The video time interval (in seconds) for each
grid along the temporal dimension in the 3D position IDs. Returned
when `videos` is not `None`.
"""
type: Literal["pixel_values_videos"]

video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("np", "ctps"),
]

This should be in `(grid_t, grid_h, grid_w)` format.
"""
video_grid_thw: Annotated[
torch.Tensor,
TensorShape("nv", 3),
]

second_per_grid_ts: torch.Tensor
"""
The video time interval (in seconds) for each grid along the temporal
dimension in the 3D position IDs. Returned when `videos` is not `None`.
"""
second_per_grid_ts: Annotated[
Optional[torch.Tensor],
TensorShape("nv"),
]


class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).

Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the videos.
- `hidden_size` must match the hidden size of language model backbone.
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
"""

video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- nf: Number of video features
- hs: Hidden size
- nv: Number of videos

Historical context:
- video_embeds shape: (num_video_features, hidden_size)
- num_video_features varies based on the number and resolution of the
videos.
- hidden_size must match the hidden size of language model backbone.
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
format
"""
type: Literal["video_embeds"]

video_embeds: Annotated[
torch.Tensor,
TensorShape("nf", "hs"),
]

video_grid_thw: Annotated[
torch.Tensor,
TensorShape("nv", 3),
]


Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
Expand Down Expand Up @@ -936,10 +978,6 @@ def _parse_and_validate_image_input(
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")

return Qwen2_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
Expand All @@ -950,9 +988,6 @@ def _parse_and_validate_image_input(
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
Expand All @@ -973,7 +1008,8 @@ def _parse_and_validate_video_input(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")

if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
Expand All @@ -987,9 +1023,6 @@ def _parse_and_validate_video_input(
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")

if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
Expand Down
Loading