Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docling/datamodel/object_detection_engine_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Engine option helpers for object-detection runtimes."""

from __future__ import annotations

from typing import List, Literal

from pydantic import Field

from docling.models.inference_engines.object_detection.base import (
BaseObjectDetectionEngineOptions,
ObjectDetectionEngineType,
)


class OnnxRuntimeObjectDetectionEngineOptions(BaseObjectDetectionEngineOptions):
"""Runtime configuration for ONNX Runtime based object-detection models.

Preprocessing parameters come from HuggingFace preprocessor configs,
not from these options.
"""

engine_type: Literal[ObjectDetectionEngineType.ONNXRUNTIME] = (
ObjectDetectionEngineType.ONNXRUNTIME
)

model_filename: str = Field(
default="model.onnx",
description="Filename of the ONNX export inside the model repository",
)

providers: List[str] = Field(
default_factory=lambda: ["CPUExecutionProvider"],
description="Ordered list of ONNX Runtime execution providers to try",
)


class TransformersObjectDetectionEngineOptions(BaseObjectDetectionEngineOptions):
"""Runtime configuration for Transformers-based object-detection models."""

engine_type: Literal[ObjectDetectionEngineType.TRANSFORMERS] = (
ObjectDetectionEngineType.TRANSFORMERS
)

torch_dtype: str | None = Field(
default=None,
description="PyTorch dtype for model inference (e.g., 'float32', 'float16', 'bfloat16')",
)
45 changes: 42 additions & 3 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
)
from typing_extensions import deprecated

from docling.datamodel import asr_model_specs, stage_model_specs, vlm_model_specs
from docling.datamodel import (
asr_model_specs,
stage_model_specs,
vlm_model_specs,
)

# Import the following for backwards compatibility
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
Expand All @@ -26,16 +30,19 @@
DOCLING_LAYOUT_V2,
LayoutModelConfig,
)
from docling.datamodel.pipeline_options_asr_model import (
InlineAsrOptions,
from docling.datamodel.object_detection_engine_options import (
BaseObjectDetectionEngineOptions,
)
from docling.datamodel.pipeline_options_asr_model import InlineAsrOptions
from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions,
InferenceFramework,
InlineVlmOptions,
ResponseFormat,
)
from docling.datamodel.stage_model_specs import (
ObjectDetectionModelSpec,
ObjectDetectionStagePresetMixin,
StagePresetMixin,
VlmModelSpec,
)
Expand Down Expand Up @@ -1094,6 +1101,38 @@ class LayoutOptions(BaseLayoutOptions):
] = DOCLING_LAYOUT_HERON


class LayoutObjectDetectionOptions(ObjectDetectionStagePresetMixin, BaseLayoutOptions):
"""Options for layout detection using object-detection runtimes."""

kind: ClassVar[str] = "layout_object_detection"

create_orphan_clusters: Annotated[
bool,
Field(
description=(
"Create clusters for orphaned elements not assigned to any structure. When True, isolated text or "
"elements are grouped into their own clusters. Recommended for complete document coverage."
)
),
] = False

model_spec: ObjectDetectionModelSpec = Field(
default_factory=lambda: stage_model_specs.OBJECT_DETECTION_LAYOUT_HERON.model_spec.model_copy(
deep=True
),
description="Object-detection model specification for layout analysis",
)

engine_options: BaseObjectDetectionEngineOptions = Field(
description="Runtime configuration for the object-detection engine",
)


LayoutObjectDetectionOptions.register_preset(
stage_model_specs.OBJECT_DETECTION_LAYOUT_HERON
)


class AsrPipelineOptions(PipelineOptions):
"""Configuration options for the Automatic Speech Recognition (ASR) pipeline.

Expand Down
204 changes: 203 additions & 1 deletion docling/datamodel/stage_model_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import logging
from typing import Any, ClassVar, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set

from pydantic import BaseModel, Field

Expand All @@ -17,8 +17,16 @@
TransformersPromptStyle,
)
from docling.datamodel.vlm_engine_options import BaseVlmEngineOptions
from docling.models.inference_engines.object_detection.base import (
ObjectDetectionEngineType,
)
from docling.models.inference_engines.vlm.base import VlmEngineType

if TYPE_CHECKING:
from docling.datamodel.object_detection_engine_options import (
BaseObjectDetectionEngineOptions,
)

_log = logging.getLogger(__name__)


Expand Down Expand Up @@ -292,6 +300,75 @@ def has_explicit_engine_export(self, engine_type: VlmEngineType) -> bool:
return False


# =============================================================================
# OBJECT DETECTION MODEL SPECIFICATION
# =============================================================================


class ObjectDetectionModelSpec(BaseModel):
"""Specification for an object detection model.

Simpler than VlmModelSpec - no prompts, no preprocessing params.
Preprocessing comes from HuggingFace preprocessor configs.
Model files are assumed to be at the root of the HuggingFace repo.
"""

name: str = Field(description="Human-readable model name")

repo_id: str = Field(description="Default HuggingFace repository ID")

revision: str = Field(default="main", description="Default model revision")

engine_overrides: Dict["ObjectDetectionEngineType", EngineModelConfig] = Field(
default_factory=dict,
description="Engine-specific configuration overrides",
)

def get_engine_config(
self, engine_type: "ObjectDetectionEngineType"
) -> EngineModelConfig:
"""Get EngineModelConfig for a specific object-detection engine.

Args:
engine_type: The engine type being requested

Returns:
EngineModelConfig populated with repo/revision and engine overrides
"""
override = self.engine_overrides.get(engine_type)
if override is not None:
return override.merge_with(self.repo_id, self.revision)
return EngineModelConfig(repo_id=self.repo_id, revision=self.revision)

def get_repo_id(self, engine_type: "ObjectDetectionEngineType") -> str:
"""Get repository ID for specific engine.

Args:
engine_type: The engine type

Returns:
Repository ID (with engine override if applicable)
"""
override = self.engine_overrides.get(engine_type)
if override and override.repo_id:
return override.repo_id
return self.repo_id

def get_revision(self, engine_type: "ObjectDetectionEngineType") -> str:
"""Get revision for specific engine.

Args:
engine_type: The engine type

Returns:
Model revision (with engine override if applicable)
"""
override = self.engine_overrides.get(engine_type)
if override and override.revision:
return override.revision
return self.revision


# =============================================================================
# STAGE PRESET SYSTEM
# =============================================================================
Expand Down Expand Up @@ -502,6 +579,108 @@ def from_preset(
return instance


class ObjectDetectionStagePreset(BaseModel):
"""Preset definition for object detection-powered stages."""

preset_id: str = Field(description="Preset identifier")
name: str = Field(description="Human-readable preset name")
description: str = Field(description="Description of this preset")
model_spec: ObjectDetectionModelSpec = Field(
description="Object detection model specification"
)
default_engine_type: ObjectDetectionEngineType = Field(
default=ObjectDetectionEngineType.ONNXRUNTIME,
description="Default inference engine to use",
)
stage_options: Dict[str, Any] = Field(
default_factory=dict, description="Additional stage-specific defaults"
)


class ObjectDetectionStagePresetMixin:
"""Mixin to enable preset loading for object detection stages."""

_presets: ClassVar[Dict[str, ObjectDetectionStagePreset]]

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._presets = {}

@classmethod
def register_preset(cls, preset: ObjectDetectionStagePreset) -> None:
if preset.preset_id not in cls._presets:
cls._presets[preset.preset_id] = preset
else:
_log.error(
f"Preset '{preset.preset_id}' already registered for {cls.__name__}"
)

@classmethod
def get_preset(cls, preset_id: str) -> ObjectDetectionStagePreset:
if preset_id not in cls._presets:
raise KeyError(
f"Preset '{preset_id}' not found for {cls.__name__}. "
f"Available presets: {list(cls._presets.keys())}"
)
return cls._presets[preset_id]

@classmethod
def list_presets(cls) -> List[ObjectDetectionStagePreset]:
return list(cls._presets.values())

@classmethod
def list_preset_ids(cls) -> List[str]:
return list(cls._presets.keys())

@classmethod
def get_preset_info(cls) -> List[Dict[str, str]]:
return [
{
"preset_id": p.preset_id,
"name": p.name,
"description": p.description,
"model": p.model_spec.name,
"default_engine": p.default_engine_type.value,
}
for p in cls._presets.values()
]

@classmethod
def from_preset(
cls,
preset_id: str,
engine_options: Optional["BaseObjectDetectionEngineOptions"] = None,
**overrides: Any,
):
from docling.datamodel.object_detection_engine_options import (
OnnxRuntimeObjectDetectionEngineOptions,
TransformersObjectDetectionEngineOptions,
)

preset = cls.get_preset(preset_id)

if engine_options is None:
if preset.default_engine_type == ObjectDetectionEngineType.ONNXRUNTIME:
engine_options = OnnxRuntimeObjectDetectionEngineOptions()
elif preset.default_engine_type == ObjectDetectionEngineType.TRANSFORMERS:
engine_options = TransformersObjectDetectionEngineOptions()
else:
raise ValueError(
f"Unsupported engine type {preset.default_engine_type} for presets"
)

instance = cls( # type: ignore[call-arg]
model_spec=preset.model_spec,
engine_options=engine_options,
**preset.stage_options,
)

for key, value in overrides.items():
setattr(instance, key, value)

return instance


# =============================================================================
# PRESET DEFINITIONS
# =============================================================================
Expand Down Expand Up @@ -573,6 +752,29 @@ def from_preset(
},
}

# -----------------------------------------------------------------------------
# OBJECT DETECTION PRESETS
# -----------------------------------------------------------------------------

OBJECT_DETECTION_LAYOUT_HERON = ObjectDetectionStagePreset(
preset_id="layout_heron_default",
name="Layout Heron",
description="RT-DETR layout-heron model (ResNet50)",
model_spec=ObjectDetectionModelSpec(
name="layout_heron",
repo_id="docling-project/docling-layout-heron",
revision="main",
engine_overrides={
ObjectDetectionEngineType.ONNXRUNTIME: EngineModelConfig(
repo_id="docling-project/docling-layout-heron-onnx",
extra_config={"model_filename": "model.onnx"},
)
},
),
default_engine_type=ObjectDetectionEngineType.TRANSFORMERS,
)


# -----------------------------------------------------------------------------
# VLM_CONVERT PRESETS (for full page conversion)
# -----------------------------------------------------------------------------
Expand Down
21 changes: 21 additions & 0 deletions docling/models/inference_engines/object_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Object detection inference engines."""

from docling.models.inference_engines.object_detection.base import (
BaseObjectDetectionEngine,
BaseObjectDetectionEngineOptions,
ObjectDetectionEngineInput,
ObjectDetectionEngineOutput,
ObjectDetectionEngineType,
)
from docling.models.inference_engines.object_detection.factory import (
create_object_detection_engine,
)

__all__ = [
"BaseObjectDetectionEngine",
"BaseObjectDetectionEngineOptions",
"ObjectDetectionEngineInput",
"ObjectDetectionEngineOutput",
"ObjectDetectionEngineType",
"create_object_detection_engine",
]
Loading