Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 165 additions & 10 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union)
TypedDict, TypeVar, Union, cast)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,6 +35,7 @@
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves

from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
Expand All @@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""

feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
"""

num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""


class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
Expand All @@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""

feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
"""

num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]

Expand Down Expand Up @@ -317,14 +356,40 @@ def _call_hf_processor(
for p, (h, w) in zip(pixel_values, image_sizes)
]

hf_config = self.info.get_hf_config()

tile_sizes = [
get_pixtral_hf_image_feature_grid_size(
hf_config.vision_config,
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2])
for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks.
embed_is_patch = torch.tensor(
flatten_2d_lists([([True] * ncols + [False]) * nrows
for ncols, nrows in tile_sizes]))
processed_outputs["num_crops"] = num_crops
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch

return processed_outputs

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
Expand Down Expand Up @@ -562,6 +627,23 @@ def _parse_and_validate_image_input(
if pixel_values is None and image_embeds is None:
return None

feat_is_patch = kwargs.pop("feat_is_patch", None)
if feat_is_patch is not None and not isinstance(
feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")

embed_is_patch = kwargs.pop("embed_is_patch", None)
if embed_is_patch is not None and not isinstance(
embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

num_crops = kwargs.pop("num_crops", None)
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
Expand All @@ -571,12 +653,18 @@ def _parse_and_validate_image_input(
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

if image_embeds is not None:
Expand All @@ -587,6 +675,9 @@ def _parse_and_validate_image_input(
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

raise AssertionError("This line should be unreachable.")
Expand Down Expand Up @@ -633,16 +724,74 @@ def _process_image_input(self,

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
if isinstance(image_features, torch.Tensor):
return self.multi_modal_projector(image_features)

feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]

image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.

Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""

# Insert columns of nan values according to `feat_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1)
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features

num_crops_per_image = num_crops.tolist()
feats_per_image = expanded_embedding.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)

embed_dim = expanded_embedding.shape[-1]
num_embeds = embed_is_patch.shape[0]

embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)

return embeds_in_batch

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
if kwargs.get("v0_path", False):
return vision_embeddings
else:
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)

def get_input_embeddings(
self,
Expand All @@ -651,8 +800,15 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
input_ids, inputs_embeds, cast(NestedTensors,
patch_embeddings),
self.config.image_token_index)
return inputs_embeds

Expand Down Expand Up @@ -705,6 +861,7 @@ def forward(
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
Expand Down Expand Up @@ -767,7 +924,6 @@ def apply(
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
Expand All @@ -778,8 +934,7 @@ def apply(
image_height=-1,
)

result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)

mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,8 +1484,8 @@ def _parse_and_validate_image_input(

img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
raise ValueError("Incorrect type of img_patch_id. "
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()

return MolmoImageInputs(
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,9 +1042,13 @@ def forward(
for img in pixel_values
]

patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]

# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)

# positional embeddings
Expand Down Expand Up @@ -1075,6 +1079,8 @@ def forward(
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)

# squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes)
return out

# (TODO) Add prefix argument for filtering out weights to be loaded
Expand Down