Skip to content

Commit 9da25a8

Browse files
[MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) (#8029)
Signed-off-by: Alex-Brooks <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 8685ba1 commit 9da25a8

File tree

8 files changed

+1110
-208
lines changed

8 files changed

+1110
-208
lines changed

docs/source/models/supported_models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ Multimodal Language Models
242242
- Image\ :sup:`+`
243243
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
244244
-
245+
* - :code:`QWenLMHeadModel`
246+
- Qwen
247+
- Image
248+
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
249+
-
245250
* - :code:`UltravoxModel`
246251
- Ultravox
247252
- Audio\ :sup:`E+`

examples/offline_inference_vision_language.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,20 @@ def run_blip2(question):
159159
return llm, prompt, stop_token_ids
160160

161161

162+
# Qwen
163+
def run_qwen_vl(question):
164+
165+
llm = LLM(
166+
model="Qwen/Qwen-VL",
167+
trust_remote_code=True,
168+
max_num_seqs=5,
169+
)
170+
171+
prompt = f"{question}Picture 1: <img></img>\n"
172+
stop_token_ids = None
173+
return llm, prompt, stop_token_ids
174+
175+
162176
model_example_map = {
163177
"llava": run_llava,
164178
"llava-next": run_llava_next,
@@ -169,6 +183,7 @@ def run_blip2(question):
169183
"minicpmv": run_minicpmv,
170184
"blip-2": run_blip2,
171185
"internvl_chat": run_internvl,
186+
"qwen_vl": run_qwen_vl,
172187
}
173188

174189

tests/models/test_qwen.py

Lines changed: 142 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,165 @@
1-
from typing import Type
1+
import pathlib
2+
from typing import List, Optional, Type
23

34
import pytest
45

5-
from ..conftest import HfRunner, VllmRunner
6+
from vllm.multimodal.utils import rescale_image_size
7+
8+
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
69
from .utils import check_logprobs_close
710

8-
models = ["qwen/qwen-vl"]
11+
pytestmark = pytest.mark.vlm
912

13+
text_only_models = [
14+
"Qwen/Qwen-7B-Chat" # Has no visual component
15+
]
1016

11-
@pytest.mark.parametrize("dtype", ["half"])
12-
@pytest.mark.parametrize("max_tokens", [32])
13-
@pytest.mark.parametrize("num_logprobs", [5])
14-
@pytest.mark.parametrize("model", models)
15-
def test_text_only_qwen_model(
17+
multimodal_models = ["Qwen/Qwen-VL"]
18+
19+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
20+
"stop_sign":
21+
"Picture 1: <img></img>\nWhat's the content of the image?: ",
22+
"cherry_blossom":
23+
"Picture 1: <img></img>\nWhat is the season?: ",
24+
})
25+
26+
27+
### Tests for multimodal Qwen models
28+
def run_test(
29+
tmp_path: pathlib.PosixPath,
1630
hf_runner: Type[HfRunner],
1731
vllm_runner: Type[VllmRunner],
18-
example_prompts,
32+
image_assets: _ImageAssets,
1933
model: str,
2034
*,
35+
size_factors: List[float],
2136
dtype: str,
2237
max_tokens: int,
2338
num_logprobs: int,
39+
tensor_parallel_size: int,
40+
distributed_executor_backend: Optional[str] = None,
2441
):
25-
# This test checks language inputs only, since the visual component
26-
# for qwen-vl is still unsupported in VLLM. In the near-future, the
27-
# implementation and this test will be extended to consider
28-
# visual inputs as well.
42+
"""Inference result should be the same between hf and vllm.
43+
44+
All the image fixtures for the test is under tests/images.
45+
For huggingface runner, we provide the PIL images as input.
46+
For vllm runner, we provide MultiModalDataDict objects
47+
and corresponding MultiModalConfig as input.
48+
Note, the text input is also adjusted to abide by vllm contract.
49+
The text output is sanitized to be able to compare with hf.
50+
"""
51+
images = [asset.pil_image for asset in image_assets]
52+
53+
# Export the images to a tempdir and substitute it into the hf prompt;
54+
# the contents between <img>/</img> will be ignored by VLLM, but the
55+
# transformers implementation for the visual transformer parses this to
56+
# reload it in the forward call; the contents are treated as a URL or a
57+
# local path.
58+
for idx, asset in enumerate(image_assets):
59+
image_tmp_path = tmp_path / f"{asset.name}.jpg"
60+
asset.pil_image.save(image_tmp_path)
61+
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
62+
"<img></img>", f"<img>{image_tmp_path}</img>")
63+
64+
inputs_per_image = [(
65+
[prompt for _ in size_factors],
66+
[rescale_image_size(image, factor) for factor in size_factors],
67+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
68+
69+
# NOTE: take care of the order. run vLLM first, and then run HF.
70+
# vLLM needs a fresh new process without cuda initialization.
71+
# if we run HF first, the cuda initialization will be done and it
72+
# will hurt multiprocessing backend with fork method (the default method).
73+
74+
# max_model_len should be greater than image_feature_size
75+
# Qwen encodes images into a fixed content size of 256
76+
with vllm_runner(model,
77+
max_model_len=300,
78+
max_num_seqs=1,
79+
dtype=dtype,
80+
tensor_parallel_size=tensor_parallel_size,
81+
distributed_executor_backend=distributed_executor_backend,
82+
enforce_eager=True) as vllm_model:
83+
vllm_outputs_per_image = [
84+
vllm_model.generate_greedy_logprobs(prompts,
85+
max_tokens,
86+
num_logprobs=num_logprobs,
87+
images=images)
88+
for prompts, images in inputs_per_image
89+
]
90+
2991
with hf_runner(model, dtype=dtype) as hf_model:
30-
hf_outputs = hf_model.generate_greedy_logprobs_limit(
31-
example_prompts,
32-
max_tokens,
33-
num_logprobs=num_logprobs,
92+
hf_outputs_per_image = [
93+
hf_model.generate_greedy_logprobs_limit(prompts,
94+
max_tokens,
95+
num_logprobs=num_logprobs,
96+
images=images)
97+
for prompts, images in inputs_per_image
98+
]
99+
100+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
101+
vllm_outputs_per_image):
102+
103+
check_logprobs_close(
104+
outputs_0_lst=hf_outputs,
105+
outputs_1_lst=vllm_outputs,
106+
name_0="hf",
107+
name_1="vllm",
34108
)
35109

110+
111+
@pytest.mark.parametrize("model", multimodal_models)
112+
@pytest.mark.parametrize(
113+
"size_factors",
114+
[
115+
# No image
116+
[],
117+
# Single-scale
118+
[1.0],
119+
# Single-scale, batched
120+
[1.0, 1.0, 1.0],
121+
# Multi-scale
122+
[0.25, 0.5, 1.0],
123+
],
124+
)
125+
@pytest.mark.parametrize("dtype", ["bfloat16"])
126+
@pytest.mark.parametrize("max_tokens", [8])
127+
@pytest.mark.parametrize("num_logprobs", [5])
128+
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
129+
model, size_factors, dtype, max_tokens,
130+
num_logprobs) -> None:
131+
run_test(
132+
tmp_path,
133+
hf_runner,
134+
vllm_runner,
135+
image_assets,
136+
model,
137+
size_factors=size_factors,
138+
dtype=dtype,
139+
max_tokens=max_tokens,
140+
num_logprobs=num_logprobs,
141+
tensor_parallel_size=1,
142+
)
143+
144+
145+
# Ensure that a text-only Qwen model can still be loaded and
146+
# used for inference in VLLM without throwing.
147+
@pytest.mark.parametrize("model", text_only_models)
148+
@pytest.mark.parametrize("dtype", ["bfloat16"])
149+
@pytest.mark.parametrize("max_tokens", [32])
150+
@pytest.mark.parametrize("num_logprobs", [5])
151+
def test_text_only_qwen_model_can_be_loaded_and_run(
152+
vllm_runner: Type[VllmRunner],
153+
example_prompts,
154+
model: str,
155+
*,
156+
dtype: str,
157+
max_tokens: int,
158+
num_logprobs: int,
159+
):
36160
with vllm_runner(model, dtype=dtype) as vllm_model:
37-
vllm_outputs = vllm_model.generate_greedy_logprobs(
161+
vllm_model.generate_greedy_logprobs(
38162
example_prompts,
39163
max_tokens,
40164
num_logprobs=num_logprobs,
41165
)
42-
43-
check_logprobs_close(
44-
outputs_0_lst=hf_outputs,
45-
outputs_1_lst=vllm_outputs,
46-
name_0="hf",
47-
name_1="vllm",
48-
)

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def _placeholder_str(self, modality: ModalityStr,
150150
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
151151
# These models do not use image tokens in the prompt
152152
return None
153+
if model_type == "qwen":
154+
return f"Picture {current_count}: <img></img>"
153155
if model_type.startswith("llava"):
154156
return self._cached_token_str(self._tokenizer,
155157
hf_config.image_token_index)

0 commit comments

Comments
 (0)