-
-
Notifications
You must be signed in to change notification settings - Fork 13k
[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support
#27361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import math | ||
| from collections.abc import Iterable, Mapping, Sequence | ||
| from typing import Annotated, Literal | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
@@ -53,6 +54,7 @@ | |
| count_tiles, | ||
| ) | ||
| from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config | ||
| from vllm.utils.tensor_schema import TensorSchema, TensorShape | ||
| from vllm.v1.sample.logits_processor import ( | ||
| AdapterLogitsProcessor, | ||
| RequestLogitsProcessor, | ||
|
|
@@ -65,6 +67,28 @@ | |
| _IMAGE_TOKEN = "<image>" | ||
|
|
||
|
|
||
| class DeepseekOCRImagePixelInputs(TensorSchema): | ||
| """ | ||
| Dimensions: | ||
| - b: Batch size | ||
| - n: Number of images | ||
| - p: Number of patches | ||
| - base_size: Base size of the processor | ||
| - image_size: Image size of the processor | ||
| """ | ||
|
|
||
| type: Literal["pixel_values"] | ||
| data: Annotated[ | ||
| torch.Tensor, | ||
| TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}), | ||
| ] | ||
| images_crop: Annotated[ | ||
| torch.Tensor, | ||
| TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}), | ||
| ] | ||
| images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] | ||
|
|
||
|
|
||
| class NoRepeatNGramLogitsProcessor: | ||
| def __init__( | ||
| self, | ||
|
|
@@ -260,10 +284,14 @@ def _get_mm_fields_config( | |
| hf_inputs: BatchFeature, | ||
| hf_processor_mm_kwargs: Mapping[str, object], | ||
| ) -> Mapping[str, MultiModalFieldConfig]: | ||
| images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2))) | ||
| patches_per_image = images_spatial_crop.prod(dim=-1) | ||
| return dict( | ||
| pixel_values=MultiModalFieldConfig.batched("image"), | ||
| images_spatial_crop=MultiModalFieldConfig.batched("image"), | ||
| images_crop=MultiModalFieldConfig.batched("image"), | ||
| images_crop=MultiModalFieldConfig.flat_from_sizes( | ||
| "image", patches_per_image | ||
| ), | ||
| ) | ||
|
|
||
| def _get_prompt_updates( | ||
|
|
@@ -302,42 +330,15 @@ def get_replacement_deepseek_vl2(item_idx: int): | |
| ) | ||
| ] | ||
|
|
||
| # TODO(Isotr0py): Check if we still need this workaround for | ||
| # deepseek-ocr processor. | ||
| # def _cached_apply_hf_processor( | ||
| # self, | ||
| # prompt: str | list[int], | ||
| # mm_data_items: MultiModalDataItems, | ||
| # hf_processor_mm_kwargs: Mapping[str, object], | ||
| # tokenization_kwargs: Mapping[str, object], | ||
| # mm_uuids: MultiModalUUIDDict | None = None, | ||
| # ) -> tuple[list[int], MultiModalKwargs, bool]: | ||
| # # The processor logic is different for len(images) <= 2 vs > 2 | ||
| # # Since the processing cache assumes that the processor output is | ||
| # # invariant of how many images are passed per prompt, we only | ||
| # # perform caching for the most common case | ||
| # if mm_data_items.get_count("image", strict=False) > 2: | ||
| # # This code path corresponds to the cache being disabled | ||
| # return self._apply_hf_processor_main( | ||
| # prompt=prompt, | ||
| # mm_items=mm_data_items, | ||
| # hf_processor_mm_kwargs=hf_processor_mm_kwargs, | ||
| # enable_hf_prompt_update=True, | ||
| # ) | ||
|
|
||
| # return super()._cached_apply_hf_processor( | ||
| # prompt=prompt, | ||
| # mm_data_items=mm_data_items, | ||
| # hf_processor_mm_kwargs=hf_processor_mm_kwargs, | ||
| # ) | ||
|
|
||
|
|
||
| @MULTIMODAL_REGISTRY.register_processor( | ||
| DeepseekOCRMultiModalProcessor, | ||
| info=DeepseekOCRProcessingInfo, | ||
| dummy_inputs=DeepseekOCRDummyInputsBuilder, | ||
| ) | ||
| class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): | ||
| merge_by_field_config = True | ||
|
|
||
| hf_to_vllm_mapper = WeightsMapper( | ||
| orig_to_new_prefix={ | ||
| # map prefix for language backbone | ||
|
|
@@ -389,6 +390,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.vision_model = DeepCLIPVisionTransformer( | ||
| config=clip_vision_config, | ||
| quant_config=quant_config, | ||
| prefix=maybe_prefix(prefix, "vision_model"), | ||
| ) | ||
|
|
||
| self.projector = MlpProjector(self.projector_config) | ||
|
|
@@ -426,7 +428,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| self.language_model.make_empty_intermediate_tensors | ||
| ) | ||
|
|
||
| def _parse_and_validate_image_input(self, **kwargs: object): | ||
| def _parse_and_validate_image_input( | ||
| self, **kwargs: object | ||
| ) -> DeepseekOCRImagePixelInputs | None: | ||
| pixel_values = kwargs.pop("pixel_values", None) | ||
| images_spatial_crop = kwargs.pop("images_spatial_crop", None) | ||
| images_crop = kwargs.pop("images_crop", None) | ||
|
|
@@ -435,23 +439,16 @@ def _parse_and_validate_image_input(self, **kwargs: object): | |
| return None | ||
|
|
||
| if pixel_values is not None: | ||
| if not isinstance(pixel_values, (torch.Tensor, list)): | ||
| raise ValueError( | ||
| f"Incorrect type of pixel values. Got type: {type(pixel_values)}" | ||
| ) | ||
|
|
||
| if not isinstance(images_spatial_crop, (torch.Tensor, list)): | ||
| raise ValueError( | ||
| "Incorrect type of image sizes. " | ||
| f"Got type: {type(images_spatial_crop)}" | ||
| ) | ||
|
|
||
| if not isinstance(images_crop, (torch.Tensor, list)): | ||
| raise ValueError( | ||
| f"Incorrect type of image crop. Got type: {type(images_crop)}" | ||
| ) | ||
|
|
||
| return [pixel_values, images_crop, images_spatial_crop] | ||
| base_size = self.vision_config.image_size | ||
| return DeepseekOCRImagePixelInputs( | ||
| type="pixel_values", | ||
| data=pixel_values, | ||
| images_crop=images_crop, | ||
| images_spatial_crop=images_spatial_crop, | ||
| resolve_bindings={ | ||
| "base_size": base_size, | ||
| }, | ||
| ) | ||
|
|
||
| raise AssertionError("This line should be unreachable.") | ||
|
|
||
|
|
@@ -518,10 +515,11 @@ def _pixel_values_to_embedding( | |
| ) -> NestedTensors: | ||
| images_in_this_batch = [] | ||
|
|
||
| images_crop = images_crop.split(images_spatial_crop.prod(dim=-1).tolist()) | ||
| for jdx in range(images_spatial_crop.size(0)): | ||
| patches = images_crop[jdx][0].to(torch.bfloat16) | ||
| image_ori = pixel_values[jdx] | ||
| crop_shape = images_spatial_crop[jdx][0] | ||
| patches = images_crop[jdx] | ||
|
Comment on lines
518
to
523
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new logic assumes Useful? React with 👍 / 👎. |
||
| image_ori = pixel_values[[jdx]] | ||
| crop_shape = images_spatial_crop[jdx] | ||
|
|
||
| global_features = self._encode_global_features(image_ori) | ||
| local_features = self._encode_local_features(patches, crop_shape) | ||
|
|
@@ -540,10 +538,12 @@ def _pixel_values_to_embedding( | |
|
|
||
| return images_in_this_batch | ||
|
|
||
| def _process_image_input(self, image_input) -> torch.Tensor: | ||
| pixel_values = image_input[0].to(torch.bfloat16) | ||
| images_crop = image_input[1] | ||
| images_spatial_crop = image_input[2].to(dtype=torch.long) | ||
| def _process_image_input( | ||
| self, image_input: DeepseekOCRImagePixelInputs | ||
| ) -> torch.Tensor: | ||
| pixel_values = image_input.data | ||
| images_crop = image_input.images_crop | ||
| images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long) | ||
|
|
||
| vision_features = self._pixel_values_to_embedding( | ||
| pixel_values=pixel_values, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.