diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index be81c3883340..faac2b97722b 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -254,7 +254,7 @@ Multimodal Language Models
-
* - :code:`QWenLMHeadModel`
- Qwen-VL
- - Image\ :sup:`E`
+ - Image\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
* - :code:`Qwen2VLForConditionalGeneration`
diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py
index ed7e886d5780..454872c62837 100644
--- a/examples/offline_inference_vision_language_multi_image.py
+++ b/examples/offline_inference_vision_language_multi_image.py
@@ -19,7 +19,39 @@
]
-def load_phi3v(question, image_urls: List[str]):
+def load_qwenvl_chat(question: str, image_urls: List[str]):
+ model_name = "Qwen/Qwen-VL-Chat"
+ llm = LLM(
+ model=model_name,
+ trust_remote_code=True,
+ max_num_seqs=5,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
+ placeholders = "".join(f"Picture {i}:
\n"
+ for i, _ in enumerate(image_urls, start=1))
+
+ # This model does not have a chat_template attribute on its tokenizer,
+ # so we need to explicitly pass it. We use ChatML since it's used in the
+ # generation utils of the model:
+ # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ trust_remote_code=True)
+
+ # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
+ chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
+
+ messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
+ prompt = tokenizer.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ chat_template=chat_template)
+
+ stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
+ return llm, prompt, stop_token_ids, None, chat_template
+
+
+def load_phi3v(question: str, image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
@@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
- return llm, prompt, stop_token_ids, None
+ return llm, prompt, stop_token_ids, None, None
-def load_internvl(question, image_urls: List[str]):
+def load_internvl(question: str, image_urls: List[str]):
model_name = "OpenGVLab/InternVL2-2B"
llm = LLM(
@@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]):
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
- return llm, prompt, stop_token_ids, None
+ return llm, prompt, stop_token_ids, None, None
def load_qwen2_vl(question, image_urls: List[str]):
@@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]):
else:
image_data, _ = process_vision_info(messages)
- return llm, prompt, stop_token_ids, image_data
+ return llm, prompt, stop_token_ids, image_data, None
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
+ "qwen_vl_chat": load_qwenvl_chat,
}
def run_generate(model, question: str, image_urls: List[str]):
- llm, prompt, stop_token_ids, image_data = model_example_map[model](
+ llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]
@@ -146,29 +179,32 @@ def run_generate(model, question: str, image_urls: List[str]):
def run_chat(model: str, question: str, image_urls: List[str]):
- llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
+ llm, _, stop_token_ids, _, chat_template = model_example_map[model](
+ question, image_urls)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)
-
- outputs = llm.chat([{
- "role":
- "user",
- "content": [
- {
- "type": "text",
- "text": question,
- },
- *({
- "type": "image_url",
- "image_url": {
- "url": image_url
+ outputs = llm.chat(
+ [{
+ "role":
+ "user",
+ "content": [
+ {
+ "type": "text",
+ "text": question,
},
- } for image_url in image_urls),
- ],
- }],
- sampling_params=sampling_params)
+ *({
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ },
+ } for image_url in image_urls),
+ ],
+ }],
+ sampling_params=sampling_params,
+ chat_template=chat_template,
+ )
for o in outputs:
generated_text = o.outputs[0].text
diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py
index 05f5cbf8c343..5e7f1de99d6c 100644
--- a/tests/models/test_qwen.py
+++ b/tests/models/test_qwen.py
@@ -1,11 +1,17 @@
import pathlib
-from typing import List, Optional, Type
+from typing import Dict, List, Optional, Tuple, Type, Union
import pytest
+import torch
+from PIL.Image import Image
-from vllm.multimodal.utils import rescale_image_size
+from vllm.config import ModelConfig
+from vllm.inputs import InputContext, LLMInputs
+from vllm.multimodal.base import MultiModalInputs
+from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
-from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
+from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
+ VllmRunner, _ImageAssets)
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
@@ -23,19 +29,205 @@
"Picture 1:
\nWhat is the season?: ",
})
+HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1:
\nPicture 2:
\nCan you compare these images?\n" # noqa: E501
+HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1:
\nPicture 2:
\nDescribe the two images in detail.\n" # noqa: E501
+### Multimodal preprocessing tests
+SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
+# These values are specific to Qwen-VL/Chat; we can get these from the model
+# config also, but they are hardcoded here to keep the parameterize/fixtures
+# easy to read.
+IMG_START_ID = 151857
+IMG_END_ID = 151858
+IMG_PAD_ID = 151859
+TOKS_PER_IMG = 256
+VIS_ENC_DIM = 4096
+IMG_SIZE = 448
+
+
+def build_model_context(model_name: str,
+ tokenizer_name: Optional[str] = None,
+ trust_remote_code: bool = False):
+ """Creates an InputContext for a given model.
+
+ Args:
+ model_name: Name of the model being considered.
+ tokenizer_name: Name of the tokenizer being considered.
+ trust_remote_code: Whether or not to allow loading remote code.
+
+ Returns:
+ InputContext for the model being considered.
+ """
+ if tokenizer_name is None:
+ tokenizer_name = model_name
+ model_config = ModelConfig(
+ model_name,
+ tokenizer_name,
+ tokenizer_mode="auto",
+ trust_remote_code=trust_remote_code,
+ dtype="float32",
+ seed=0,
+ )
+ return InputContext(model_config)
+
+
+@pytest.fixture()
+def input_mapper_for_qwen():
+ # Lazy import to avoid initializing CUDA during test collection
+ from vllm.model_executor.models.qwen import input_mapper_for_qwen
+ return input_mapper_for_qwen
+
+
+@pytest.fixture()
+def input_processor_for_qwen():
+ # Lazy import to avoid initializing CUDA during test collection
+ from vllm.model_executor.models.qwen import input_processor_for_qwen
+ return input_processor_for_qwen
+
+
+@pytest.fixture()
+def qwen_vl_context() -> InputContext:
+ """Get an InputContext for Qwen-VL."""
+ return build_model_context(model_name="Qwen/Qwen-VL",
+ trust_remote_code=True)
+
+
+# Happy path tests for single/multi-image scenarios for the multimodal
+# input processor and mapper, respectively
+@pytest.mark.parametrize("num_images", [1, 2])
+def test_input_processor_valid_mm_data(input_processor_for_qwen,
+ qwen_vl_context: InputContext,
+ num_images: int):
+ """Happy cases for image inputs to Qwen's multimodal input processor."""
+ prompt = "".join(
+ [f"Picture {num}:
\n" for num in range(1, num_images + 1)])
+ inputs = LLMInputs(
+ prompt=prompt,
+ # When processing multimodal data for a multimodal model, the qwen
+ # input processor will overwrite the provided prompt_token_ids with
+ # the image prompts
+ prompt_token_ids=None,
+ multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
+ )
+ proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
+ assert isinstance(proc_inputs, dict)
+
+ # Each image should have one start / stop and a fixed context of 256
+ proc_tokens = proc_inputs["prompt_token_ids"]
+ assert proc_tokens.count(IMG_START_ID) == num_images
+ assert proc_tokens.count(IMG_END_ID) == num_images
+ assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
+
+
+@pytest.mark.parametrize(
+ "img_data,expected_shape",
+ [
+ # single / multi-image
+ (SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
+ (2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
+ # single / multi-image embeddings
+ (torch.rand(
+ (TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
+ (torch.rand(
+ (1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
+ (torch.rand(
+ (2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
+ ])
+def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
+ qwen_vl_context: InputContext,
+ img_data: Union[torch.Tensor, List[Image],
+ Image],
+ expected_shape: List[int]):
+ """Happy cases for image inputs to Qwen's multimodal input mapper."""
+ mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
+ # Ensure that we get the appropriately shaped pixel_values
+ # for images and image embeddings, respectively.
+ assert isinstance(mapped_img_data, MultiModalInputs)
+ assert "pixel_values" in mapped_img_data
+ assert mapped_img_data["pixel_values"].shape == expected_shape
+
+
+# Sad path tests for the multimodal input processor and mapper, respectively
+@pytest.mark.parametrize("mm_data", [
+ {
+ "image": torch.rand((5))
+ },
+ {
+ "image": torch.rand((5, 5, 5, 5, 5))
+ },
+])
+def test_input_processor_invalid_mm_data(input_processor_for_qwen,
+ qwen_vl_context: InputContext,
+ mm_data: Dict[str, torch.Tensor]):
+ """Test sad cases validated in Qwen's multimodal input processor."""
+ tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
+ trust_remote_code=True)
+ prompt = "Picture 1:
\n"
+ prompt_token_ids = tokenizer.encode(prompt)
+ inputs = LLMInputs(prompt=prompt,
+ prompt_token_ids=prompt_token_ids,
+ multi_modal_data=mm_data)
+ # Should fail since we have too many or too few dimensions for embeddings
+ with pytest.raises(ValueError):
+ input_processor_for_qwen(qwen_vl_context, inputs)
+
+
+@pytest.mark.parametrize(
+ "img_data",
+ [
+ # Wrong context length
+ torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
+ # Wrong visual encoder output size
+ torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
+ ])
+def test_input_mapper_invalid_mm_data(
+ input_mapper_for_qwen,
+ qwen_vl_context: InputContext,
+ img_data: Union[torch.Tensor, List[Image], Image],
+):
+ """Sad cases validated in Qwen VL's multimodal input mapper."""
+ with pytest.raises(ValueError):
+ input_mapper_for_qwen(qwen_vl_context, img_data)
+
+
+### End-to-end generation tests
+def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str,
+ assets: Union[_ImageAssets, List[ImageAsset]]) -> str:
+ """Given a temporary dir path, export one or more image assets into the
+ tempdir & replace its contents with the local path to the string so that
+ the HF version of Qwen-VL can resolve the path and load the image ni its
+ forward() call.
+
+ Args:
+ tmp_path: Tempdir for test under consideration.
+ prompt: Prompt with image placeholders.
+ assets: List of image assets whose len equals the num placeholders.
+ """
+ # Ensure that the number of placeholders matches the number of assets;
+ # If this is not true, the test is probably written incorrectly.
+ assert prompt.count("
") == len(assets)
+
+ # Replace the placeholders with local paths to the exported assets
+ for asset in assets:
+ image_tmp_path = tmp_path / f"{asset.name}.jpg"
+ asset.pil_image.save(image_tmp_path)
+ prompt = prompt.replace(
+ "
",
+ f"
{image_tmp_path}",
+ 1,
+ )
+ return prompt
+
-### Tests for multimodal Qwen models
def run_test(
- tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
- image_assets: _ImageAssets,
+ inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
- size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
+ mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
@@ -48,23 +240,6 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
- images = [asset.pil_image for asset in image_assets]
-
- # Export the images to a tempdir and substitute it into the hf prompt;
- # the contents between
/ will be ignored by VLLM, but the
- # transformers implementation for the visual transformer parses this to
- # reload it in the forward call; the contents are treated as a URL or a
- # local path.
- for idx, asset in enumerate(image_assets):
- image_tmp_path = tmp_path / f"{asset.name}.jpg"
- asset.pil_image.save(image_tmp_path)
- HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
- "
", f"
{image_tmp_path}")
-
- inputs_per_image = [(
- [prompt for _ in size_factors],
- [rescale_image_size(image, factor) for factor in size_factors],
- ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
@@ -72,11 +247,12 @@ def run_test(
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
- # Qwen encodes images into a fixed content size of 256
+ # Qwen encodes each image into a fixed content size of 256
with vllm_runner(model,
- max_model_len=300,
+ max_model_len=1024,
max_num_seqs=1,
dtype=dtype,
+ limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
@@ -85,7 +261,7 @@ def run_test(
max_tokens,
num_logprobs=num_logprobs,
images=images)
- for prompts, images in inputs_per_image
+ for prompts, images in inputs
]
with hf_runner(model, dtype=dtype) as hf_model:
@@ -94,7 +270,7 @@ def run_test(
max_tokens,
num_logprobs=num_logprobs,
images=images)
- for prompts, images in inputs_per_image
+ for prompts, images in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
@@ -125,19 +301,81 @@ def run_test(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
-def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
- model, size_factors, dtype, max_tokens,
- num_logprobs) -> None:
+def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath,
+ hf_runner: Type[HfRunner],
+ vllm_runner: Type[VllmRunner],
+ image_assets: _ImageAssets, model: str,
+ size_factors: List[float], dtype: str,
+ max_tokens: int,
+ num_logprobs: int) -> None:
+ """Tests multimodal models with single image prompts."""
+ images = [asset.pil_image for asset in image_assets]
+
+ prompts = [
+ get_prompt_with_path(tmp_path, prompt, [asset])
+ for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets)
+ ]
+
+ inputs = [(
+ [prompt for _ in size_factors],
+ [rescale_image_size(image, factor) for factor in size_factors],
+ ) for image, prompt in zip(images, prompts)]
+
+ run_test(
+ hf_runner,
+ vllm_runner,
+ inputs,
+ model,
+ dtype=dtype,
+ max_tokens=max_tokens,
+ num_logprobs=num_logprobs,
+ mm_limit=1,
+ tensor_parallel_size=1,
+ )
+
+
+@pytest.mark.parametrize("model", multimodal_models)
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # No image
+ [],
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ ],
+)
+@pytest.mark.parametrize("dtype", ["bfloat16"])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_multimodal_models_multi_image(tmp_path: pathlib.PosixPath,
+ hf_runner: Type[HfRunner],
+ vllm_runner: Type[VllmRunner],
+ image_assets: _ImageAssets, model: str,
+ size_factors: List[float], dtype: str,
+ max_tokens: int,
+ num_logprobs: int) -> None:
+ """Tests multimodal models with multi-image prompts."""
+ images = [asset.pil_image for asset in image_assets]
+ # Put all of the images into one prompt.
+ prompt = get_prompt_with_path(tmp_path, HF_MULTIIMAGE_IMAGE_PROMPT,
+ image_assets)
+ inputs = [([prompt for _ in size_factors],
+ [[rescale_image_size(image, factor) for image in images]
+ for factor in size_factors])]
+
run_test(
- tmp_path,
hf_runner,
vllm_runner,
- image_assets,
+ inputs,
model,
- size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
+ mm_limit=2,
tensor_parallel_size=1,
)
@@ -150,7 +388,7 @@ def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
@pytest.mark.parametrize("num_logprobs", [5])
def test_text_only_qwen_model_can_be_loaded_and_run(
vllm_runner: Type[VllmRunner],
- example_prompts,
+ example_prompts: List[str],
model: str,
*,
dtype: str,
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index a726ec10984c..18bc6b303f48 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -47,6 +47,7 @@
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
+from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
@@ -684,9 +685,12 @@ def input_processor_for_qwen(ctx: InputContext,
raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0]
- else:
- # TODO - handle multiple image inputs once the API is solidified
+ elif isinstance(image_data, Image.Image):
num_images = 1
+ elif is_list_of(image_data, Image.Image):
+ num_images = len(image_data)
+ else:
+ raise TypeError(f"Invalid image type: {type(image_data)}")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
@@ -767,11 +771,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]")
pixel_values = data
-
else:
transform = build_normalization_transform(image_size)
- # TODO - handle multiple image inputs once the API is solidified
- transformed_images = [transform(data)]
+ if not isinstance(data, (list, tuple)):
+ data = [data]
+ transformed_images = [transform(datum) for datum in data]
pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalInputs({"pixel_values": pixel_values})