Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,18 @@
"glm4v": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
test_type=VLMTestType.IMAGE,
prompt_formatter=identity,
img_idx_to_prompt=lambda idx: "",
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<|begin_of_image|><|endoftext|><|end_of_image|>What's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<|begin_of_image|><|endoftext|><|end_of_image|>What is the season?", # noqa: E501
}),
max_model_len=2048,
max_num_seqs=2,
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
patch_hf_runner=model_utils.glm_patch_hf_runner,
patch_hf_runner=model_utils.glm4v_patch_hf_runner,
max_tokens=8,
num_logprobs=10,
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

vllm_runner_kwargs_: dict[str, Any] = {}
vllm_runner_kwargs_: dict[str, Any] = {
"disable_mm_preprocessor_cache": True,
}
if model_info.tokenizer:
vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer
if model_info.tokenizer_mode:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,29 @@ def processor(*args, text="", images=None, **kwargs):
return hf_model


def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
hf_processor = hf_model.processor
patch_padding_side(hf_processor)

def processor(*args, text="", images=None, **kwargs):
if images is None:
return hf_processor(*args, **kwargs)

images = [images] if isinstance(images, Image) else images

contents = re.findall(
r"<\|begin_of_image\|><\|endoftext\|><\|end_of_image\|>(.*?)<\|assistant\|>",
text,
)
assert len(contents) == len(images)

return hf_processor.apply_chat_template(
[{
"role": "user",
"image": images,
"content": text
}],
"image": image,
"content": content
} for image, content in zip(images, contents)],
add_generation_prompt=True,
tokenize=True,
return_dict=True,
Expand Down
12 changes: 8 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,18 @@ def __init__(
if rope_scaling is not None:
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides_kw.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
hf_overrides_str = json.dumps(hf_overrides)
msg = (
"`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides_kw.update(hf_override)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
hf_overrides_str = json.dumps(hf_overrides)
msg = (
"`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,17 @@ def _placeholder_str(self, modality: ModalityStr,
hf_config = self._model_config.hf_config
model_type = hf_config.model_type

if modality in ["image", "image_embeds"]:
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
import json
from typing import Iterable, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -463,7 +464,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
"The configuration of this model indicates that it supports "
"vision inputs, but you instantiated the text-only version "
"of this model. Please use the vision model by setting "
f"`--hf-overrides {hf_overrides!r}`")
f"`--hf-overrides '{json.dumps(hf_overrides)}'`")

super().__init__(vllm_config=vllm_config, prefix=prefix)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""

import json
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -354,7 +354,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
"The configuration of this model indicates that it supports "
"vision inputs, but you instantiated the text-only version "
"of this model. Please use the vision model by setting "
f"`--hf-overrides {hf_overrides!r}`")
f"`--hf-overrides '{json.dumps(hf_overrides)}'`")

super().__init__(vllm_config=vllm_config, prefix=prefix)

Expand Down