From 7596b42707fdc362b71632e9653c5c0f7b8950f6 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Wed, 5 Mar 2025 10:53:39 +0000 Subject: [PATCH 1/3] Pixtral-HF on V1 Signed-off-by: Linkun Chen --- vllm/model_executor/models/llava.py | 172 ++++++++++++++++++++++++-- vllm/model_executor/models/pixtral.py | 10 +- 2 files changed, 170 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 66b79f809bc9..a46b16006fd2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, - TypedDict, TypeVar, Union) + TypedDict, TypeVar, Union, cast) import torch import torch.nn as nn @@ -35,6 +35,7 @@ PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors +from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP @@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict): in which case the data is passed as a list instead of a batched tensor. """ + feat_is_patch: Union[torch.Tensor, List[torch.Tensor]] + """ + A boolean mask indicating which image features correspond + to patch tokens. + + Shape: `(batch_size, num_crops, num_patch)` + """ + + embed_is_patch: Union[torch.Tensor, List[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_embeds)` + """ + + num_crops: torch.Tensor + """Shape: `(batch_size, num_images)`""" + class LlavaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict): `hidden_size` must match the hidden size of language model backbone. """ + feat_is_patch: Union[torch.Tensor, List[torch.Tensor]] + """ + A boolean mask indicating which image features correspond + to patch tokens. + + Shape: `(batch_size, num_crops, num_patch)` + """ + + embed_is_patch: Union[torch.Tensor, List[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size, num_embeds)` + """ + + num_crops: torch.Tensor + """Shape: `(batch_size, num_images)`""" + LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs] @@ -317,6 +356,23 @@ def _call_hf_processor( for p, (h, w) in zip(pixel_values, image_sizes) ] + hf_config = self.info.get_hf_config() + + tile_sizes = [ + get_pixtral_hf_image_feature_grid_size( + hf_config.vision_config, + image_width=pixel_value.shape[-1], + image_height=pixel_value.shape[-2]) + for pixel_value in processed_outputs["pixel_values"] + ] + num_crops = torch.tensor([(ncols + 1) * nrows + for ncols, nrows in tile_sizes]) + embed_is_patch = [([True] * ncols + [False]) * nrows + for ncols, nrows in tile_sizes] + processed_outputs["num_crops"] = num_crops + processed_outputs["embed_is_patch"] = embed_is_patch + processed_outputs["feat_is_patch"] = embed_is_patch + return processed_outputs def _get_mm_fields_config( @@ -324,7 +380,14 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + num_crops = hf_inputs.get("num_crops", torch.empty(0)) + num_images = len(num_crops) + return dict( + feat_is_patch=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), + embed_is_patch=MultiModalFieldConfig.shared("image", num_images), + num_crops=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -562,6 +625,23 @@ def _parse_and_validate_image_input( if pixel_values is None and image_embeds is None: return None + feat_is_patch = kwargs.pop("feat_is_patch", None) + if feat_is_patch is not None and not isinstance( + feat_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of feat_is_patch. " + f"Got type: {type(feat_is_patch)}") + + embed_is_patch = kwargs.pop("embed_is_patch", None) + if embed_is_patch is not None and not isinstance( + embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + num_crops = kwargs.pop("num_crops", None) + if num_crops is not None and not isinstance(num_crops, torch.Tensor): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " @@ -571,12 +651,18 @@ def _parse_and_validate_image_input( return LlavaImagePixelInputs( type="pixel_values", data=flatten_bn(pixel_values), + feat_is_patch=feat_is_patch, + embed_is_patch=embed_is_patch, + num_crops=num_crops, ) return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( flatten_bn(pixel_values, concat=True)), + feat_is_patch=feat_is_patch, + embed_is_patch=embed_is_patch, + num_crops=num_crops, ) if image_embeds is not None: @@ -587,6 +673,9 @@ def _parse_and_validate_image_input( return LlavaImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds, concat=True), + feat_is_patch=feat_is_patch, + embed_is_patch=embed_is_patch, + num_crops=num_crops, ) raise AssertionError("This line should be unreachable.") @@ -633,16 +722,74 @@ def _process_image_input(self, assert self.vision_tower is not None image_features = self._process_image_pixels(image_input) - return self.multi_modal_projector(image_features) - def get_multimodal_embeddings( - self, **kwargs - ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]: + if isinstance(image_features, torch.Tensor): + return self.multi_modal_projector(image_features) + + feature_sizes = [ + image_feature.shape[0] for image_feature in image_features + ] + + image_embeds = self.multi_modal_projector(torch.cat(image_features)) + image_embeds = torch.split(image_embeds, feature_sizes) + return image_embeds + + def _get_mm_embeds( + self, + features: torch.Tensor, # Shape: (num_crop, num_patch, d) + feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) + num_crops: torch.Tensor, # Shape: (num_images,) + embed_is_patch: torch.Tensor, # Shape: (num_embeds,) + ) -> list[torch.Tensor]: + """Scatter the patch features into a contiguous tensor that corresponds + to the embedding tokens defined by the multimodal processor. + + Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment. + """ + + # Insert columns of nan values according to `feat_is_patch`. This work + # ideally should be done in `_process_image_input`, but + # `_process_image_input` is used in both V0 and V1 path. It's safer to + # put the logic here. + # FIXME: Move this logic to `_process_image_input` when v0 is + # deprecated. Merge this function with `Molmo._get_mm_embeds`. + feat_is_patch = feat_is_patch.view(-1) + embed_is_patch = embed_is_patch.view(-1) + expanded_embedding = torch.full( + (sum(num_crops), *features.shape[1:]), + torch.nan, + dtype=features.dtype).to(features.device) + expanded_embedding[feat_is_patch] = features + + num_crops_per_image = num_crops.tolist() + feats_per_image = expanded_embedding.split(num_crops_per_image) + f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) + + embed_dim = expanded_embedding.shape[-1] + num_embeds = embed_is_patch.shape[0] + + embeds_in_batch = list[torch.Tensor]() + for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image): + embeds = feats.new_full((num_embeds, embed_dim), torch.nan) + embeds[embed_is_patch] = feats[f_is_patch] + embeds_in_batch.append(embeds) + + return embeds_in_batch + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + if image_input["feat_is_patch"] is None: + return vision_embeddings + else: + nested_emb = [ + self._get_mm_embeds(*args) for args in zip( + vision_embeddings, image_input["feat_is_patch"], + image_input["num_crops"], image_input["embed_is_patch"]) + ] + return flatten_2d_lists(nested_emb) def get_input_embeddings( self, @@ -651,8 +798,15 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: + # Extract the patch tokens + patch_embeddings = json_map_leaves( + lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), + cast(JSONTree[torch.Tensor], multimodal_embeddings), + ) + inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, + input_ids, inputs_embeds, cast(NestedTensors, + patch_embeddings), self.config.image_token_index) return inputs_embeds @@ -767,7 +921,6 @@ def apply( prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - return_mm_hashes: bool = False, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -778,8 +931,7 @@ def apply( image_height=-1, ) - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d2388dda3f4a..8acc07ac353a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1042,9 +1042,13 @@ def forward( for img in pixel_values ] + patch_embeds = [ + p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list + ] + embed_sizes = [p.shape[1] for p in patch_embeds] + # flatten to a single sequence - patch_embeds = torch.cat( - [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = torch.cat(patch_embeds, dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings @@ -1075,6 +1079,8 @@ def forward( out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, self.config.num_hidden_layers) + # squeeze dim 0 and split into separate tensors for each image + out = torch.split(torch.squeeze(out), embed_sizes) return out # (TODO) Add prefix argument for filtering out weights to be loaded From 4dc3a19580469b31a4b4871b6382ce6e9b62d020 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Wed, 5 Mar 2025 21:20:28 +0000 Subject: [PATCH 2/3] fix for multi_image case * use "v0_path" kwargs to separate v0/v1 * flatten `*_is_patch` mask Signed-off-by: Linkun Chen --- vllm/model_executor/models/llava.py | 17 ++++++++++------- vllm/model_executor/models/molmo.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a46b16006fd2..10796f9f2a6e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -367,8 +367,11 @@ def _call_hf_processor( ] num_crops = torch.tensor([(ncols + 1) * nrows for ncols, nrows in tile_sizes]) - embed_is_patch = [([True] * ncols + [False]) * nrows - for ncols, nrows in tile_sizes] + # Each image may result to masks of different sizes, so we need to + # flatten the list and later use `num_crops` to get per-image masks. + embed_is_patch = torch.tensor( + flatten_2d_lists([([True] * ncols + [False]) * nrows + for ncols, nrows in tile_sizes])) processed_outputs["num_crops"] = num_crops processed_outputs["embed_is_patch"] = embed_is_patch processed_outputs["feat_is_patch"] = embed_is_patch @@ -380,13 +383,12 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_crops = hf_inputs.get("num_crops", torch.empty(0)) - num_images = len(num_crops) - + num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1) return dict( feat_is_patch=MultiModalFieldConfig.flat_from_sizes( "image", num_crops), - embed_is_patch=MultiModalFieldConfig.shared("image", num_images), + embed_is_patch=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops), num_crops=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), @@ -781,7 +783,7 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: if image_input is None: return None vision_embeddings = self._process_image_input(image_input) - if image_input["feat_is_patch"] is None: + if kwargs.get("v0_path", False): return vision_embeddings else: nested_emb = [ @@ -859,6 +861,7 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cc571bc24bac..554080533059 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1484,8 +1484,8 @@ def _parse_and_validate_image_input( img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): - raise ValueError("Incorrect type of num_crops. " - f"Got type: {type(num_crops)}") + raise ValueError("Incorrect type of img_patch_id. " + f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() return MolmoImageInputs( From 9c213c5b7aeda4f44cd8624f8c1805709fa73955 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Thu, 6 Mar 2025 06:49:05 +0000 Subject: [PATCH 3/3] revert unexpected change, update doc Signed-off-by: Linkun Chen --- docs/source/models/supported_models.md | 6 +----- vllm/model_executor/models/llava.py | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index fc363585b0e7..6cae2971ab26 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -866,7 +866,7 @@ See [this page](#generative-models) for more information on how to use generativ - * `PixtralForConditionalGeneration` * Pixtral * T + I+ - * `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b` (see note), etc. + * `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc. * * ✅︎ * ✅︎ @@ -930,10 +930,6 @@ For more details, please see: Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release. ::: -:::{note} -`mistral-community/pixtral-12b` does not support V1 yet. -::: - :::{note} To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. ::: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 10796f9f2a6e..e83dfd320bb6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -924,6 +924,7 @@ def apply( prompt: Union[str, list[int]], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -934,7 +935,8 @@ def apply( image_height=-1, ) - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) + result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, + return_mm_hashes) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts()