Skip to content

Commit 1d48e92

Browse files
ywang96DarkLight1337
authored andcommitted
[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (vllm-project#11685)
Signed-off-by: Roger Wang <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 888d414 commit 1d48e92

File tree

17 files changed

+636
-282
lines changed

17 files changed

+636
-282
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ
647647
- `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
648648
-
649649
- ✅︎
650-
-
650+
- ✅︎
651651
* - `MiniCPMV`
652652
- MiniCPM-V
653653
- T + I<sup>E+</sup>

tests/multimodal/test_utils.py

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,22 @@
22
import mimetypes
33
import os
44
from tempfile import NamedTemporaryFile, TemporaryDirectory
5-
from typing import Dict, Tuple
5+
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple
66

77
import numpy as np
88
import pytest
99
from PIL import Image, ImageChops
1010
from transformers import AutoConfig, AutoTokenizer
1111

12+
from vllm.multimodal.inputs import PlaceholderRange
1213
from vllm.multimodal.utils import (MediaConnector,
14+
merge_and_sort_multimodal_metadata,
1315
repeat_and_pad_placeholder_tokens)
1416

17+
if TYPE_CHECKING:
18+
from vllm.multimodal.hasher import MultiModalHashDict
19+
from vllm.multimodal.inputs import MultiModalPlaceholderDict
20+
1521
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
1622
TEST_IMAGE_URLS = [
1723
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
@@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
191197
assert new_prompt == expected_prompt
192198
assert new_token_ids == expected_token_ids
193199
assert ranges == expected_ranges
200+
201+
202+
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
203+
class TestCase(NamedTuple):
204+
mm_positions: "MultiModalPlaceholderDict"
205+
mm_hashes: Optional["MultiModalHashDict"]
206+
expected_modalities: list[str]
207+
expected_ranges: list[PlaceholderRange]
208+
expected_hashes: Optional[list[str]]
209+
210+
211+
def test_merge_and_sort_multimodal_metadata():
212+
213+
test_cases = [
214+
# Single modality should return result as is but flattened
215+
TestCase(
216+
mm_positions={
217+
"image": [
218+
PlaceholderRange(offset=0, length=2),
219+
PlaceholderRange(offset=3, length=2),
220+
]
221+
},
222+
mm_hashes={"image": ["hash1", "hash2"]},
223+
expected_modalities=["image"],
224+
expected_ranges=[
225+
PlaceholderRange(offset=0, length=2),
226+
PlaceholderRange(offset=3, length=2),
227+
],
228+
expected_hashes=["hash1", "hash2"],
229+
),
230+
231+
# Single modality without hashes return None for mm hash.
232+
TestCase(
233+
mm_positions={
234+
"image": [
235+
PlaceholderRange(offset=0, length=2),
236+
PlaceholderRange(offset=2, length=2),
237+
]
238+
},
239+
mm_hashes=None,
240+
expected_modalities=["image"],
241+
expected_ranges=[
242+
PlaceholderRange(offset=0, length=2),
243+
PlaceholderRange(offset=2, length=2),
244+
],
245+
expected_hashes=None,
246+
),
247+
248+
# Multiple modalities with hashes should return sorted modalities
249+
# and flattened ranges and hashes.
250+
TestCase(
251+
mm_positions={
252+
"image": [
253+
PlaceholderRange(offset=7, length=4),
254+
PlaceholderRange(offset=11, length=5),
255+
],
256+
"audio": [
257+
PlaceholderRange(offset=0, length=2),
258+
PlaceholderRange(offset=2, length=3),
259+
]
260+
},
261+
mm_hashes={
262+
"image": ["image_hash1", "image_hash2"],
263+
"audio": ["audio_hash1", "audio_hash2"],
264+
},
265+
expected_modalities=["audio", "image"],
266+
expected_ranges=[
267+
PlaceholderRange(offset=0, length=2),
268+
PlaceholderRange(offset=2, length=3),
269+
PlaceholderRange(offset=7, length=4),
270+
PlaceholderRange(offset=11, length=5),
271+
],
272+
expected_hashes=[
273+
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
274+
],
275+
),
276+
277+
# Multiple modalities without hashes should return sorted modalities
278+
# and flattened ranges and None.
279+
TestCase(
280+
mm_positions={
281+
"image": [
282+
PlaceholderRange(offset=7, length=4),
283+
PlaceholderRange(offset=11, length=5),
284+
],
285+
"audio": [
286+
PlaceholderRange(offset=0, length=2),
287+
PlaceholderRange(offset=2, length=3),
288+
]
289+
},
290+
mm_hashes=None,
291+
expected_modalities=["audio", "image"],
292+
expected_ranges=[
293+
PlaceholderRange(offset=0, length=2),
294+
PlaceholderRange(offset=2, length=3),
295+
PlaceholderRange(offset=7, length=4),
296+
PlaceholderRange(offset=11, length=5),
297+
],
298+
expected_hashes=None,
299+
),
300+
301+
# Three modalities
302+
TestCase(
303+
mm_positions={
304+
"image": [
305+
PlaceholderRange(offset=15, length=7),
306+
PlaceholderRange(offset=22, length=8),
307+
],
308+
"audio": [
309+
PlaceholderRange(offset=0, length=2),
310+
],
311+
"video": [
312+
PlaceholderRange(offset=3, length=4),
313+
PlaceholderRange(offset=7, length=5),
314+
PlaceholderRange(offset=12, length=6),
315+
]
316+
},
317+
mm_hashes={
318+
"image": ["image_hash1", "image_hash2"],
319+
"audio": ["audio_hash1"],
320+
"video": ["video_hash1", "video_hash2", "video_hash3"]
321+
},
322+
expected_modalities=["audio", "video", "image"],
323+
expected_ranges=[
324+
PlaceholderRange(offset=0, length=2),
325+
PlaceholderRange(offset=3, length=4),
326+
PlaceholderRange(offset=7, length=5),
327+
PlaceholderRange(offset=12, length=6),
328+
PlaceholderRange(offset=15, length=7),
329+
PlaceholderRange(offset=22, length=8),
330+
],
331+
expected_hashes=[
332+
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
333+
"image_hash1", "image_hash2"
334+
],
335+
),
336+
]
337+
338+
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
339+
expected_hashes) in test_cases:
340+
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
341+
mm_positions, mm_hashes)
342+
343+
assert modalities == expected_modalities
344+
assert ranges == expected_ranges
345+
assert hashes == expected_hashes
346+
347+
348+
def test_merge_and_sort_multimodal_metadata_with_interleaving():
349+
350+
test_cases = [
351+
352+
# <image> <audio> <image> <audio>
353+
TestCase(
354+
mm_positions={
355+
"image": [
356+
PlaceholderRange(offset=0, length=4),
357+
PlaceholderRange(offset=8, length=2),
358+
],
359+
"audio": [
360+
PlaceholderRange(offset=5, length=2),
361+
PlaceholderRange(offset=11, length=4),
362+
]
363+
},
364+
mm_hashes={
365+
"image": ["image_hash1", "image_hash2"],
366+
"audio": ["audio_hash1", "audio_hash2"],
367+
},
368+
expected_modalities=[],
369+
expected_ranges=[],
370+
expected_hashes=None,
371+
),
372+
373+
# <image> <image> <video> <audio> <image>
374+
TestCase(
375+
mm_positions={
376+
"image": [
377+
PlaceholderRange(offset=0, length=2),
378+
PlaceholderRange(offset=2, length=3),
379+
PlaceholderRange(offset=20, length=4),
380+
],
381+
"audio": [
382+
PlaceholderRange(offset=5, length=2),
383+
],
384+
"video": [
385+
PlaceholderRange(offset=8, length=5),
386+
]
387+
},
388+
mm_hashes=None,
389+
expected_modalities=[],
390+
expected_ranges=[],
391+
expected_hashes=None,
392+
),
393+
]
394+
395+
for case in test_cases:
396+
with pytest.raises(ValueError) as ex_info:
397+
merge_and_sort_multimodal_metadata(case.mm_positions,
398+
case.mm_hashes)
399+
400+
assert "Interleaved mixed-modality" in str(ex_info.value)

tests/v1/core/test_kv_cache_utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from vllm.inputs import token_inputs
3+
from vllm.multimodal.inputs import MultiModalKwargs
44
from vllm.sampling_params import SamplingParams
55
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
66
KVCacheBlock,
@@ -14,14 +14,18 @@ def make_request(request_id,
1414
prompt_token_ids,
1515
mm_positions=None,
1616
mm_hashes=None):
17+
if mm_positions is None:
18+
multi_modal_inputs = None
19+
else:
20+
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
21+
1722
return Request(
1823
request_id=request_id,
19-
inputs=token_inputs(
20-
prompt_token_ids=prompt_token_ids,
21-
multi_modal_placeholders={"image": mm_positions}
22-
if mm_positions else None,
23-
multi_modal_hashes=mm_hashes,
24-
),
24+
prompt=None,
25+
prompt_token_ids=prompt_token_ids,
26+
multi_modal_inputs=multi_modal_inputs,
27+
multi_modal_hashes=mm_hashes,
28+
multi_modal_placeholders=mm_positions,
2529
sampling_params=SamplingParams(max_tokens=17),
2630
eos_token_id=100,
2731
arrival_time=0,

tests/v1/core/test_prefix_caching.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Compare the with and without prefix caching."""
22
import pytest
33

4-
from vllm.inputs import token_inputs
5-
from vllm.multimodal.inputs import PlaceholderRange
4+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
65
from vllm.sampling_params import SamplingParams
76
from vllm.utils import cdiv
87
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
@@ -13,12 +12,18 @@ def make_request(request_id,
1312
prompt_token_ids,
1413
mm_positions=None,
1514
mm_hashes=None):
15+
if mm_positions is None:
16+
multi_modal_inputs = None
17+
else:
18+
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
19+
1620
return Request(
1721
request_id=request_id,
18-
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
19-
multi_modal_placeholders={"image": mm_positions}
20-
if mm_positions else None,
21-
multi_modal_hashes=mm_hashes),
22+
prompt=None,
23+
prompt_token_ids=prompt_token_ids,
24+
multi_modal_inputs=multi_modal_inputs,
25+
multi_modal_hashes=mm_hashes,
26+
multi_modal_placeholders=mm_positions,
2227
sampling_params=SamplingParams(max_tokens=17),
2328
eos_token_id=100,
2429
arrival_time=0,

vllm/model_executor/models/interfaces.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
3939
4040
The output embeddings must be one of the following formats:
4141
- A list or tuple of 2D tensors, where each tensor corresponds to
42-
each input image.
42+
each input multimodal data item (e.g, image).
4343
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
44+
45+
NOTE: The returned multimodal embeddings must be in the same order as
46+
the appearances of their corresponding multimodal data item in the
47+
input prompt.
4448
"""
4549
...
4650

0 commit comments

Comments
 (0)