Skip to content

Commit 7ff6370

Browse files
bbeckcagemini-code-assist[bot]
authored andcommitted
Migrate Qwen2 inputs to TensorSchema (vllm-project#23475)
Signed-off-by: Benji Beck <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent a80d287 commit 7ff6370

File tree

4 files changed

+257
-164
lines changed

4 files changed

+257
-164
lines changed

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from collections.abc import Iterable, Mapping, Sequence
2626
from copy import copy
2727
from functools import partial
28-
from typing import Any, Callable, Optional, Union
28+
from typing import Annotated, Any, Callable, Literal, Optional, Union
2929

3030
import torch
3131
import torch.nn as nn
@@ -41,15 +41,13 @@
4141
from vllm.config import VllmConfig
4242
from vllm.logger import init_logger
4343
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
44-
from vllm.model_executor.models.module_mapping import MultiModelKeys
4544
from vllm.model_executor.models.qwen2_5_vl import (
4645
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
4746
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
4847
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
4948
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
5049
from vllm.model_executor.models.qwen2_audio import (
51-
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
52-
_get_feat_extract_output_lengths)
50+
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
5351
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
5452
from vllm.model_executor.sampling_metadata import SamplingMetadata
5553
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -66,9 +64,9 @@
6664
from vllm.multimodal.profiling import BaseDummyInputsBuilder
6765
from vllm.sequence import IntermediateTensors
6866
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
67+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6968

70-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
71-
SupportsMultiModal, SupportsPP)
69+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
7270
from .utils import (AutoWeightsLoader, WeightsMapper,
7371
init_vllm_registered_model, maybe_prefix,
7472
merge_multimodal_embeddings)
@@ -81,6 +79,26 @@
8179
logger = init_logger(__name__)
8280

8381

82+
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
83+
"""
84+
Dimensions:
85+
- na: Number of audios
86+
- nmb: Number of mel bins
87+
- msl: Maximum sequence length
88+
- tsl: Total sequence length
89+
"""
90+
type: Literal["audio_features"]
91+
input_features: Annotated[
92+
Union[torch.Tensor, list[torch.Tensor]],
93+
TensorShape("nmb", "tsl"),
94+
]
95+
96+
feature_attention_mask: Annotated[
97+
torch.Tensor,
98+
TensorShape("na", "msl"),
99+
]
100+
101+
84102
def create_qwen2_5_omni_thinker_field_factory(
85103
spatial_merge_size: int
86104
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
@@ -536,7 +554,7 @@ def _validate_and_reshape_mm_tensor(self,
536554
return torch.concat(mm_input, dim=dim)
537555

538556
def _parse_and_validate_audio_input(
539-
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
557+
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
540558
input_audio_features = kwargs.pop('input_audio_features', None)
541559
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
542560
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
@@ -550,7 +568,8 @@ def _parse_and_validate_audio_input(
550568
if not isinstance(input_audio_features, (torch.Tensor, list)):
551569
raise ValueError("Incorrect type of audio input features. "
552570
f"Got type: {type(input_audio_features)}")
553-
return Qwen2AudioFeatureInputs(
571+
return Qwen2_5OmniAudioFeatureInputs(
572+
type="audio_features",
554573
input_features=input_audio_features,
555574
audio_feature_lengths=audio_feature_lengths,
556575
feature_attention_mask=feature_attention_mask)
@@ -633,7 +652,7 @@ def _parse_and_validate_video_input(
633652

634653
def _process_audio_input(
635654
self,
636-
audio_input: Qwen2AudioFeatureInputs,
655+
audio_input: Qwen2_5OmniAudioFeatureInputs,
637656
audio_hashes: list[str] = None,
638657
cached_audio_features: torch.Tensor = None,
639658
) -> torch.Tensor:
@@ -660,8 +679,8 @@ def _process_audio_input(
660679
feature_lens=audio_feature_lengths,
661680
aftercnn_lens=audio_feat_lengths,
662681
)
663-
audio_features = audio_outputs.last_hidden_state
664-
return audio_features.split(audio_output_lengths.tolist())
682+
return audio_outputs.last_hidden_state.split(
683+
audio_output_lengths.tolist())
665684

666685
def _process_image_input(
667686
self,
@@ -707,7 +726,7 @@ def _process_video_input(
707726
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
708727
)
709728
class Qwen2_5OmniThinkerForConditionalGeneration(
710-
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
729+
nn.Module, SupportsMultiModal, SupportsPP,
711730
Qwen2_5OmniConditionalGenerationMixin):
712731
hf_to_vllm_mapper = WeightsMapper(
713732
orig_to_new_prefix={
@@ -800,15 +819,6 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
800819
def get_language_model(self) -> torch.nn.Module:
801820
return self.language_model
802821

803-
def get_mm_mapping(self) -> MultiModelKeys:
804-
"""Get module prefix for multimodal models to filter LoRA modules."""
805-
return MultiModelKeys.from_string_field(
806-
language_model="language_model",
807-
connector=[], # No explicit connector in this model
808-
tower_model=["visual",
809-
"audio_tower"], # Exclude vision and audio towers
810-
)
811-
812822
def get_multimodal_embeddings(self,
813823
**kwargs: object) -> MultiModalEmbeddings:
814824

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 103 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
2828
from collections.abc import Iterable, Mapping
2929
from functools import lru_cache, partial
30-
from typing import Callable, Literal, Optional, TypedDict, Union
30+
from typing import Annotated, Callable, Literal, Optional, Union
3131

3232
import torch
3333
import torch.nn as nn
@@ -64,6 +64,7 @@
6464
from vllm.platforms import _Backend
6565
from vllm.sequence import IntermediateTensors
6666
from vllm.transformers_utils.config import uses_mrope
67+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6768

6869
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
6970
SupportsMultiModal, SupportsPP, SupportsQuant)
@@ -80,84 +81,125 @@
8081
# === Vision Inputs === #
8182

8283

83-
class Qwen2_5_VLImagePixelInputs(TypedDict):
84-
type: Literal["pixel_values"]
85-
pixel_values: torch.Tensor
86-
"""Shape:
87-
`(num_patches, num_channels * patch_size * patch_size)`
84+
class Qwen2_5_VLImagePixelInputs(TensorSchema):
8885
"""
89-
90-
image_grid_thw: torch.Tensor
91-
"""Shape: `(num_images, 3)`
92-
This should be in `(grid_t, grid_h, grid_w)` format.
86+
Dimensions:
87+
- np: Number of patches
88+
- ni: Number of images
89+
- cps: Number of channels * patch_size * patch_size
90+
91+
Historical context:
92+
- pixel_values shape: (num_patches, num_channels * patch_size *
93+
patch_size)
94+
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
95+
formatnum_channels * patch_size * patch_size
9396
"""
97+
type: Literal["pixel_values"]
98+
99+
pixel_values: Annotated[
100+
torch.Tensor,
101+
TensorShape("np", "cps"),
102+
]
94103

104+
image_grid_thw: Annotated[
105+
torch.Tensor,
106+
TensorShape("ni", 3),
107+
]
95108

96-
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
97-
type: Literal["image_embeds"]
98-
image_embeds: torch.Tensor
99-
"""Supported types:
100-
- list[`torch.Tensor`]: A list of tensors holding all images' features.
101-
Each tensor holds an image's features.
102-
- `torch.Tensor`: A tensor holding all images' features
103-
(concatenation of all images' feature tensors).
104-
105-
Tensor shape: `(num_image_features, hidden_size)`
106-
- `num_image_features` varies based on
107-
the number and resolution of the images.
108-
- `hidden_size` must match the hidden size of language model backbone.
109-
"""
110109

111-
image_grid_thw: torch.Tensor
112-
"""Shape: `(num_images, 3)`
113-
This should be in `(grid_t, grid_h, grid_w)` format.
110+
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
114111
"""
112+
Dimensions:
113+
- nf: Number of image features
114+
- hs: Hidden size
115+
- ni: Number of images
116+
117+
Historical context:
118+
- image_embeds shape: (num_image_features, hidden_size)
119+
- num_image_features varies based on the number and resolution of the
120+
images.
121+
- hidden_size must match the hidden size of language model backbone.
122+
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
123+
format
124+
"""
125+
type: Literal["image_embeds"]
126+
127+
image_embeds: Annotated[
128+
torch.Tensor,
129+
TensorShape("nf", "hs"),
130+
]
131+
132+
image_grid_thw: Annotated[
133+
torch.Tensor,
134+
TensorShape("ni", 3),
135+
]
115136

116137

117138
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
118139
Qwen2_5_VLImageEmbeddingInputs]
119140

120141

121-
class Qwen2_5_VLVideoPixelInputs(TypedDict):
122-
type: Literal["pixel_values_videos"]
123-
pixel_values_videos: torch.Tensor
124-
"""Shape:
125-
`(num_patches,
126-
num_channels * temporal_patch_size * patch_size * patch_size)`
142+
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
143+
"""
144+
Dimensions:
145+
- np: Number of patches
146+
- nv: Number of videos
147+
- ctps: Number of channels * temporal_patch_size * patch_size *
148+
patch_size
149+
150+
Historical context:
151+
- pixel_values_videos shape: (num_patches, num_channels *
152+
temporal_patch_size * patch_size * patch_size)
153+
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
154+
format
155+
- second_per_grid_ts: The video time interval (in seconds) for each
156+
grid along the temporal dimension in the 3D position IDs. Returned
157+
when `videos` is not `None`.
127158
"""
159+
type: Literal["pixel_values_videos"]
128160

129-
video_grid_thw: torch.Tensor
130-
"""Shape: `(num_videos, 3)`
161+
pixel_values_videos: Annotated[
162+
torch.Tensor,
163+
TensorShape("np", "ctps"),
164+
]
131165

132-
This should be in `(grid_t, grid_h, grid_w)` format.
133-
"""
166+
video_grid_thw: Annotated[
167+
torch.Tensor,
168+
TensorShape("nv", 3),
169+
]
134170

135-
second_per_grid_ts: torch.Tensor
136-
"""
137-
The video time interval (in seconds) for each grid along the temporal
138-
dimension in the 3D position IDs. Returned when `videos` is not `None`.
139-
"""
171+
second_per_grid_ts: Annotated[
172+
Optional[torch.Tensor],
173+
TensorShape("nv"),
174+
]
140175

141176

142-
class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
143-
type: Literal["video_embeds"]
144-
video_embeds: torch.Tensor
145-
"""Supported types:
146-
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
147-
Each tensor holds an video's features.
148-
- `torch.Tensor`: A tensor holding all videos' features
149-
(concatenation of all videos' feature tensors).
150-
151-
Tensor shape: `(num_image_features, hidden_size)`
152-
- `num_image_features` varies based on
153-
the number and resolution of the videos.
154-
- `hidden_size` must match the hidden size of language model backbone.
177+
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
155178
"""
156-
157-
video_grid_thw: torch.Tensor
158-
"""Shape: `(num_videos, 3)`
159-
This should be in `(grid_t, grid_h, grid_w)` format.
179+
Dimensions:
180+
- nf: Number of video features
181+
- hs: Hidden size
182+
- nv: Number of videos
183+
184+
Historical context:
185+
- video_embeds shape: (num_video_features, hidden_size)
186+
- num_video_features varies based on the number and resolution of the
187+
videos.
188+
- hidden_size must match the hidden size of language model backbone.
189+
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
190+
format
160191
"""
192+
type: Literal["video_embeds"]
193+
194+
video_embeds: Annotated[
195+
torch.Tensor,
196+
TensorShape("nf", "hs"),
197+
]
198+
199+
video_grid_thw: Annotated[
200+
torch.Tensor,
201+
TensorShape("nv", 3),
202+
]
161203

162204

163205
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
@@ -936,10 +978,6 @@ def _parse_and_validate_image_input(
936978
image_grid_thw = self._validate_and_reshape_mm_tensor(
937979
image_grid_thw, "image grid_thw")
938980

939-
if not isinstance(pixel_values, (torch.Tensor, list)):
940-
raise ValueError("Incorrect type of image pixel values. "
941-
f"Got type: {type(pixel_values)}")
942-
943981
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
944982
pixel_values=pixel_values,
945983
image_grid_thw=image_grid_thw)
@@ -950,9 +988,6 @@ def _parse_and_validate_image_input(
950988
image_grid_thw = self._validate_and_reshape_mm_tensor(
951989
image_grid_thw, "image grid_thw")
952990

953-
if not isinstance(image_embeds, torch.Tensor):
954-
raise ValueError("Incorrect type of image embeddings. "
955-
f"Got type: {type(image_embeds)}")
956991
return Qwen2_5_VLImageEmbeddingInputs(
957992
type="image_embeds",
958993
image_embeds=image_embeds,
@@ -973,7 +1008,8 @@ def _parse_and_validate_video_input(
9731008
pixel_values_videos, "video pixel values")
9741009
video_grid_thw = self._validate_and_reshape_mm_tensor(
9751010
video_grid_thw, "video grid_thw")
976-
1011+
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
1012+
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
9771013
return Qwen2_5_VLVideoPixelInputs(
9781014
type="pixel_values_videos",
9791015
pixel_values_videos=pixel_values_videos,
@@ -987,9 +1023,6 @@ def _parse_and_validate_video_input(
9871023
video_grid_thw = self._validate_and_reshape_mm_tensor(
9881024
video_grid_thw, "video grid_thw")
9891025

990-
if not isinstance(video_embeds, torch.Tensor):
991-
raise ValueError("Incorrect type of video embeddings. "
992-
f"Got type: {type(video_embeds)}")
9931026
return Qwen2_5_VLVideoEmbeddingInputs(
9941027
type="video_embeds",
9951028
video_embeds=video_embeds,

0 commit comments

Comments
 (0)