diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index fc363585b0e7..6cae2971ab26 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -866,7 +866,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I+
- * `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b` (see note), etc.
+ * `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
*
* ✅︎
* ✅︎
@@ -930,10 +930,6 @@ For more details, please see:
Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release.
:::
-:::{note}
-`mistral-community/pixtral-12b` does not support V1 yet.
-:::
-
:::{note}
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
:::
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index 66b79f809bc9..e83dfd320bb6 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
- TypedDict, TypeVar, Union)
+ TypedDict, TypeVar, Union, cast)
import torch
import torch.nn as nn
@@ -35,6 +35,7 @@
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
+from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
@@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
+ feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
+ """
+ A boolean mask indicating which image features correspond
+ to patch tokens.
+
+ Shape: `(batch_size, num_crops, num_patch)`
+ """
+
+ embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size, num_embeds)`
+ """
+
+ num_crops: torch.Tensor
+ """Shape: `(batch_size, num_images)`"""
+
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
+ feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
+ """
+ A boolean mask indicating which image features correspond
+ to patch tokens.
+
+ Shape: `(batch_size, num_crops, num_patch)`
+ """
+
+ embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
+ """
+ A boolean mask indicating which image embeddings correspond
+ to patch tokens.
+
+ Shape: `(batch_size, num_embeds)`
+ """
+
+ num_crops: torch.Tensor
+ """Shape: `(batch_size, num_images)`"""
+
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
@@ -317,6 +356,26 @@ def _call_hf_processor(
for p, (h, w) in zip(pixel_values, image_sizes)
]
+ hf_config = self.info.get_hf_config()
+
+ tile_sizes = [
+ get_pixtral_hf_image_feature_grid_size(
+ hf_config.vision_config,
+ image_width=pixel_value.shape[-1],
+ image_height=pixel_value.shape[-2])
+ for pixel_value in processed_outputs["pixel_values"]
+ ]
+ num_crops = torch.tensor([(ncols + 1) * nrows
+ for ncols, nrows in tile_sizes])
+ # Each image may result to masks of different sizes, so we need to
+ # flatten the list and later use `num_crops` to get per-image masks.
+ embed_is_patch = torch.tensor(
+ flatten_2d_lists([([True] * ncols + [False]) * nrows
+ for ncols, nrows in tile_sizes]))
+ processed_outputs["num_crops"] = num_crops
+ processed_outputs["embed_is_patch"] = embed_is_patch
+ processed_outputs["feat_is_patch"] = embed_is_patch
+
return processed_outputs
def _get_mm_fields_config(
@@ -324,7 +383,13 @@ def _get_mm_fields_config(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
+ num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict(
+ feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_crops),
+ embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_crops),
+ num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -562,6 +627,23 @@ def _parse_and_validate_image_input(
if pixel_values is None and image_embeds is None:
return None
+ feat_is_patch = kwargs.pop("feat_is_patch", None)
+ if feat_is_patch is not None and not isinstance(
+ feat_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of feat_is_patch. "
+ f"Got type: {type(feat_is_patch)}")
+
+ embed_is_patch = kwargs.pop("embed_is_patch", None)
+ if embed_is_patch is not None and not isinstance(
+ embed_is_patch, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of embed_is_patch. "
+ f"Got type: {type(embed_is_patch)}")
+
+ num_crops = kwargs.pop("num_crops", None)
+ if num_crops is not None and not isinstance(num_crops, torch.Tensor):
+ raise ValueError("Incorrect type of num_crops. "
+ f"Got type: {type(num_crops)}")
+
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
@@ -571,12 +653,18 @@ def _parse_and_validate_image_input(
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
+ feat_is_patch=feat_is_patch,
+ embed_is_patch=embed_is_patch,
+ num_crops=num_crops,
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
+ feat_is_patch=feat_is_patch,
+ embed_is_patch=embed_is_patch,
+ num_crops=num_crops,
)
if image_embeds is not None:
@@ -587,6 +675,9 @@ def _parse_and_validate_image_input(
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
+ feat_is_patch=feat_is_patch,
+ embed_is_patch=embed_is_patch,
+ num_crops=num_crops,
)
raise AssertionError("This line should be unreachable.")
@@ -633,16 +724,74 @@ def _process_image_input(self,
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
- return self.multi_modal_projector(image_features)
- def get_multimodal_embeddings(
- self, **kwargs
- ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
+ if isinstance(image_features, torch.Tensor):
+ return self.multi_modal_projector(image_features)
+
+ feature_sizes = [
+ image_feature.shape[0] for image_feature in image_features
+ ]
+
+ image_embeds = self.multi_modal_projector(torch.cat(image_features))
+ image_embeds = torch.split(image_embeds, feature_sizes)
+ return image_embeds
+
+ def _get_mm_embeds(
+ self,
+ features: torch.Tensor, # Shape: (num_crop, num_patch, d)
+ feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
+ num_crops: torch.Tensor, # Shape: (num_images,)
+ embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
+ ) -> list[torch.Tensor]:
+ """Scatter the patch features into a contiguous tensor that corresponds
+ to the embedding tokens defined by the multimodal processor.
+
+ Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
+ """
+
+ # Insert columns of nan values according to `feat_is_patch`. This work
+ # ideally should be done in `_process_image_input`, but
+ # `_process_image_input` is used in both V0 and V1 path. It's safer to
+ # put the logic here.
+ # FIXME: Move this logic to `_process_image_input` when v0 is
+ # deprecated. Merge this function with `Molmo._get_mm_embeds`.
+ feat_is_patch = feat_is_patch.view(-1)
+ embed_is_patch = embed_is_patch.view(-1)
+ expanded_embedding = torch.full(
+ (sum(num_crops), *features.shape[1:]),
+ torch.nan,
+ dtype=features.dtype).to(features.device)
+ expanded_embedding[feat_is_patch] = features
+
+ num_crops_per_image = num_crops.tolist()
+ feats_per_image = expanded_embedding.split(num_crops_per_image)
+ f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
+
+ embed_dim = expanded_embedding.shape[-1]
+ num_embeds = embed_is_patch.shape[0]
+
+ embeds_in_batch = list[torch.Tensor]()
+ for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
+ embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
+ embeds[embed_is_patch] = feats[f_is_patch]
+ embeds_in_batch.append(embeds)
+
+ return embeds_in_batch
+
+ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
- return vision_embeddings
+ if kwargs.get("v0_path", False):
+ return vision_embeddings
+ else:
+ nested_emb = [
+ self._get_mm_embeds(*args) for args in zip(
+ vision_embeddings, image_input["feat_is_patch"],
+ image_input["num_crops"], image_input["embed_is_patch"])
+ ]
+ return flatten_2d_lists(nested_emb)
def get_input_embeddings(
self,
@@ -651,8 +800,15 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
+ # Extract the patch tokens
+ patch_embeddings = json_map_leaves(
+ lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
+ cast(JSONTree[torch.Tensor], multimodal_embeddings),
+ )
+
inputs_embeds = merge_multimodal_embeddings(
- input_ids, inputs_embeds, multimodal_embeddings,
+ input_ids, inputs_embeds, cast(NestedTensors,
+ patch_embeddings),
self.config.image_token_index)
return inputs_embeds
@@ -705,6 +861,7 @@ def forward(
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
+ kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index cc571bc24bac..554080533059 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -1484,8 +1484,8 @@ def _parse_and_validate_image_input(
img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
- raise ValueError("Incorrect type of num_crops. "
- f"Got type: {type(num_crops)}")
+ raise ValueError("Incorrect type of img_patch_id. "
+ f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()
return MolmoImageInputs(
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index d2388dda3f4a..8acc07ac353a 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -1042,9 +1042,13 @@ def forward(
for img in pixel_values
]
+ patch_embeds = [
+ p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
+ ]
+ embed_sizes = [p.shape[1] for p in patch_embeds]
+
# flatten to a single sequence
- patch_embeds = torch.cat(
- [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
+ patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
@@ -1075,6 +1079,8 @@ def forward(
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)
+ # squeeze dim 0 and split into separate tensors for each image
+ out = torch.split(torch.squeeze(out), embed_sizes)
return out
# (TODO) Add prefix argument for filtering out weights to be loaded