Skip to content
266 changes: 195 additions & 71 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
from torch import nn
from transformers import PaliGemmaConfig
from transformers import BatchFeature, PaliGemmaConfig

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
BoundPromptReplacement,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptReplacementDetails,
decode_tokens, encode_tokens,
find_text_matches, find_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

Expand All @@ -47,95 +58,208 @@
PaliGemmaImageEmbeddingInputs]


def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
class PaliGemmaMultiModalProjector(nn.Module):

return get_max_siglip_image_tokens(vision_config)
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()

self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states

seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)

mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)
class PaliGemmaProcessingInfo(BaseProcessingInfo):

def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig)

def input_processor_for_paligemma(ctx: InputContext,
inputs: DecoderOnlyInputs):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}

multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)

model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig)

tokenizer = cached_tokenizer_from_config(model_config)
image_feature_size = hf_config.text_config.num_image_tokens
image_token_str = tokenizer.decode(hf_config.image_token_index)
bos_token = tokenizer.decode(hf_config.bos_token_id)
image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
class PaliGemmaDummyInputsBuilder(
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

orig_prompt = inputs.get("prompt")
orig_prompt_ids = inputs.get("prompt_token_ids")
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size

num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)

if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)

new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
class PaliGemmaMultiModalProcessor(
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

# The PaliGemma 2 tokenizer does not include a starting BOS token
if orig_prompt_ids[0] != hf_config.bos_token_id:
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],

Check failure on line 128 in vllm/model_executor/models/paligemma.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/paligemma.py:128:81: E501 Line too long (81 > 80)
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt)
# Paligemma2 is NOT adding <bos> token at the beginning of the prompt
# Adding <bos> token (value 2) to adapt with prompt replacement
if len(prompt_ids) == 0:
prompt_ids = [2]
elif prompt_ids[0] != 2:
prompt_ids = [2] + prompt_ids
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index

tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens

bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)

return [
PromptReplacement(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
)
]

def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer()

mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
mm_match_counts = {
modality: len(matches)
for modality, matches in mm_token_matches.items()
}

# If the search text does not represent a special token,
# it may have different token IDs in the prompt, because
# the tokens may go across the boundaries of the search text.
# ----
# e.g. when searching for "foo" in "food", if "food" itself makes
# up a token, then the token ID of "foo" will not appear at all
# ----
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if all(
mm_match_counts.get(modality, 0) >= item_count
for modality, item_count in mm_item_counts.items()
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
mm_token_matches,
mm_item_counts,
)

class PaliGemmaMultiModalProjector(nn.Module):
text = decode_tokens(tokenizer, token_ids)
matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_token_matches.items()
}
else:
text = decode_tokens(tokenizer, token_ids)

mm_text_matches = {
modality: find_text_matches(text, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
}
text = replace_text_matches(
text,
mm_text_matches,
mm_item_counts,
)

def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
token_ids = encode_tokens(tokenizer,
text,
add_special_tokens=False)
matched_repls = {
modality: [match.prompt_repl for match in token_matches]
for modality, token_matches in mm_text_matches.items()
}

placeholders = self._find_mm_placeholders(
matched_repls,
token_ids,
mm_item_counts,
)

self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
# Force to add newline at the end of prompt due to paligemma's format
if len(token_ids) and token_ids[-1] != 109:
token_ids.append(109)
text += "\n"

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
return token_ids, text, placeholders


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
@MULTIMODAL_REGISTRY.register_processor(
PaliGemmaMultiModalProcessor,
info=PaliGemmaProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
Expand Down
Loading