Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
Expand Down Expand Up @@ -815,4 +816,4 @@ We have the following levels of testing for models:
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test.
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test.
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
361 changes: 361 additions & 0 deletions examples/offline_inference/dolphin.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add examples in examples/offline_inference/encoder_decoder_multimodal.py and examples/offline_inference/vision_language.py instead of adding new files in examples/offline_inference/.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at the images in the S3 bucket and it seems there aren't any suitable for OCR tasks. This is particularly true for the Dolphin model, whose OCR task is similar to executing a workflow. Different prompts will determine whether to segment or parse the document. At the same time, depending on the parsing tags, it will decide whether to parse text or icons. That's why I've added two example files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used fetch_image to load the image and moved donut to encoder_decoder_multimodal.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also removed the --task in dolphin

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the independent dolphin example still got merged. We don't really want to have model specific examples as that will clutter the examples and make it harder for new users to find what they need

Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import copy
import os
import sys
from dataclasses import dataclass

import cv2
import numpy as np
import regex as re
from datasets import load_dataset
from PIL import Image
from transformers import DonutProcessor

from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt


# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
@dataclass
class ImageDimensions:
original_w: int
original_h: int
padded_w: int
padded_h: int


# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def map_to_original_coordinates(
x1, y1, x2, y2, dims: ImageDimensions
) -> tuple[int, int, int, int]:
try:
top = (dims.padded_h - dims.original_h) // 2
left = (dims.padded_w - dims.original_w) // 2
orig_x1 = max(0, x1 - left)
orig_y1 = max(0, y1 - top)
orig_x2 = min(dims.original_w, x2 - left)
orig_y2 = min(dims.original_h, y2 - top)
if orig_x2 <= orig_x1:
orig_x2 = min(orig_x1 + 1, dims.original_w)
if orig_y2 <= orig_y1:
orig_y2 = min(orig_y1 + 1, dims.original_h)
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
except Exception as e:
print(f"map_to_original_coordinates error: {str(e)}")
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)


# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2):
if isinstance(image, str):
image = cv2.imread(image)
img_h, img_w = image.shape[:2]
new_boxes = []
for box in boxes:
best_box = copy.deepcopy(box)

def check_edge(img, current_box, i, is_vertical):
edge = current_box[i]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(
gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
if is_vertical:
line = binary[current_box[1] : current_box[3] + 1, edge]
else:
line = binary[edge, current_box[0] : current_box[2] + 1]
transitions = np.abs(np.diff(line))
return np.sum(transitions) / len(transitions)

edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
current_box = copy.deepcopy(box)
current_box[0] = min(max(current_box[0], 0), img_w - 1)
current_box[1] = min(max(current_box[1], 0), img_h - 1)
current_box[2] = min(max(current_box[2], 0), img_w - 1)
current_box[3] = min(max(current_box[3], 0), img_h - 1)

for i, direction, is_vertical in edges:
best_score = check_edge(image, current_box, i, is_vertical)
if best_score <= threshold:
continue
for step in range(max_pixels):
current_box[i] += direction
if i == 0 or i == 2:
current_box[i] = min(max(current_box[i], 0), img_w - 1)
else:
current_box[i] = min(max(current_box[i], 0), img_h - 1)
score = check_edge(image, current_box, i, is_vertical)
if score < best_score:
best_score = score
best_box = copy.deepcopy(current_box)
if score <= threshold:
break
new_boxes.append(best_box)
return new_boxes


# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
try:
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
x1, y1, x2, y2 = (
max(0, min(x1, dims.padded_w - 1)),
max(0, min(y1, dims.padded_h - 1)),
max(0, min(x2, dims.padded_w)),
max(0, min(y2, dims.padded_h)),
)
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
x1, y1, x2, y2 = new_boxes[0]
x1, y1, x2, y2 = (
max(0, min(x1, dims.padded_w - 1)),
max(0, min(y1, dims.padded_h - 1)),
max(0, min(x2, dims.padded_w)),
max(0, min(y2, dims.padded_h)),
)
if x2 <= x1:
x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
if previous_box is not None:
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
y1 = prev_y2
y1 = min(y1, dims.padded_h - 1)
if y2 <= y1:
y2 = min(y1 + 1, dims.padded_h)
new_previous_box = [x1, y1, x2, y2]
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
x1, y1, x2, y2, dims
)
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
except Exception as e:
print(f"process_coordinates error: {str(e)}")
orig_x1, orig_y1, orig_x2, orig_y2 = (
0,
0,
min(100, dims.original_w),
min(100, dims.original_h),
)
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]


def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]:
try:
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
original_h, original_w = image_cv.shape[:2]
max_size = max(original_h, original_w)
top = (max_size - original_h) // 2
bottom = max_size - original_h - top
left = (max_size - original_w) // 2
right = max_size - original_w - left
padded_image = cv2.copyMakeBorder(
image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
padded_h, padded_w = padded_image.shape[:2]
dimensions = ImageDimensions(
original_w=original_w,
original_h=original_h,
padded_w=padded_w,
padded_h=padded_h,
)
return padded_image, dimensions
except Exception as e:
print(f"prepare_image error: {str(e)}")
h, w = image.height, image.width
dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
return np.zeros((h, w, 3), dtype=np.uint8), dimensions


def parse_layout_string(bbox_str):
"""Parse layout string using regular expressions"""
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
matches = re.finditer(pattern, bbox_str)

parsed_results = []
for match in matches:
coords = [float(match.group(i)) for i in range(1, 5)]
label = match.group(5).strip()
parsed_results.append((coords, label))

return parsed_results


model_id = "ByteDance/Dolphin"

# The input image size for Dolphin is 896 x 896,
# and the patch_size is 4 x 4.
# Therefore, the initial number of patches is:
# Height: 896 / 4 = 224 patches
# Width: 896 / 4 = 224 patches

# The Dolphin model uses a staged downsampling approach,
# defined by the "depths": [2, 2, 14, 2] configuration.
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
# which halves the feature map's dimensions (dividing both height and width by 2).
# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112.
# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56.
# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28.

# Because vLLM needs to fill the image features with an encoder_prompt,
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783.
encoder_prompt = "".join(["0"] * 783)
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=2048,
logprobs=0,
prompt_logprobs=None,
skip_special_tokens=False,
)

processor = DonutProcessor.from_pretrained(model_id)
llm = LLM(
model=model_id,
dtype="float32",
enforce_eager=True,
max_num_seqs=16,
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
)

parser = argparse.ArgumentParser()
parser.add_argument(
"--image_path", type=str, default=None, help="Path to a local image file."
)
parser.add_argument(
"--task",
type=str,
default="full",
choices=["full", "segment", "text", "table"],
help="The task to perform. "
"'full': layout analysis then OCR (default). "
"'segment': layout analysis only. "
"'text'/'table': direct end-to-end parsing.",
)
args = parser.parse_args()

if args.image_path:
if not os.path.exists(args.image_path):
raise FileNotFoundError(f"Error: File not found at {args.image_path}")
image = Image.open(args.image_path).convert("RGB")
else:
print("Loading default image from Hugging Face datasets.")
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]


if args.task in ["full", "segment"]:
prompt = "Parse the reading order of this document."
decoder_prompt = f"<s>{prompt}<Answer/>"
decoder_prompt_tokens = TokensPrompt(
prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[
"input_ids"
]
)
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=encoder_prompt, multi_modal_data={"image": image}
),
decoder_prompt=decoder_prompt_tokens,
)
layout_outputs = llm.generate(
prompts=enc_dec_prompt, sampling_params=sampling_params
)
layout_result_str = layout_outputs[0].outputs[0].text
print(f"Raw layout analysis output:\n{layout_result_str}")

if args.task == "segment":
print("\nTask 'segment' completed.")
sys.exit(0)

padded_image, dims = prepare_image(image)
layout_results = parse_layout_string(layout_result_str)
text_table_elements = []
previous_box = None
reading_order = 0
for bbox_coords, label in layout_results:
if label == "fig":
continue
try:
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = (
process_coordinates(bbox_coords, padded_image, dims, previous_box)
)
cropped = padded_image[y1:y2, x1:x2]
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
prompt_ocr = (
"Parse the table in the image."
if label == "tab"
else "Read text in the image."
)
text_table_elements.append(
{
"crop": pil_crop,
"prompt": prompt_ocr,
"reading_order": reading_order,
}
)
reading_order += 1
except Exception as e:
print(f"Error processing bbox (label: {label}): {str(e)}")
continue

if text_table_elements:
batch_prompts = []
for elem in text_table_elements:
decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>"
decoder_prompt_tokens = TokensPrompt(
prompt_token_ids=processor.tokenizer(
decoder_prompt_str, add_special_tokens=False
)["input_ids"]
)
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]}
),
decoder_prompt=decoder_prompt_tokens,
)
batch_prompts.append(enc_dec_prompt)
batch_outputs = llm.generate(
prompts=batch_prompts, sampling_params=sampling_params
)
for i, output in enumerate(batch_outputs):
text_table_elements[i]["text"] = output.outputs[0].text.strip()

print("------" * 8)
text_table_elements.sort(key=lambda x: x["reading_order"])
for elem in text_table_elements:
print(elem.get("text", ""))

elif args.task in ["text", "table"]:
prompt_map = {
"text": "Read text in the image.",
"table": "Parse the tables in the image.",
}
prompt = prompt_map[args.task]
print(f'Using direct prompt: "{prompt}"')

decoder_prompt = f"<s>{prompt} <Answer/>"
decoder_prompt_tokens = TokensPrompt(
prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[
"input_ids"
]
)
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=encoder_prompt, multi_modal_data={"image": image}
),
decoder_prompt=decoder_prompt_tokens,
)
outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params)
result_text = outputs[0].outputs[0].text.strip()

print("------" * 8)
print("TEXT: ", result_text)
Loading
Loading