Skip to content

Commit 2b3530e

Browse files
committed
feat: add programmatic Python API with custom prompt support
Add a type-safe, composable Python API for running OlmoCR pipeline programmatically, eliminating the need for subprocess calls. Key features: - New PipelineConfig dataclass with 25+ configuration options - run_pipeline() async function as main programmatic entry point - Custom prompt support via custom_prompt parameter - Full backward compatibility - CLI unchanged, delegates to shared impl - All existing features maintained (retries, batching, server management) Example usage: ```python import asyncio from olmocr import run_pipeline, PipelineConfig config = PipelineConfig( workspace="./workspace", pdfs=["doc1.pdf", "doc2.pdf"], custom_prompt="Extract text from this legal document...", markdown=True, workers=10 ) asyncio.run(run_pipeline(config)) ``` Changes: - Created olmocr/config.py with PipelineConfig dataclass - Extracted _main_impl() from main() to share logic between CLI and API - Added run_pipeline() as programmatic entry point - Added _config_to_args() helper to convert config to argparse.Namespace - Added custom_prompt parameter to build_page_query() - Threaded custom prompt through process_page() call stack - Updated __init__.py to export PipelineConfig and run_pipeline - Updated test mocks to accept custom_prompt parameter Backward compatibility: - CLI interface unchanged - main() delegates to _main_impl() - All default behaviors preserved - All existing flags and options work identically - Custom prompt optional - defaults to original prompt if not provided
1 parent f3198d2 commit 2b3530e

File tree

4 files changed

+283
-68
lines changed

4 files changed

+283
-68
lines changed

olmocr/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
from .config import PipelineConfig
2+
from .pipeline import run_pipeline
13
from .version import VERSION, VERSION_SHORT
4+
5+
__all__ = ["PipelineConfig", "run_pipeline"]

olmocr/config.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Configuration dataclass for the OlmoCR pipeline."""
2+
3+
from dataclasses import dataclass, field
4+
from typing import Optional
5+
6+
7+
@dataclass
8+
class PipelineConfig:
9+
"""
10+
Configuration for the OlmoCR pipeline.
11+
12+
This provides a type-safe, programmatic interface to configure the pipeline
13+
instead of using CLI arguments and subprocess calls.
14+
15+
Example:
16+
>>> import asyncio
17+
>>> from olmocr import run_pipeline, PipelineConfig
18+
>>>
19+
>>> config = PipelineConfig(
20+
... workspace="./workspace",
21+
... pdfs=["doc1.pdf", "doc2.pdf"],
22+
... custom_prompt="Extract text from this legal document...",
23+
... markdown=True
24+
... )
25+
>>> asyncio.run(run_pipeline(config))
26+
"""
27+
28+
# Required arguments
29+
workspace: str
30+
"""Filesystem path where work will be stored (local folder or s3://bucket/prefix/)"""
31+
32+
# PDF inputs
33+
pdfs: Optional[list[str]] = None
34+
"""List of PDF paths (can include globs like s3://bucket/*.pdf)"""
35+
36+
# Model configuration
37+
model: str = "allenai/olmOCR-2-7B-1025-FP8"
38+
"""Path to model (local, s3, or HuggingFace)"""
39+
40+
custom_prompt: Optional[str] = None
41+
"""
42+
Custom prompt to use for OCR extraction. If None, uses the default prompt.
43+
The prompt should instruct the model how to extract text from document pages.
44+
"""
45+
46+
# Server configuration
47+
server: Optional[str] = None
48+
"""URL of external vLLM server (e.g., http://hostname:port/v1). If None, spawns local server."""
49+
50+
api_key: Optional[str] = None
51+
"""API key for authenticated remote servers (e.g., DeepInfra)"""
52+
53+
# Processing options
54+
workers: int = 20
55+
"""Number of concurrent workers"""
56+
57+
max_page_retries: int = 8
58+
"""Max number of times to retry rendering a page"""
59+
60+
max_page_error_rate: float = 0.004
61+
"""Rate of allowable failed pages in a document (default: 1/250)"""
62+
63+
target_longest_image_dim: int = 1288
64+
"""Dimension on longest side for rendering PDF pages"""
65+
66+
target_anchor_text_len: int = -1
67+
"""Maximum amount of anchor text to use (characters), -1 for new models"""
68+
69+
apply_filter: bool = False
70+
"""Apply basic filtering to English PDFs (not forms, not SEO spam)"""
71+
72+
markdown: bool = False
73+
"""Also write natural text to markdown files preserving folder structure"""
74+
75+
guided_decoding: bool = False
76+
"""Enable guided decoding for model YAML type outputs"""
77+
78+
stats: bool = False
79+
"""Instead of running pipeline, report statistics about workspace"""
80+
81+
# S3 profiles
82+
workspace_profile: Optional[str] = None
83+
"""S3 configuration profile for accessing the workspace"""
84+
85+
pdf_profile: Optional[str] = None
86+
"""S3 configuration profile for accessing raw PDF documents"""
87+
88+
pages_per_group: Optional[int] = None
89+
"""Number of PDF pages per work item group (auto-calculated if None)"""
90+
91+
# VLLM configuration
92+
gpu_memory_utilization: Optional[float] = None
93+
"""Fraction of VRAM vLLM may pre-allocate for KV-cache"""
94+
95+
max_model_len: int = 16384
96+
"""Upper bound (tokens) vLLM will allocate KV-cache for"""
97+
98+
tensor_parallel_size: int = 1
99+
"""Tensor parallel size for vLLM"""
100+
101+
data_parallel_size: int = 1
102+
"""Data parallel size for vLLM"""
103+
104+
port: int = 30024
105+
"""Port to use for the VLLM server"""
106+
107+
# Beaker/cluster execution (usually not needed for programmatic use)
108+
beaker: bool = False
109+
"""Submit this job to Beaker instead of running locally"""
110+
111+
beaker_workspace: str = "ai2/olmocr"
112+
"""Beaker workspace to submit to"""
113+
114+
beaker_cluster: list[str] = field(
115+
default_factory=lambda: ["ai2/jupiter", "ai2/ceres", "ai2/neptune", "ai2/saturn"]
116+
)
117+
"""Beaker clusters to run on"""
118+
119+
beaker_gpus: int = 1
120+
"""Number of GPU replicas to run"""
121+
122+
beaker_priority: str = "normal"
123+
"""Beaker priority level for the job"""
124+
125+
def __post_init__(self):
126+
"""Validate configuration after initialization."""
127+
# Auto-calculate pages_per_group if not set
128+
if self.pages_per_group is None:
129+
# Use smaller groups for external APIs to avoid wasting money
130+
self.pages_per_group = 50 if self.api_key is not None else 500
131+
132+
# Validate workspace
133+
if not self.workspace:
134+
raise ValueError("workspace is required")
135+
136+
# Validate that we have PDFs to process (unless running stats)
137+
if not self.stats and not self.pdfs:
138+
raise ValueError("pdfs list is required when not running stats")

0 commit comments

Comments
 (0)