From 954427a93c605bbe7b826582e69697496e1d818a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Wed, 20 Aug 2025 10:50:31 +0800 Subject: [PATCH 01/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- docs/models/supported_models.md | 3 +- examples/offline_inference/dolphin.py | 361 ++++++++++ examples/offline_inference/donut.py | 81 +++ tests/models/registry.py | 4 +- vllm/engine/llm_engine.py | 2 +- vllm/model_executor/models/donut.py | 391 +++++++++++ vllm/model_executor/models/registry.py | 3 +- vllm/model_executor/models/swin.py | 909 +++++++++++++++++++++++++ vllm/multimodal/profiling.py | 2 +- vllm/v1/engine/processor.py | 2 +- 10 files changed, 1752 insertions(+), 6 deletions(-) create mode 100644 examples/offline_inference/dolphin.py create mode 100644 examples/offline_inference/donut.py create mode 100644 vllm/model_executor/models/donut.py create mode 100644 vllm/model_executor/models/swin.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 1d165fa6f16b..9896be316275 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -613,6 +613,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | @@ -815,4 +816,4 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. \ No newline at end of file diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 000000000000..c2b6aecdfeda --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +import sys +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from datasets import load_dataset +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, + logprobs=0, + prompt_logprobs=None, + skip_special_tokens=False, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float32", + enforce_eager=True, + max_num_seqs=16, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +parser.add_argument( + "--task", + type=str, + default="full", + choices=["full", "segment", "text", "table"], + help="The task to perform. " + "'full': layout analysis then OCR (default). " + "'segment': layout analysis only. " + "'text'/'table': direct end-to-end parsing.", +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + print("Loading default image from Hugging Face datasets.") + dataset = load_dataset("hf-internal-testing/example-documents", split="test") + image = dataset[0]["image"] + + +if args.task in ["full", "segment"]: + prompt = "Parse the reading order of this document." + decoder_prompt = f"{prompt}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": image} + ), + decoder_prompt=decoder_prompt_tokens, + ) + layout_outputs = llm.generate( + prompts=enc_dec_prompt, sampling_params=sampling_params + ) + layout_result_str = layout_outputs[0].outputs[0].text + print(f"Raw layout analysis output:\n{layout_result_str}") + + if args.task == "segment": + print("\nTask 'segment' completed.") + sys.exit(0) + + padded_image, dims = prepare_image(image) + layout_results = parse_layout_string(layout_result_str) + text_table_elements = [] + previous_box = None + reading_order = 0 + for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image." + if label == "tab" + else "Read text in the image." + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + + if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"{elem['prompt']}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate( + prompts=batch_prompts, sampling_params=sampling_params + ) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + + print("------" * 8) + text_table_elements.sort(key=lambda x: x["reading_order"]) + for elem in text_table_elements: + print(elem.get("text", "")) + +elif args.task in ["text", "table"]: + prompt_map = { + "text": "Read text in the image.", + "table": "Parse the tables in the image.", + } + prompt = prompt_map[args.task] + print(f'Using direct prompt: "{prompt}"') + + decoder_prompt = f"{prompt} " + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": image} + ), + decoder_prompt=decoder_prompt_tokens, + ) + outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) + result_text = outputs[0].outputs[0].text.strip() + + print("------" * 8) + print("TEXT: ", result_text) diff --git a/examples/offline_inference/donut.py b/examples/offline_inference/donut.py new file mode 100644 index 000000000000..34c6e0dd60a7 --- /dev/null +++ b/examples/offline_inference/donut.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from datasets import load_dataset +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt + +model_id = "naver-clova-ix/donut-base-finetuned-docvqa" + +processor = DonutProcessor.from_pretrained(model_id) + +# The input image size for donut-base-finetuned-docvqa is 2560 x 1920, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 1920 / 4 = 480 patches +# Width: 2560 / 4 = 640 patches + +# The Swin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. +# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. +# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. +encoder_prompt = ["$"] * 4799 +encoder_prompt = "".join(encoder_prompt) + +dataset = load_dataset("hf-internal-testing/example-documents", split="test") +questions = [ + "What time is the coffee break?", + "What's the brand name?", + "What's the total cost?", +] +enc_dec_prompt = [] +for i in range(3): + image = dataset[i]["image"] + question = questions[i] + decoder_prompt = f"{question}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] + ) + enc_dec_prompt.append( + ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": image} + ), + decoder_prompt=decoder_prompt_tokens, + ) + ) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +llm = LLM( + model=model_id, + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +# Batch Inference +outputs = llm.generate( + prompts=enc_dec_prompt, + sampling_params=sampling_params, +) + +print("------" * 8) + +for i in range(3): + print(f"Decoder prompt: {questions[i]}") + print(f"Generated text: {outputs[i].outputs[0].text}") + + print("------" * 8) diff --git a/tests/models/registry.py b/tests/models/registry.py index cbdc9edbbc9d..2f4051a1d666 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -517,6 +517,8 @@ def check_available_online( trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}), # noqa: E501 # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 } @@ -606,4 +608,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) -AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) +AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) \ No newline at end of file diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351e87..dbf8d3ba5014 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1822,7 +1822,7 @@ def _validate_model_input( assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 000000000000..55af51bb85af --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _validate_pixel_values( + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: + + # size = self.processor_config["size"] + h, w = self.config.encoder.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = tuple(*map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return DonutImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = _flatten_embeddings(vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8728684d8e68..f41b93caa30b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -249,6 +249,7 @@ "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 @@ -856,4 +857,4 @@ def _run() -> None: if __name__ == "__main__": - _run() + _run() \ No newline at end of file diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 000000000000..2810be344b45 --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,909 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.pytorch_utils import meshgrid +from transformers.utils import torch_int + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view(batch_size, height // window_size, + window_size, width // window_size, + window_size, num_channels) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view( + -1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, + window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, height, width, + num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings +class SwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + Optionally, also the mask token. + """ + + def __init__( + self, + config: SwinConfig, + ) -> None: + super().__init__() + + self.patch_embeddings = SwinPatchEmbeddings(config) + self.patch_grid = self.patch_embeddings.grid_size + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit + # .modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, + to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure + # the exported model works for dynamic input shapes + if not torch.jit.is_tracing( + ) and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, + sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + ) -> tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings +class SwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape + `(batch_size, num_channels, height, width)` + into the initial `hidden_states` (patch embeddings) of shape + `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance( + image_size, Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance( + patch_size, Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // + patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], + image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, + hidden_size, + kernel_size=patch_size, + stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, + self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward( + self, pixel_values: Optional[torch.FloatTensor] + ) -> tuple[torch.Tensor, tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class SwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, + input_resolution: tuple[int], + dim: int, + norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, + input_dimensions: tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, + num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([ + input_feature_0, input_feature_1, input_feature_2, input_feature_3 + ], -1) + input_feature = input_feature.view( + batch_size, -1, + 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, + drop_prob: float = 0.0, + training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the + DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a + different form of dropout in a separate paper... See discussion: + https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 + I've opted for changing the layer and argument names to 'drop path' + rather than mix DropConnect as a layer name and use 'survival rate' + as the argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0], ) + (1, ) * ( + input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath +class SwinDropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + self.fused_attn = True + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + if self.fused_attn: + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + else: + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores * self.scale + attention_scores = attention_scores + self._get_rel_pos_bias() + + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + attention_mask.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, dim, dim) + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SwinLayer(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.drop_path = SwinDropPath( + drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, + # we don't partition windows + self.shift_size = torch_int(0) + self.window_size = (torch.min(torch.tensor(input_resolution)) + if torch.jit.is_tracing() else + min(input_resolution)) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), + dtype=dtype, + device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + (-100.0)).masked_fill( + attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - + width % self.window_size) % self.window_size + pad_bottom = (self.window_size - + height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, + width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, + shifts=(-self.shift_size, + -self.shift_size), + dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, + self.window_size) + hidden_states_windows = hidden_states_windows.view( + -1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask(height_pad, + width_pad, + dtype=hidden_states.dtype, + device=hidden_states_windows.device) + + attention_outputs = self.attention(hidden_states_windows, + attn_mask, + head_mask, + output_attentions=output_attentions) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, + self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, + height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, + shifts=(self.shift_size, + self.shift_size), + dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, : + width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, + channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, + attention_outputs[1]) if output_attentions else ( + layer_output, ) + return layer_outputs + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params \ No newline at end of file diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2da9b4c72189..ea2efbdd8b52 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -209,7 +209,7 @@ def get_encoder_dummy_data( if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 97d79c2ae093..2b84851bb26f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -397,7 +397,7 @@ def _validate_model_input( assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( From 61dd59399a3aaf19f2151793b64f447042b9db36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Wed, 20 Aug 2025 11:07:35 +0800 Subject: [PATCH 02/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- vllm/model_executor/models/donut.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py index 55af51bb85af..ba859ffdbdad 100644 --- a/vllm/model_executor/models/donut.py +++ b/vllm/model_executor/models/donut.py @@ -287,10 +287,9 @@ def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: - expected_expr = tuple(*map(str, expected_dims)) raise ValueError( "The expected shape of pixel values per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") + f"is {expected_dims}. You supplied {actual_dims}.") for d in data: _validate_shape(d) From 9f6602f7fecfc9e5299e1f6c83e32a9c0c5423c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Wed, 20 Aug 2025 14:23:59 +0800 Subject: [PATCH 03/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- docs/models/supported_models.md | 2 +- tests/models/registry.py | 2 +- vllm/model_executor/models/registry.py | 2 +- vllm/model_executor/models/swin.py | 278 +------------------------ 4 files changed, 9 insertions(+), 275 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9896be316275..490e87189280 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -816,4 +816,4 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. \ No newline at end of file +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/tests/models/registry.py b/tests/models/registry.py index 2f4051a1d666..6dba95cd1337 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -608,4 +608,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) -AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) \ No newline at end of file +AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f41b93caa30b..9f1b095860e4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -857,4 +857,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index 2810be344b45..74e79f425f9b 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -7,6 +7,11 @@ import torch import torch.nn as nn from transformers import SwinConfig +from transformers.models.swin.modeling_swin import (SwinDropPath, + SwinEmbeddings, + SwinPatchMerging, + window_partition, + window_reverse) from transformers.pytorch_utils import meshgrid from transformers.utils import torch_int @@ -18,277 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader -# Copied from transformers.models.swin.modeling_swin.window_partition -def window_partition(input_feature, window_size): - """ - Partitions the given input into windows. - """ - batch_size, height, width, num_channels = input_feature.shape - input_feature = input_feature.view(batch_size, height // window_size, - window_size, width // window_size, - window_size, num_channels) - windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view( - -1, window_size, window_size, num_channels) - return windows - - -# Copied from transformers.models.swin.modeling_swin.window_reverse -def window_reverse(windows, window_size, height, width): - """ - Merges windows to produce higher resolution features. - """ - num_channels = windows.shape[-1] - windows = windows.view(-1, height // window_size, width // window_size, - window_size, window_size, num_channels) - windows = windows.permute(0, 1, 3, 2, 4, - 5).contiguous().view(-1, height, width, - num_channels) - return windows - - -# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings -class SwinEmbeddings(nn.Module): - """ - Construct the patch and position embeddings. - Optionally, also the mask token. - """ - - def __init__( - self, - config: SwinConfig, - ) -> None: - super().__init__() - - self.patch_embeddings = SwinPatchEmbeddings(config) - self.patch_grid = self.patch_embeddings.grid_size - - self.norm = nn.LayerNorm(config.embed_dim) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.patch_size = config.patch_size - self.config = config - - # Copied from transformers.models.vit - # .modeling_vit.ViTEmbeddings.interpolate_pos_encoding - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, - width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, - to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing. - """ - - num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 - - # always interpolate when tracing to ensure - # the exported model works for dynamic input shapes - if not torch.jit.is_tracing( - ) and num_patches == num_positions and height == width: - return self.position_embeddings - - class_pos_embed = self.position_embeddings[:, :1] - patch_pos_embed = self.position_embeddings[:, 1:] - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size - - sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, - sqrt_num_positions, dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, - ) - - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - - return torch.cat((class_pos_embed, patch_pos_embed), dim=1) - - def forward( - self, - pixel_values: Optional[torch.FloatTensor], - ) -> tuple[torch.Tensor]: - _, num_channels, height, width = pixel_values.shape - embeddings, output_dimensions = self.patch_embeddings(pixel_values) - embeddings = self.norm(embeddings) - embeddings = self.dropout(embeddings) - - return embeddings, output_dimensions - - -# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings -class SwinPatchEmbeddings(nn.Module): - """ - This class turns `pixel_values` of shape - `(batch_size, num_channels, height, width)` - into the initial `hidden_states` (patch embeddings) of shape - `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.embed_dim - image_size = image_size if isinstance( - image_size, Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance( - patch_size, Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // - patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - self.grid_size = (image_size[0] // patch_size[0], - image_size[1] // patch_size[1]) - - self.projection = nn.Conv2d(num_channels, - hidden_size, - kernel_size=patch_size, - stride=patch_size) - - def maybe_pad(self, pixel_values, height, width): - if width % self.patch_size[1] != 0: - pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) - pixel_values = nn.functional.pad(pixel_values, pad_values) - if height % self.patch_size[0] != 0: - pad_values = (0, 0, 0, - self.patch_size[0] - height % self.patch_size[0]) - pixel_values = nn.functional.pad(pixel_values, pad_values) - return pixel_values - - def forward( - self, pixel_values: Optional[torch.FloatTensor] - ) -> tuple[torch.Tensor, tuple[int]]: - _, num_channels, height, width = pixel_values.shape - # pad the input to be divisible by self.patch_size, if needed - pixel_values = self.maybe_pad(pixel_values, height, width) - embeddings = self.projection(pixel_values) - _, _, height, width = embeddings.shape - output_dimensions = (height, width) - embeddings = embeddings.flatten(2).transpose(1, 2) - - return embeddings, output_dimensions - - -# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging -class SwinPatchMerging(nn.Module): - """ - Patch Merging Layer. - - Args: - input_resolution (`tuple[int]`): - Resolution of input feature. - dim (`int`): - Number of input channels. - norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): - Normalization layer class. - """ - - def __init__(self, - input_resolution: tuple[int], - dim: int, - norm_layer: nn.Module = nn.LayerNorm) -> None: - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def maybe_pad(self, input_feature, height, width): - should_pad = (height % 2 == 1) or (width % 2 == 1) - if should_pad: - pad_values = (0, 0, 0, width % 2, 0, height % 2) - input_feature = nn.functional.pad(input_feature, pad_values) - - return input_feature - - def forward(self, input_feature: torch.Tensor, - input_dimensions: tuple[int, int]) -> torch.Tensor: - height, width = input_dimensions - # `dim` is height * width - batch_size, dim, num_channels = input_feature.shape - - input_feature = input_feature.view(batch_size, height, width, - num_channels) - # pad input to be disible by width and height, if needed - input_feature = self.maybe_pad(input_feature, height, width) - # [batch_size, height/2, width/2, num_channels] - input_feature_0 = input_feature[:, 0::2, 0::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_1 = input_feature[:, 1::2, 0::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_2 = input_feature[:, 0::2, 1::2, :] - # [batch_size, height/2, width/2, num_channels] - input_feature_3 = input_feature[:, 1::2, 1::2, :] - # batch_size height/2 width/2 4*num_channels - input_feature = torch.cat([ - input_feature_0, input_feature_1, input_feature_2, input_feature_3 - ], -1) - input_feature = input_feature.view( - batch_size, -1, - 4 * num_channels) # batch_size height/2*width/2 4*C - - input_feature = self.norm(input_feature) - input_feature = self.reduction(input_feature) - - return input_feature - - -# Copied from transformers.models.beit.modeling_beit.drop_path -def drop_path(input: torch.Tensor, - drop_prob: float = 0.0, - training: bool = False) -> torch.Tensor: - """ - Drop paths (Stochastic Depth) per sample - (when applied in main path of residual blocks). - - Comment by Ross Wightman: This is the same as the - DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a - different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 - I've opted for changing the layer and argument names to 'drop path' - rather than mix DropConnect as a layer name and use 'survival rate' - as the argument. - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - shape = (input.shape[0], ) + (1, ) * ( - input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand( - shape, dtype=input.dtype, device=input.device) - random_tensor.floor_() # binarize - output = input.div(keep_prob) * random_tensor - return output - - -# Copied from transformers.models.swin.modeling_swin.SwinDropPath -class SwinDropPath(nn.Module): - """ - Drop paths (Stochastic Depth) per sample - (when applied in main path of residual blocks). - """ - - def __init__(self, drop_prob: Optional[float] = None) -> None: - super().__init__() - self.drop_prob = drop_prob - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return drop_path(hidden_states, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return "p={}".format(self.drop_prob) - - class SwinSelfAttention(nn.Module): def __init__( @@ -906,4 +640,4 @@ def load_weights(self, weights: Iterable[tuple[str, default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) - return loaded_params \ No newline at end of file + return loaded_params From 1fd5d8a2ae1ade2016b61b943ee0b75b1b129402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Wed, 20 Aug 2025 15:45:15 +0800 Subject: [PATCH 04/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- tests/models/registry.py | 3 +- vllm/model_executor/models/donut.py | 12 +- vllm/model_executor/models/swin.py | 171 +++------------------------- 3 files changed, 25 insertions(+), 161 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6dba95cd1337..26987d504db8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -518,7 +518,8 @@ def check_available_online( "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 - hf_overrides={"architectures": ["DonutForConditionalGeneration"]}), # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 } diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py index ba859ffdbdad..b1f6a0af6b3d 100644 --- a/vllm/model_executor/models/donut.py +++ b/vllm/model_executor/models/donut.py @@ -344,6 +344,13 @@ def get_multimodal_embeddings( vision_embeddings = self._process_image_input(image_input) return vision_embeddings + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + def forward( self, input_ids: torch.Tensor, @@ -370,7 +377,8 @@ def forward( inputs_embeds = None if encoder_input_ids.numel() > 0: vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = _flatten_embeddings(vision_embeddings) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) hidden_states = self.decoder(input_ids, positions, @@ -387,4 +395,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index 74e79f425f9b..5a7c71f34a2c 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -7,13 +7,10 @@ import torch import torch.nn as nn from transformers import SwinConfig -from transformers.models.swin.modeling_swin import (SwinDropPath, - SwinEmbeddings, - SwinPatchMerging, - window_partition, - window_reverse) +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging from transformers.pytorch_utils import meshgrid -from transformers.utils import torch_int from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -78,8 +75,6 @@ def __init__( prefix=f"{prefix}.qkv", ) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -151,7 +146,6 @@ def forward( -1, self.num_attention_heads, dim, dim) attention_probs = nn.functional.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask @@ -185,12 +179,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.dense", ) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) return hidden_states @@ -263,15 +255,13 @@ def __init__(self, dim, quant_config=quant_config, prefix=f"{prefix}.dense") - self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) return hidden_states -class SwinLayer(nn.Module): +class SwinLayer(HFSwinLayer): def __init__( self, @@ -284,21 +274,21 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.shift_size = shift_size - self.window_size = config.window_size - self.input_resolution = input_resolution - self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size, quant_config=quant_config, prefix=f"{prefix}.attention") - self.drop_path = SwinDropPath( - drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() - self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.intermediate = SwinIntermediate(config, dim, quant_config=quant_config, @@ -308,141 +298,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.output") - def set_shift_and_window_size(self, input_resolution): - if min(input_resolution) <= self.window_size: - # if window size is larger than input resolution, - # we don't partition windows - self.shift_size = torch_int(0) - self.window_size = (torch.min(torch.tensor(input_resolution)) - if torch.jit.is_tracing() else - min(input_resolution)) - - def get_attn_mask(self, height, width, dtype, device): - if self.shift_size > 0: - # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1), - dtype=dtype, - device=device) - height_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - width_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - count = 0 - for height_slice in height_slices: - for width_slice in width_slices: - img_mask[:, height_slice, width_slice, :] = count - count += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view( - -1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, - (-100.0)).masked_fill( - attn_mask == 0, 0.0) - else: - attn_mask = None - return attn_mask - - def maybe_pad(self, hidden_states, height, width): - pad_right = (self.window_size - - width % self.window_size) % self.window_size - pad_bottom = (self.window_size - - height % self.window_size) % self.window_size - pad_values = (0, 0, 0, pad_right, 0, pad_bottom) - hidden_states = nn.functional.pad(hidden_states, pad_values) - return hidden_states, pad_values - - def forward( - self, - hidden_states: torch.Tensor, - input_dimensions: tuple[int, int], - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - always_partition: Optional[bool] = False, - ) -> tuple[torch.Tensor, torch.Tensor]: - if not always_partition: - self.set_shift_and_window_size(input_dimensions) - else: - pass - height, width = input_dimensions - batch_size, _, channels = hidden_states.size() - shortcut = hidden_states - - hidden_states = self.layernorm_before(hidden_states) - - hidden_states = hidden_states.view(batch_size, height, width, channels) - - # pad hidden_states to multiples of window size - hidden_states, pad_values = self.maybe_pad(hidden_states, height, - width) - - _, height_pad, width_pad, _ = hidden_states.shape - # cyclic shift - if self.shift_size > 0: - shifted_hidden_states = torch.roll(hidden_states, - shifts=(-self.shift_size, - -self.shift_size), - dims=(1, 2)) - else: - shifted_hidden_states = hidden_states - - # partition windows - hidden_states_windows = window_partition(shifted_hidden_states, - self.window_size) - hidden_states_windows = hidden_states_windows.view( - -1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, - width_pad, - dtype=hidden_states.dtype, - device=hidden_states_windows.device) - - attention_outputs = self.attention(hidden_states_windows, - attn_mask, - head_mask, - output_attentions=output_attentions) - - attention_output = attention_outputs[0] - - attention_windows = attention_output.view(-1, self.window_size, - self.window_size, channels) - shifted_windows = window_reverse(attention_windows, self.window_size, - height_pad, width_pad) - - # reverse cyclic shift - if self.shift_size > 0: - attention_windows = torch.roll(shifted_windows, - shifts=(self.shift_size, - self.shift_size), - dims=(1, 2)) - else: - attention_windows = shifted_windows - - was_padded = pad_values[3] > 0 or pad_values[5] > 0 - if was_padded: - attention_windows = attention_windows[:, :height, : - width, :].contiguous() - - attention_windows = attention_windows.view(batch_size, height * width, - channels) - - hidden_states = shortcut + self.drop_path(attention_windows) - - layer_output = self.layernorm_after(hidden_states) - layer_output = self.intermediate(layer_output) - layer_output = hidden_states + self.output(layer_output) - - layer_outputs = (layer_output, - attention_outputs[1]) if output_attentions else ( - layer_output, ) - return layer_outputs - class SwinStage(nn.Module): From 7da83d72e9e7d65f1d99a11b1ef4d47a6354336e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Thu, 21 Aug 2025 10:53:17 +0800 Subject: [PATCH 05/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- examples/offline_inference/dolphin.py | 193 ++++++++++---------------- vllm/model_executor/models/swin.py | 67 +++------ 2 files changed, 93 insertions(+), 167 deletions(-) diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py index c2b6aecdfeda..f1d70380e335 100644 --- a/examples/offline_inference/dolphin.py +++ b/examples/offline_inference/dolphin.py @@ -4,7 +4,6 @@ import argparse import copy import os -import sys from dataclasses import dataclass import cv2 @@ -147,6 +146,7 @@ def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_bo return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: try: image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) @@ -174,6 +174,7 @@ def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: return np.zeros((h, w, 3), dtype=np.uint8), dimensions +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py def parse_layout_string(bbox_str): """Parse layout string using regular expressions""" pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" @@ -211,17 +212,13 @@ def parse_layout_string(bbox_str): sampling_params = SamplingParams( temperature=0.0, max_tokens=2048, - logprobs=0, - prompt_logprobs=None, - skip_special_tokens=False, ) processor = DonutProcessor.from_pretrained(model_id) llm = LLM( model=model_id, - dtype="float32", - enforce_eager=True, - max_num_seqs=16, + dtype="float16", + max_num_seqs=8, hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, ) @@ -229,16 +226,6 @@ def parse_layout_string(bbox_str): parser.add_argument( "--image_path", type=str, default=None, help="Path to a local image file." ) -parser.add_argument( - "--task", - type=str, - default="full", - choices=["full", "segment", "text", "table"], - help="The task to perform. " - "'full': layout analysis then OCR (default). " - "'segment': layout analysis only. " - "'text'/'table': direct end-to-end parsing.", -) args = parser.parse_args() if args.image_path: @@ -246,116 +233,78 @@ def parse_layout_string(bbox_str): raise FileNotFoundError(f"Error: File not found at {args.image_path}") image = Image.open(args.image_path).convert("RGB") else: - print("Loading default image from Hugging Face datasets.") dataset = load_dataset("hf-internal-testing/example-documents", split="test") image = dataset[0]["image"] -if args.task in ["full", "segment"]: - prompt = "Parse the reading order of this document." - decoder_prompt = f"{prompt}" - decoder_prompt_tokens = TokensPrompt( - prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ - "input_ids" - ] - ) - enc_dec_prompt = ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=encoder_prompt, multi_modal_data={"image": image} - ), - decoder_prompt=decoder_prompt_tokens, - ) - layout_outputs = llm.generate( - prompts=enc_dec_prompt, sampling_params=sampling_params - ) - layout_result_str = layout_outputs[0].outputs[0].text - print(f"Raw layout analysis output:\n{layout_result_str}") - - if args.task == "segment": - print("\nTask 'segment' completed.") - sys.exit(0) - - padded_image, dims = prepare_image(image) - layout_results = parse_layout_string(layout_result_str) - text_table_elements = [] - previous_box = None - reading_order = 0 - for bbox_coords, label in layout_results: - if label == "fig": - continue - try: - x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( - process_coordinates(bbox_coords, padded_image, dims, previous_box) - ) - cropped = padded_image[y1:y2, x1:x2] - if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: - pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) - prompt_ocr = ( - "Parse the table in the image." - if label == "tab" - else "Read text in the image." - ) - text_table_elements.append( - { - "crop": pil_crop, - "prompt": prompt_ocr, - "reading_order": reading_order, - } - ) - reading_order += 1 - except Exception as e: - print(f"Error processing bbox (label: {label}): {str(e)}") - continue - - if text_table_elements: - batch_prompts = [] - for elem in text_table_elements: - decoder_prompt_str = f"{elem['prompt']}" - decoder_prompt_tokens = TokensPrompt( - prompt_token_ids=processor.tokenizer( - decoder_prompt_str, add_special_tokens=False - )["input_ids"] +prompt = "Parse the reading order of this document. " +decoder_prompt = f"{prompt}" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " ) - enc_dec_prompt = ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} - ), - decoder_prompt=decoder_prompt_tokens, + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } ) - batch_prompts.append(enc_dec_prompt) - batch_outputs = llm.generate( - prompts=batch_prompts, sampling_params=sampling_params - ) - for i, output in enumerate(batch_outputs): - text_table_elements[i]["text"] = output.outputs[0].text.strip() + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue - print("------" * 8) - text_table_elements.sort(key=lambda x: x["reading_order"]) +if text_table_elements: + batch_prompts = [] for elem in text_table_elements: - print(elem.get("text", "")) - -elif args.task in ["text", "table"]: - prompt_map = { - "text": "Read text in the image.", - "table": "Parse the tables in the image.", - } - prompt = prompt_map[args.task] - print(f'Using direct prompt: "{prompt}"') - - decoder_prompt = f"{prompt} " - decoder_prompt_tokens = TokensPrompt( - prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ - "input_ids" - ] - ) - enc_dec_prompt = ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=encoder_prompt, multi_modal_data={"image": image} - ), - decoder_prompt=decoder_prompt_tokens, - ) - outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) - result_text = outputs[0].outputs[0].text.strip() - - print("------" * 8) - print("TEXT: ", result_text) + decoder_prompt_str = f"{elem['prompt']}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index 5a7c71f34a2c..30b441f5b4df 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -43,7 +43,6 @@ def __init__( self.window_size = (window_size if isinstance(window_size, Iterable) else (window_size, window_size)) self.scale = self.attention_head_size**-0.5 - self.fused_attn = True self.relative_position_bias_table = nn.Parameter( torch.zeros( @@ -107,50 +106,28 @@ def forward( value_layer = self.transpose_for_scores(value_layer) query_layer = self.transpose_for_scores(query_layer) - if self.fused_attn: - attention_scores = self._get_rel_pos_bias() - if attention_mask is not None: - mask_shape = attention_mask.shape[0] - attention_mask_expanded = attention_mask.view( - 1, mask_shape, 1, dim, - dim).expand(batch_size // mask_shape, mask_shape, - self.num_attention_heads, dim, dim) - attention_scores = attention_scores + \ - attention_mask_expanded.unsqueeze( - 1).unsqueeze(0) - attention_scores = attention_scores.view( - -1, self.num_attention_heads, dim, dim) - - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_scores, - dropout_p=0., - ) - attention_probs = None - else: - attention_scores = torch.matmul(query_layer, - key_layer.transpose(-1, -2)) - attention_scores = attention_scores * self.scale - attention_scores = attention_scores + self._get_rel_pos_bias() - - if attention_mask is not None: - mask_shape = attention_mask.shape[0] - attention_scores = attention_scores.view( - batch_size // mask_shape, mask_shape, - self.num_attention_heads, dim, dim) - attention_scores = attention_scores + attention_mask.unsqueeze( - 1).unsqueeze(0) - attention_scores = attention_scores.view( - -1, self.num_attention_heads, dim, dim) - - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + ( From 7818022a325ee84fde59084d9d30bae886461aef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Fri, 22 Aug 2025 09:38:20 +0800 Subject: [PATCH 06/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- .../encoder_decoder_multimodal.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb7e..63dbd0966266 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,49 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_model_len=2048, + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + ) + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": ["$"] * 4799, + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/images/sf-districts.png" + ) # noqa: E501 + }, + }, + "decoder_prompt": "What time is the coffee break?", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -121,6 +165,7 @@ def run_whisper(): "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, + "donut": run_donut, } From ec1d90780273d960372ec8336562d5cce7c7d0ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Fri, 22 Aug 2025 10:53:26 +0800 Subject: [PATCH 07/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- examples/offline_inference/dolphin.py | 7 +- examples/offline_inference/donut.py | 81 ------------------- .../encoder_decoder_multimodal.py | 7 +- 3 files changed, 8 insertions(+), 87 deletions(-) delete mode 100644 examples/offline_inference/donut.py diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py index f1d70380e335..d2ba27cd1e02 100644 --- a/examples/offline_inference/dolphin.py +++ b/examples/offline_inference/dolphin.py @@ -9,12 +9,12 @@ import cv2 import numpy as np import regex as re -from datasets import load_dataset from PIL import Image from transformers import DonutProcessor from vllm import LLM, SamplingParams from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image # Copied from https://github.com/bytedance/Dolphin/utils/utils.py @@ -233,8 +233,9 @@ def parse_layout_string(bbox_str): raise FileNotFoundError(f"Error: File not found at {args.image_path}") image = Image.open(args.image_path).convert("RGB") else: - dataset = load_dataset("hf-internal-testing/example-documents", split="test") - image = dataset[0]["image"] + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) prompt = "Parse the reading order of this document. " diff --git a/examples/offline_inference/donut.py b/examples/offline_inference/donut.py deleted file mode 100644 index 34c6e0dd60a7..000000000000 --- a/examples/offline_inference/donut.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from datasets import load_dataset -from transformers import DonutProcessor - -from vllm import LLM, SamplingParams -from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt - -model_id = "naver-clova-ix/donut-base-finetuned-docvqa" - -processor = DonutProcessor.from_pretrained(model_id) - -# The input image size for donut-base-finetuned-docvqa is 2560 x 1920, -# and the patch_size is 4 x 4. -# Therefore, the initial number of patches is: -# Height: 1920 / 4 = 480 patches -# Width: 2560 / 4 = 640 patches - -# The Swin model uses a staged downsampling approach, -# defined by the "depths": [2, 2, 14, 2] configuration. -# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, -# which halves the feature map's dimensions (dividing both height and width by 2). -# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. -# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. -# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. - -# Because vLLM needs to fill the image features with an encoder_prompt, -# and the encoder_prompt will have `` tokens added when tokenized, -# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. -encoder_prompt = ["$"] * 4799 -encoder_prompt = "".join(encoder_prompt) - -dataset = load_dataset("hf-internal-testing/example-documents", split="test") -questions = [ - "What time is the coffee break?", - "What's the brand name?", - "What's the total cost?", -] -enc_dec_prompt = [] -for i in range(3): - image = dataset[i]["image"] - question = questions[i] - decoder_prompt = f"{question}" - decoder_prompt_tokens = TokensPrompt( - prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ - "input_ids" - ] - ) - enc_dec_prompt.append( - ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=encoder_prompt, multi_modal_data={"image": image} - ), - decoder_prompt=decoder_prompt_tokens, - ) - ) -sampling_params = SamplingParams( - temperature=0.0, - max_tokens=2048, -) - -llm = LLM( - model=model_id, - max_num_seqs=8, - hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, -) - -# Batch Inference -outputs = llm.generate( - prompts=enc_dec_prompt, - sampling_params=sampling_params, -) - -print("------" * 8) - -for i in range(3): - print(f"Decoder prompt: {questions[i]}") - print(f"Generated text: {outputs[i].outputs[0].text}") - - print("------" * 8) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 63dbd0966266..915936719b17 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -25,11 +25,12 @@ class ModelRequestData(NamedTuple): def run_donut(): engine_args = EngineArgs( model="naver-clova-ix/donut-base-finetuned-docvqa", - max_model_len=2048, max_num_seqs=2, limit_mm_per_prompt={"image": 1}, dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, ) + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, # and the patch_size is 4 x 4. # Therefore, the initial number of patches is: @@ -48,10 +49,10 @@ def run_donut(): prompts = [ { "encoder_prompt": { - "prompt": ["$"] * 4799, + "prompt": "".join(["$"] * 4799), "multi_modal_data": { "image": fetch_image( - "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/images/sf-districts.png" + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" ) # noqa: E501 }, }, From f9d6c94226653ee4375251f0ed8cbdf0a6b6698a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Fri, 22 Aug 2025 14:21:34 +0800 Subject: [PATCH 08/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- examples/offline_inference/dolphin.py | 6 ++---- examples/offline_inference/encoder_decoder_multimodal.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py index d2ba27cd1e02..28103fbc07ed 100644 --- a/examples/offline_inference/dolphin.py +++ b/examples/offline_inference/dolphin.py @@ -13,8 +13,8 @@ from transformers import DonutProcessor from vllm import LLM, SamplingParams +from vllm.assets.base import get_vllm_public_assets from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt -from vllm.multimodal.utils import fetch_image # Copied from https://github.com/bytedance/Dolphin/utils/utils.py @@ -233,9 +233,7 @@ def parse_layout_string(bbox_str): raise FileNotFoundError(f"Error: File not found at {args.image_path}") image = Image.open(args.image_path).convert("RGB") else: - image = fetch_image( - "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" - ) + image = get_vllm_public_assets("ocr_test_images", "schedule").pil_image prompt = "Parse the reading order of this document. " diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 915936719b17..3b29db621705 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -12,8 +12,8 @@ from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset +from vllm.assets.base import get_vllm_public_assets from vllm.assets.image import ImageAsset -from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -51,9 +51,9 @@ def run_donut(): "encoder_prompt": { "prompt": "".join(["$"] * 4799), "multi_modal_data": { - "image": fetch_image( - "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" - ) # noqa: E501 + "image": get_vllm_public_assets( + "ocr_test_images", "schedule" + ).pil_image, }, }, "decoder_prompt": "What time is the coffee break?", # noqa: E501 From 2b4d120ae419d5cec2689ad14722fc6082219c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Fri, 22 Aug 2025 14:24:37 +0800 Subject: [PATCH 09/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- examples/offline_inference/dolphin.py | 2 +- examples/offline_inference/encoder_decoder_multimodal.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py index 28103fbc07ed..8630f5f3927d 100644 --- a/examples/offline_inference/dolphin.py +++ b/examples/offline_inference/dolphin.py @@ -233,7 +233,7 @@ def parse_layout_string(bbox_str): raise FileNotFoundError(f"Error: File not found at {args.image_path}") image = Image.open(args.image_path).convert("RGB") else: - image = get_vllm_public_assets("ocr_test_images", "schedule").pil_image + image = get_vllm_public_assets("schedule.jpg", "ocr_test_images").pil_image prompt = "Parse the reading order of this document. " diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 3b29db621705..e8922abb9956 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -52,7 +52,7 @@ def run_donut(): "prompt": "".join(["$"] * 4799), "multi_modal_data": { "image": get_vllm_public_assets( - "ocr_test_images", "schedule" + "schedule.jpg", "ocr_test_images" ).pil_image, }, }, From 31a0d5e85cf157405ccefb838a410a393a4af300 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Fri, 22 Aug 2025 14:27:40 +0800 Subject: [PATCH 10/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- examples/offline_inference/dolphin.py | 6 ++++-- examples/offline_inference/encoder_decoder_multimodal.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py index 8630f5f3927d..d2ba27cd1e02 100644 --- a/examples/offline_inference/dolphin.py +++ b/examples/offline_inference/dolphin.py @@ -13,8 +13,8 @@ from transformers import DonutProcessor from vllm import LLM, SamplingParams -from vllm.assets.base import get_vllm_public_assets from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image # Copied from https://github.com/bytedance/Dolphin/utils/utils.py @@ -233,7 +233,9 @@ def parse_layout_string(bbox_str): raise FileNotFoundError(f"Error: File not found at {args.image_path}") image = Image.open(args.image_path).convert("RGB") else: - image = get_vllm_public_assets("schedule.jpg", "ocr_test_images").pil_image + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) prompt = "Parse the reading order of this document. " diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index e8922abb9956..915936719b17 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -12,8 +12,8 @@ from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset -from vllm.assets.base import get_vllm_public_assets from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -51,9 +51,9 @@ def run_donut(): "encoder_prompt": { "prompt": "".join(["$"] * 4799), "multi_modal_data": { - "image": get_vllm_public_assets( - "schedule.jpg", "ocr_test_images" - ).pil_image, + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 }, }, "decoder_prompt": "What time is the coffee break?", # noqa: E501 From 120a9d4baccca558b1ae6f45db781d6b32c76ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Fri, 22 Aug 2025 14:45:26 +0800 Subject: [PATCH 11/14] xxx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- tests/models/multimodal/processing/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0fdc182b9ee9..ebb100b75dfa 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -319,6 +319,7 @@ def _test_processing_correctness_one( "openai/whisper-large-v3", "omni-research/Tarsier-7b", "omni-research/Tarsier2-Recap-7b", + "naver-clova-ix/donut-base-finetuned-docvqa", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) From 59b15845fcdb73da045b6d8ca5465d6518072714 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Sun, 24 Aug 2025 15:21:17 +0800 Subject: [PATCH 12/14] fix unit test error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- tests/models/multimodal/processing/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index b10a88ce824e..aab1b0fffab2 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -166,6 +166,7 @@ def _test_processing_correctness( "paligemma": False, "ultravox": False, "whisper": False, + "vision-encoder-decoder": False, } _IGNORE_MM_KEYS = { From 828fe796c29341abea57289e27884c5b2a389fd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Sun, 24 Aug 2025 15:29:43 +0800 Subject: [PATCH 13/14] fix unit test error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- tests/models/multimodal/processing/test_common.py | 2 +- tests/models/registry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index aab1b0fffab2..5b8d9d96d0be 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -166,7 +166,7 @@ def _test_processing_correctness( "paligemma": False, "ultravox": False, "whisper": False, - "vision-encoder-decoder": False, + "donut": False, } _IGNORE_MM_KEYS = { diff --git a/tests/models/registry.py b/tests/models/registry.py index 3ac701445c8a..06a977d88f11 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -523,7 +523,7 @@ def check_available_online( "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 - hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 From 9f4627adba75c3f8217cd9558d58eb007c74e38c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Sun, 24 Aug 2025 18:36:30 +0800 Subject: [PATCH 14/14] keep the placement in alphabetical order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- examples/offline_inference/encoder_decoder_multimodal.py | 2 +- tests/models/multimodal/processing/test_common.py | 4 ++-- tests/models/registry.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 915936719b17..655f9f3fce7a 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -163,10 +163,10 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, - "donut": run_donut, } diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 5b8d9d96d0be..a604d11f0e76 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -160,13 +160,13 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, "ovis2_5": False, "paligemma": False, "ultravox": False, "whisper": False, - "donut": False, } _IGNORE_MM_KEYS = { @@ -271,6 +271,7 @@ def _test_processing_correctness_one( "facebook/chameleon-7b", "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "naver-clova-ix/donut-base-finetuned-docvqa", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", @@ -326,7 +327,6 @@ def _test_processing_correctness_one( "omni-research/Tarsier-7b", "omni-research/Tarsier2-Recap-7b", "mistralai/Voxtral-Mini-3B-2507", - "naver-clova-ix/donut-base-finetuned-docvqa", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 06a977d88f11..f10c552b894b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -515,6 +515,9 @@ def check_available_online( is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 @@ -522,9 +525,6 @@ def check_available_online( trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 - "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 - hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 - extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 }