From df2181c60cd6a243b79bc79f9dd7f7e0bb1418ee Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 23 Aug 2025 18:02:27 -0700 Subject: [PATCH 1/5] Migrate Qwen2 inputs to TensorSchema Signed-off-by: Benji Beck --- vllm/model_executor/models/qwen2_5_vl.py | 160 ++++++++++++---------- vllm/model_executor/models/qwen2_audio.py | 25 +++- vllm/model_executor/models/qwen2_vl.py | 147 +++++++++++--------- 3 files changed, 194 insertions(+), 138 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 0a89f86fc738..1f9168015905 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -27,7 +27,7 @@ """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -64,6 +64,7 @@ from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -80,84 +81,115 @@ # === Vision Inputs === # -class Qwen2_5_VLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Qwen2_5_VLImagePixelInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) formatnum_channels * patch_size * patch_size """ + type: Literal["pixel_values"] + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] -class Qwen2_5_VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. - """ - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. +class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format + """ + type: Literal["image_embeds"] + + image_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs] -class Qwen2_5_VLVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` +class Qwen2_5_VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - nv: Number of videos + - ctps: Number of channels * temporal_patch_size * patch_size * patch_size + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format + - second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal + dimension in the 3D position IDs. Returned when `videos` is not `None`. """ + type: Literal["pixel_values_videos"] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] - second_per_grid_ts: torch.Tensor - """ - The video time interval (in seconds) for each grid along the temporal - dimension in the 3D position IDs. Returned when `videos` is not `None`. - """ + second_per_grid_ts: Annotated[ + torch.Tensor, + TensorShape("nv"), + ] -class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. +class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["video_embeds"] + + video_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nf", "hs"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, @@ -936,10 +968,6 @@ def _parse_and_validate_image_input( image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Qwen2_5_VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -950,9 +978,6 @@ def _parse_and_validate_image_input( image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -987,9 +1012,6 @@ def _parse_and_validate_video_input( video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 86b4a9a018c7..2a896708850f 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Optional, Union import torch import torch.nn as nn @@ -47,6 +47,7 @@ PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, @@ -54,13 +55,23 @@ # # === Audio Inputs === # -class Qwen2AudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - input_features: torch.Tensor - """Shape: `(num_audios, num_mel_bins, 3000)`""" +class Qwen2AudioInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + """ + + input_features: Annotated[ + torch.Tensor, + TensorShape("na", "nmb", 3000), + ] + - feature_attention_mask: torch.Tensor - """Shape: `(num_audios, 3000)`""" + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", 3000), + ] class Qwen2AudioEmbeddingInputs(TypedDict): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b708719e4f9b..9907050e4298 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -26,7 +26,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -70,6 +70,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -86,78 +87,110 @@ # === Vision Inputs === # -class Qwen2VLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Qwen2VLImagePixelInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["pixel_values"] + pixel_values: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("np", "cps"), + ] -class Qwen2VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2VLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["image_embeds"] - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + image_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] -class Qwen2VLVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` +class Qwen2VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each video over each prompt in + the batch + - ctps: Number of channels * temporal_patch_size * patch_size * patch_size + - nv: Number of videos + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["pixel_values_videos"] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("np", "ctps"), + ] - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] -class Qwen2VLVideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). +class Qwen2VLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["video_embeds"] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_embeds: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nf", "hs"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, @@ -1126,10 +1159,6 @@ def _parse_and_validate_image_input( image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Qwen2VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -1140,9 +1169,6 @@ def _parse_and_validate_image_input( image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2VLImageEmbeddingInputs(type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw) @@ -1174,9 +1200,6 @@ def _parse_and_validate_video_input( video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2VLVideoEmbeddingInputs(type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) From bccb636ca8d9b8a0dd78b607439364357eeec11a Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 24 Aug 2025 09:44:58 -0700 Subject: [PATCH 2/5] Update type annotations for fields using _validate_and_reshape_mm_tensor Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Benji Beck --- vllm/model_executor/models/qwen2_5_vl.py | 4 ++-- vllm/model_executor/models/qwen2_vl.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1f9168015905..35440734d36e 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -121,7 +121,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] image_embeds: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor, TensorShape("nf", "hs"), ] @@ -182,7 +182,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): type: Literal["video_embeds"] video_embeds: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor, TensorShape("nf", "hs"), ] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9907050e4298..b5a5b257f649 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -102,7 +102,7 @@ class Qwen2VLImagePixelInputs(TensorSchema): type: Literal["pixel_values"] pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor, TensorShape("np", "cps"), ] @@ -128,7 +128,7 @@ class Qwen2VLImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] image_embeds: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor, TensorShape("nf", "hs"), ] @@ -157,7 +157,7 @@ class Qwen2VLVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor, TensorShape("np", "ctps"), ] @@ -183,7 +183,7 @@ class Qwen2VLVideoEmbeddingInputs(TensorSchema): type: Literal["video_embeds"] video_embeds: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor, TensorShape("nf", "hs"), ] From 2f9f793beb43ea453f339f0ad29ebed5406e8e8d Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 24 Aug 2025 09:55:27 -0700 Subject: [PATCH 3/5] Fix precommit Signed-off-by: Benji Beck --- vllm/model_executor/models/qwen2_5_vl.py | 32 ++++++++++++++++-------- vllm/model_executor/models/qwen2_vl.py | 27 +++++++++++++------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 35440734d36e..ff0208d6a6ad 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -89,8 +89,10 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema): - cps: Number of channels * patch_size * patch_size Historical context: - - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) - - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) formatnum_channels * patch_size * patch_size + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + formatnum_channels * patch_size * patch_size """ type: Literal["pixel_values"] @@ -114,9 +116,11 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): Historical context: - image_embeds shape: (num_image_features, hidden_size) - - num_image_features varies based on the number and resolution of the images. + - num_image_features varies based on the number and resolution of the + images. - hidden_size must match the hidden size of language model backbone. - - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ type: Literal["image_embeds"] @@ -140,13 +144,17 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): Dimensions: - np: Number of patches - nv: Number of videos - - ctps: Number of channels * temporal_patch_size * patch_size * patch_size + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size Historical context: - - pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) - - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format - - second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal - dimension in the 3D position IDs. Returned when `videos` is not `None`. + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format + - second_per_grid_ts: The video time interval (in seconds) for each + grid along the temporal dimension in the 3D position IDs. Returned + when `videos` is not `None`. """ type: Literal["pixel_values_videos"] @@ -175,9 +183,11 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): Historical context: - video_embeds shape: (num_video_features, hidden_size) - - num_video_features varies based on the number and resolution of the videos. + - num_video_features varies based on the number and resolution of the + videos. - hidden_size must match the hidden size of language model backbone. - - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ type: Literal["video_embeds"] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b5a5b257f649..f00b214b1ef1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -96,8 +96,10 @@ class Qwen2VLImagePixelInputs(TensorSchema): - cps: Number of channels * patch_size * patch_size Historical context: - - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) - - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ type: Literal["pixel_values"] @@ -121,9 +123,11 @@ class Qwen2VLImageEmbeddingInputs(TensorSchema): Historical context: - image_embeds shape: (num_image_features, hidden_size) - - num_image_features varies based on the number and resolution of the images. + - num_image_features varies based on the number and resolution of the + images. - hidden_size must match the hidden size of language model backbone. - - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ type: Literal["image_embeds"] @@ -147,12 +151,15 @@ class Qwen2VLVideoPixelInputs(TensorSchema): Dimensions: - np: The total number of patches over each video over each prompt in the batch - - ctps: Number of channels * temporal_patch_size * patch_size * patch_size + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size - nv: Number of videos Historical context: - - pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) - - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ type: Literal["pixel_values_videos"] @@ -176,9 +183,11 @@ class Qwen2VLVideoEmbeddingInputs(TensorSchema): Historical context: - video_embeds shape: (num_video_features, hidden_size) - - num_video_features varies based on the number and resolution of the videos. + - num_video_features varies based on the number and resolution of the + videos. - hidden_size must match the hidden size of language model backbone. - - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ type: Literal["video_embeds"] From eab0afa08195298e6c6fc28e158afdfede023d57 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 31 Aug 2025 11:00:06 -0700 Subject: [PATCH 4/5] Migrate Qwen2AudioEmbeddingInputs to TensorSchema Signed-off-by: Benji Beck --- .../models/qwen2_5_omni_thinker.py | 23 +++++++++------- vllm/model_executor/models/qwen2_5_vl.py | 5 ++-- vllm/model_executor/models/qwen2_audio.py | 27 ++++++++++++------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 29563540a794..0f3bf1246d82 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -542,15 +542,20 @@ def _parse_and_validate_audio_input( feature_attention_mask = kwargs.pop('feature_attention_mask', None) if input_audio_features is None: return None - input_audio_features = self._validate_and_reshape_mm_tensor( - input_audio_features, 'input_audio_features', dim=1) + input_audio_features = torch.stack([ + x[:, :3000] if x.size(1) >= 3000 else torch.nn.functional.pad( + x, (0, 3000 - x.size(1))) for x in input_audio_features + ], + dim=0) if feature_attention_mask is not None: - feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') - if not isinstance(input_audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_audio_features)}") + feature_attention_mask = torch.stack( + [(m.squeeze(0)[::10] + if m.numel() == 30000 else m.squeeze(0))[:3000] + for m in feature_attention_mask], + dim=0).to(torch.long) + return Qwen2AudioFeatureInputs( + type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, feature_attention_mask=feature_attention_mask) @@ -660,8 +665,8 @@ def _process_audio_input( feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) - audio_features = audio_outputs.last_hidden_state - return audio_features.split(audio_output_lengths.tolist()) + return audio_outputs.last_hidden_state.split( + audio_output_lengths.tolist()) def _process_image_input( self, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ff0208d6a6ad..afef86fbaa02 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -169,7 +169,7 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ] second_per_grid_ts: Annotated[ - torch.Tensor, + Optional[torch.Tensor], TensorShape("nv"), ] @@ -1008,7 +1008,8 @@ def _parse_and_validate_video_input( pixel_values_videos, "video pixel values") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - + if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: + second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 2a896708850f..4a36eab5a90f 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn @@ -55,31 +55,38 @@ # # === Audio Inputs === # -class Qwen2AudioInputs(TensorSchema): +class Qwen2AudioFeatureInputs(TensorSchema): """ Dimensions: - na: Number of audios - nmb: Number of mel bins """ - + type: Literal["audio_features"] input_features: Annotated[ - torch.Tensor, + [torch.Tensor, list[torch.Tensor]], TensorShape("na", "nmb", 3000), ] - feature_attention_mask: Annotated[ torch.Tensor, TensorShape("na", 3000), ] -class Qwen2AudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: list[torch.Tensor] - """Shape: `(num_audio_features, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class Qwen2AudioEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs"), + ] Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] From bb78574422a73ac1c39e2e87ba23af2cec3f2c9e Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 6 Sep 2025 09:45:49 -0700 Subject: [PATCH 5/5] Add Qwen2_OmniThinkerAudioFeatureInputs Signed-off-by: Benji Beck --- .../models/qwen2_5_omni_thinker.py | 65 ++++++++++--------- vllm/model_executor/models/qwen2_audio.py | 2 +- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 0f3bf1246d82..d05eb76cdf6f 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -41,15 +41,13 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo, - _get_feat_extract_output_lengths) + Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -66,9 +64,9 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -81,6 +79,26 @@ logger = init_logger(__name__) +class Qwen2_5OmniAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + - msl: Maximum sequence length + - tsl: Total sequence length + """ + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nmb", "tsl"), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", "msl"), + ] + + def create_qwen2_5_omni_thinker_field_factory( spatial_merge_size: int ) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, @@ -536,25 +554,21 @@ def _validate_and_reshape_mm_tensor(self, return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]: + self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) if input_audio_features is None: return None - input_audio_features = torch.stack([ - x[:, :3000] if x.size(1) >= 3000 else torch.nn.functional.pad( - x, (0, 3000 - x.size(1))) for x in input_audio_features - ], - dim=0) + input_audio_features = self._validate_and_reshape_mm_tensor( + input_audio_features, 'input_audio_features', dim=1) if feature_attention_mask is not None: - feature_attention_mask = torch.stack( - [(m.squeeze(0)[::10] - if m.numel() == 30000 else m.squeeze(0))[:3000] - for m in feature_attention_mask], - dim=0).to(torch.long) - - return Qwen2AudioFeatureInputs( + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + if not isinstance(input_audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_audio_features)}") + return Qwen2_5OmniAudioFeatureInputs( type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, @@ -638,7 +652,7 @@ def _parse_and_validate_video_input( def _process_audio_input( self, - audio_input: Qwen2AudioFeatureInputs, + audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: @@ -712,7 +726,7 @@ def _process_video_input( dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, + nn.Module, SupportsMultiModal, SupportsPP, Qwen2_5OmniConditionalGenerationMixin): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -805,15 +819,6 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_mm_mapping(self) -> MultiModelKeys: - """Get module prefix for multimodal models to filter LoRA modules.""" - return MultiModelKeys.from_string_field( - language_model="language_model", - connector=[], # No explicit connector in this model - tower_model=["visual", - "audio_tower"], # Exclude vision and audio towers - ) - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 4a36eab5a90f..54ec7b862748 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -63,7 +63,7 @@ class Qwen2AudioFeatureInputs(TensorSchema): """ type: Literal["audio_features"] input_features: Annotated[ - [torch.Tensor, list[torch.Tensor]], + Union[torch.Tensor, list[torch.Tensor]], TensorShape("na", "nmb", 3000), ]