Skip to content

Commit e34a7ae

Browse files
alex-jw-brookssumitd2
authored andcommitted
[Model] Add min_pixels / max_pixels to Qwen2VL as mm_processor_kwargs (vllm-project#9612)
Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
1 parent 530b4ab commit e34a7ae

File tree

3 files changed

+236
-18
lines changed

3 files changed

+236
-18
lines changed

examples/offline_inference_vision_language.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ def run_qwen2_vl(question: str, modality: str):
267267
model=model_name,
268268
max_model_len=8192,
269269
max_num_seqs=5,
270+
# Note - mm_processor_kwargs can also be passed to generate/chat calls
271+
mm_processor_kwargs={
272+
"min_pixels": 28 * 28,
273+
"max_pixels": 1280 * 28 * 28,
274+
},
270275
)
271276

272277
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from typing import Any, Dict, Tuple
2+
3+
import pytest
4+
import torch
5+
from PIL.Image import Image
6+
from transformers import AutoTokenizer
7+
8+
from vllm.inputs import InputContext, token_inputs
9+
from vllm.multimodal import MultiModalRegistry
10+
11+
from ....conftest import _ImageAssets
12+
from ...utils import build_model_context
13+
14+
MODEL = "Qwen/Qwen2-VL-2B-Instruct"
15+
MIN_PIXELS = "min_pixels"
16+
MAX_PIXELS = "max_pixels"
17+
18+
19+
# Fixtures lazy import to avoid initializing CUDA during test collection
20+
# NOTE: Qwen2vl supports multiple input modalities, so it registers multiple
21+
# input mappers.
22+
@pytest.fixture()
23+
def image_input_mapper_for_qwen2_vl():
24+
from vllm.model_executor.models.qwen2_vl import (
25+
image_input_mapper_for_qwen2_vl)
26+
return image_input_mapper_for_qwen2_vl
27+
28+
29+
@pytest.fixture()
30+
def input_processor_for_qwen2_vl():
31+
from vllm.model_executor.models.qwen2_vl import (
32+
input_processor_for_qwen2_vl)
33+
return input_processor_for_qwen2_vl
34+
35+
36+
@pytest.fixture()
37+
def qwen2_vl_context() -> InputContext:
38+
return build_model_context(model_name=MODEL)
39+
40+
41+
@pytest.fixture()
42+
def get_max_qwen2_vl_image_tokens():
43+
from vllm.model_executor.models.qwen2_vl import (
44+
get_max_qwen2_vl_image_tokens)
45+
return get_max_qwen2_vl_image_tokens
46+
47+
48+
@pytest.fixture()
49+
def dummy_data_for_qwen2_vl():
50+
from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl
51+
return dummy_data_for_qwen2_vl
52+
53+
54+
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
55+
({}, 1225),
56+
({
57+
MIN_PIXELS: 64**2,
58+
MAX_PIXELS: 512**2
59+
}, 324),
60+
])
61+
def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens,
62+
qwen2_vl_context: InputContext,
63+
mm_processor_kwargs: Dict[str, Any],
64+
expected_max_tokens: int):
65+
"""Ensure that the max token calc handles min/max pixels properly."""
66+
actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context,
67+
**mm_processor_kwargs)
68+
assert actual_max_tokens == expected_max_tokens
69+
70+
71+
@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [
72+
[{}, 1225, (980, 980)],
73+
[{
74+
MIN_PIXELS: 64**2,
75+
MAX_PIXELS: 512**2
76+
}, 324, (504, 504)],
77+
])
78+
def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
79+
qwen2_vl_context: InputContext,
80+
mm_processor_kwargs: Dict[str, Any],
81+
token_count: int, img_size: Tuple[int, int]):
82+
"""Ensure that the dummy data handles min/max pixels properly."""
83+
seq_len = 3000
84+
hf_config = qwen2_vl_context.get_hf_config()
85+
image_token_id = hf_config.image_token_id
86+
87+
# NOTE: video value is required, but isn't actually used
88+
# when making the dummy data except for error handling currently
89+
seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, {
90+
"image": 1,
91+
"video": 0
92+
}, **mm_processor_kwargs)
93+
94+
# Ensure we have the right number of placeholders for min/max pixel values
95+
assert seq_data.get_token_ids().count(image_token_id) == token_count
96+
97+
# Ensure the images were resized correctly
98+
image = mm_data["image"]
99+
assert isinstance(image, Image)
100+
assert image.size == img_size
101+
102+
103+
@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [
104+
({}, 1426),
105+
({
106+
MIN_PIXELS: 64**2,
107+
MAX_PIXELS: 512**2
108+
}, 330),
109+
])
110+
def test_input_processor(input_processor_for_qwen2_vl,
111+
qwen2_vl_context: InputContext,
112+
image_assets: _ImageAssets, num_placeholders: int,
113+
mm_processor_kwargs: Dict[str, Any]):
114+
"""Ensure that the image processor handles min/max pixels properly."""
115+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
116+
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
117+
118+
image = image_assets[0].pil_image
119+
hf_config = qwen2_vl_context.get_hf_config()
120+
image_token_id = hf_config.image_token_id
121+
122+
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
123+
prompt=prompt,
124+
multi_modal_data={"image": [image]})
125+
126+
processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs,
127+
**mm_processor_kwargs)
128+
assert processed_inputs["prompt_token_ids"].count(
129+
image_token_id) == num_placeholders
130+
assert len(processed_inputs["multi_modal_data"]["image"]) == 1
131+
132+
133+
@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [
134+
({}, [5704, 1176]),
135+
({
136+
MIN_PIXELS: 64**2,
137+
MAX_PIXELS: 512**2
138+
}, [1320, 1176]),
139+
])
140+
def test_image_mapper_override(qwen2_vl_context: InputContext,
141+
image_assets: _ImageAssets,
142+
mm_processor_kwargs: Dict[str, Any],
143+
pixels_shape: Tuple[int, int]):
144+
"""Ensure that the image mapper handles min/max pixels properly."""
145+
mm_registry = MultiModalRegistry()
146+
mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config)
147+
148+
image = image_assets[0].pil_image
149+
150+
mapped_output = mm_registry.map_input(
151+
qwen2_vl_context.model_config,
152+
{"image": image},
153+
mm_processor_kwargs=mm_processor_kwargs,
154+
)
155+
156+
# Dimension 0 of pixel values should match the product of image_grid_thw
157+
actual_pixels_shape = mapped_output["pixel_values"].shape
158+
assert list(actual_pixels_shape) == pixels_shape
159+
assert actual_pixels_shape[0] == torch.prod(
160+
mapped_output["image_grid_thw"])

vllm/model_executor/models/qwen2_vl.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,9 @@ def mm_input_mapper_for_qwen2_vl(
549549
ctx: InputContext,
550550
data: MultiModalData[object],
551551
data_type_key: str,
552+
*,
553+
min_pixels: Optional[int] = None,
554+
max_pixels: Optional[int] = None,
552555
) -> MultiModalInputs:
553556
"""Input mapper for Qwen2-VL."""
554557
if data_type_key == "image" and isinstance(data, dict):
@@ -557,8 +560,19 @@ def mm_input_mapper_for_qwen2_vl(
557560
"image_grid_thw": data.get("image_grid_thw"),
558561
})
559562
model_config = ctx.model_config
563+
# Handle mm processor kwargs; we pass these at creation time
564+
# because preprocess() in transformers doesn't expose them
565+
mm_processor_kwargs = {}
566+
if min_pixels:
567+
mm_processor_kwargs["min_pixels"] = min_pixels
568+
if max_pixels:
569+
mm_processor_kwargs["max_pixels"] = max_pixels
570+
560571
image_processor = cached_get_image_processor(
561-
model_config.model, trust_remote_code=model_config.trust_remote_code)
572+
model_config.model,
573+
trust_remote_code=model_config.trust_remote_code,
574+
**mm_processor_kwargs,
575+
)
562576
if image_processor is None:
563577
raise RuntimeError("No HuggingFace processor is available "
564578
"to process the image object")
@@ -631,25 +645,36 @@ def _get_max_image_info(
631645
image_processor,
632646
data_type_key: str = "image",
633647
mm_count: int = 1,
648+
min_pixels: Optional[int] = None,
649+
max_pixels: Optional[int] = None,
634650
):
651+
# Limit min / max pixels unless they're explicitly provided
652+
if min_pixels is None:
653+
min_pixels = max(image_processor.min_pixels, 28 * 28)
654+
if max_pixels is None:
655+
max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28)
656+
635657
return _get_vision_info(
636658
image_processor,
637659
height=9999999,
638660
width=9999999,
639-
640-
# Limit min / max pixels.
641-
min_pixels=max(image_processor.min_pixels, 28 * 28),
642-
max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
661+
min_pixels=min_pixels,
662+
max_pixels=max_pixels,
643663
data_type_key=data_type_key,
644664
mm_count=mm_count,
645665
)
646666

647667

648-
def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
668+
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
669+
data_type_key: str,
670+
*,
671+
min_pixels=None,
672+
max_pixels=None) -> int:
649673
image_processor = cached_get_image_processor(ctx.model_config.model)
650674
max_resized_height, max_resized_width, max_llm_image_tokens = \
651675
_get_max_image_info(image_processor, data_type_key=data_type_key,
652-
mm_count=1)
676+
mm_count=1, min_pixels=min_pixels,
677+
max_pixels=max_pixels)
653678
return max_llm_image_tokens
654679

655680

@@ -660,14 +685,20 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
660685

661686

662687
def dummy_data_for_qwen2_vl(
663-
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
688+
ctx: InputContext,
689+
seq_len: int,
690+
mm_counts: Mapping[str, int],
691+
*,
692+
min_pixels: Optional[int] = None,
693+
max_pixels: Optional[int] = None
664694
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
665695
image_processor = cached_get_image_processor(ctx.model_config.model)
666696

667697
num_images = mm_counts["image"]
668698
max_resized_height, max_resized_width, max_llm_image_tokens = \
669699
_get_max_image_info(image_processor, data_type_key="image",
670-
mm_count=num_images)
700+
mm_count=num_images, min_pixels=min_pixels,
701+
max_pixels=max_pixels)
671702
if seq_len - max_llm_image_tokens - 2 < 0:
672703
raise RuntimeError(
673704
f"Qwen2-VL cannot process {num_images} images in a prompt, "
@@ -678,10 +709,11 @@ def dummy_data_for_qwen2_vl(
678709
num_videos = mm_counts["video"]
679710
max_resized_height, max_resized_width, max_llm_video_tokens = \
680711
_get_max_image_info(image_processor, data_type_key="video",
681-
mm_count=num_videos)
712+
mm_count=num_videos, min_pixels=min_pixels,
713+
max_pixels=max_pixels)
682714
if seq_len - max_llm_video_tokens - 2 < 0:
683715
raise RuntimeError(
684-
f"Qwen2-VL cannot process {num_images} videos in a prompt, "
716+
f"Qwen2-VL cannot process {num_videos} videos in a prompt, "
685717
"please increase max_model_len or reduce video limit by "
686718
"--limit-mm-per-prompt.")
687719

@@ -706,6 +738,8 @@ def _get_llm_num_vision_tokens(
706738
mm_inputs: list,
707739
data_type_key: str,
708740
image_processor,
741+
min_pixels: int,
742+
max_pixels: int,
709743
):
710744
"""Get number of vision tokens of multimodal inputs.
711745
@@ -715,12 +749,13 @@ def _get_llm_num_vision_tokens(
715749
image = to_numpy_array(mm_inputs[0])
716750
input_data_format = infer_channel_dimension_format(image)
717751
height, width = get_image_size(image, channel_dim=input_data_format)
752+
718753
_, _, llm_num_vision_tokens = _get_vision_info(
719754
image_processor,
720755
height=height,
721756
width=width,
722-
min_pixels=image_processor.min_pixels,
723-
max_pixels=image_processor.max_pixels,
757+
min_pixels=min_pixels,
758+
max_pixels=max_pixels,
724759
do_resize=image_processor.do_resize,
725760
data_type_key=data_type_key,
726761
mm_count=len(mm_inputs),
@@ -730,7 +765,8 @@ def _get_llm_num_vision_tokens(
730765

731766
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
732767
data_type_key: str, image_processor: Any,
733-
prompt_token_ids: List[int]) -> List[int]:
768+
prompt_token_ids: List[int], min_pixels: Optional[int],
769+
max_pixels: Optional[int]) -> List[int]:
734770
"""
735771
Expand pad tokens for multi-modal inputs (e.g., images or videos).
736772
@@ -741,6 +777,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
741777
data_type_key (str): The type of the multi-modal input.
742778
image_processor (Any): The image processor used to process the inputs.
743779
prompt_token_ids (List[int]): The list of token IDs in the prompt.
780+
min_pixels (int): min pixels to used for img processing
781+
max_pixels (int): max pixels to be used for img processing
744782
745783
Returns:
746784
List[int]: The list of token IDs for the multi-modal inputs.
@@ -757,6 +795,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
757795
[data] if data_type_key == "image" else data,
758796
data_type_key=data_type_key,
759797
image_processor=image_processor,
798+
min_pixels=min_pixels,
799+
max_pixels=max_pixels,
760800
)
761801
if cnt == 0:
762802
end_idx = indices[cnt]
@@ -773,6 +813,9 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
773813
def input_processor_for_qwen2_vl(
774814
ctx: InputContext,
775815
inputs: DecoderOnlyInputs,
816+
*,
817+
min_pixels: Optional[int] = None,
818+
max_pixels: Optional[int] = None,
776819
) -> DecoderOnlyInputs:
777820
multi_modal_data = inputs.get("multi_modal_data", None)
778821
if multi_modal_data is None:
@@ -783,6 +826,10 @@ def input_processor_for_qwen2_vl(
783826

784827
processor = cached_get_processor(ctx.model_config.model)
785828
image_processor = processor.image_processor
829+
# Apply processor kwarg overrides for image processor options
830+
min_pixels = min_pixels if min_pixels else image_processor.min_pixels
831+
max_pixels = max_pixels if max_pixels else image_processor.max_pixels
832+
786833
hf_config = ctx.get_hf_config(Qwen2VLConfig)
787834

788835
# To avoid redundant processing of vision objects (resize, rescale, etc.),
@@ -830,16 +877,22 @@ def input_processor_for_qwen2_vl(
830877
else:
831878
prompt_token_ids = _expand_pad_tokens(image_inputs,
832879
hf_config.image_token_id,
833-
make_batched_images, "image",
880+
make_batched_images,
881+
"image",
834882
image_processor,
835-
prompt_token_ids)
883+
prompt_token_ids,
884+
min_pixels=min_pixels,
885+
max_pixels=max_pixels)
836886

837887
if video_inputs is not None:
838888
prompt_token_ids = _expand_pad_tokens(video_inputs,
839889
hf_config.video_token_id,
840-
make_batched_videos, "video",
890+
make_batched_videos,
891+
"video",
841892
image_processor,
842-
prompt_token_ids)
893+
prompt_token_ids,
894+
min_pixels=min_pixels,
895+
max_pixels=max_pixels)
843896

844897
return token_inputs(
845898
prompt_token_ids=prompt_token_ids,

0 commit comments

Comments
 (0)