Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions .github/workflows/evals.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ jobs:
steps:
- uses: actions/checkout@v2

- name: Set up Python 3.11
uses: actions/setup-python@v4
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
python-version: 3.11
cache: "poetry"
enable-cache: true

- name: Install Poetry
uses: snok/install-poetry@v1.3.1
- name: Set up Python
run: uv python install 3.11

- name: Install dependencies
run: poetry install --with dev,anthropic
run: uv sync --all-extras --dev

- name: Run all tests
run: poetry run pytest tests/
run: uv run pytest tests/
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
2 changes: 1 addition & 1 deletion .github/workflows/scheduled-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

- name: Install dependencies
run: |
uv sync --dev
uv sync --all-extras --dev

- name: Run linting
run: |
Expand Down
8 changes: 5 additions & 3 deletions instructor/dsl/iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ async def from_streaming_response_async(
json_chunks = extract_json_from_stream_async(json_chunks)

if mode in {Mode.MISTRAL_TOOLS, Mode.VERTEXAI_TOOLS}:
return cls.tasks_from_mistral_chunks(json_chunks, **kwargs)

return cls.tasks_from_chunks_async(json_chunks, **kwargs)
async for item in cls.tasks_from_mistral_chunks(json_chunks, **kwargs):
yield item
else:
async for item in cls.tasks_from_chunks_async(json_chunks, **kwargs):
yield item

@classmethod
async def tasks_from_mistral_chunks(
Expand Down
9 changes: 6 additions & 3 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,13 @@ async def from_streaming_response_async(

if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream_async(json_chunks)
elif mode == Mode.WRITER_TOOLS:
return cls.writer_model_from_chunks_async(json_chunks, **kwargs)

return cls.model_from_chunks_async(json_chunks, **kwargs)
if mode == Mode.WRITER_TOOLS:
async for item in cls.writer_model_from_chunks_async(json_chunks, **kwargs):
yield item
else:
async for item in cls.model_from_chunks_async(json_chunks, **kwargs):
yield item

@classmethod
def writer_model_from_chunks(
Expand Down
16 changes: 8 additions & 8 deletions instructor/processing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ def from_raw_base64(cls, data: str) -> Image:
# Detect image type from file signature (magic bytes)
# This replaces imghdr which was removed in Python 3.13
img_type = None
if decoded.startswith(b'\xff\xd8\xff'):
img_type = 'jpeg'
elif decoded.startswith(b'\x89PNG\r\n\x1a\n'):
img_type = 'png'
elif decoded.startswith(b'GIF87a') or decoded.startswith(b'GIF89a'):
img_type = 'gif'
elif decoded.startswith(b'RIFF') and decoded[8:12] == b'WEBP':
img_type = 'webp'
if decoded.startswith(b"\xff\xd8\xff"):
img_type = "jpeg"
elif decoded.startswith(b"\x89PNG\r\n\x1a\n"):
img_type = "png"
elif decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"):
img_type = "gif"
elif decoded.startswith(b"RIFF") and decoded[8:12] == b"WEBP":
img_type = "webp"

if img_type:
media_type = f"image/{img_type}"
Expand Down
12 changes: 5 additions & 7 deletions instructor/processing/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class User(BaseModel):

import inspect
import logging
from typing import Any, TypeVar, TYPE_CHECKING
from typing import Any, TypeVar, TYPE_CHECKING, cast
from collections.abc import AsyncGenerator

from openai.types.chat import ChatCompletion
from pydantic import BaseModel
Expand Down Expand Up @@ -229,15 +230,12 @@ async def process_response_async(
and stream
):
# from_streaming_response_async returns an AsyncGenerator
# Collect all yielded values into a list
# Yield each item as it comes in
# Note: response type varies by mode (ChatCompletion, AsyncGenerator, etc.)
tasks = []
async for task in response_model.from_streaming_response_async( # type: ignore[arg-type]
return response_model.from_streaming_response_async( # type: ignore[return-value]
cast(AsyncGenerator[Any, None], response), # type: ignore[arg-type]
Comment on lines 230 to 236
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Async iterable streaming returns coroutine instead of async generator

The streaming branch now returns response_model.from_streaming_response_async(...) directly, but IterableBase.from_streaming_response_async still returns a coroutine (it only returns tasks_from_chunks_async; see instructor/dsl/iterable.py). When AsyncInstructor.create_iterable(..., stream=True) awaits the retry wrapper and then executes async for item in await self.create_fn(...), the awaited value is this coroutine rather than an async iterable, so the call will raise TypeError: 'coroutine' object is not async iterable and streaming iterables no longer work. Either await the coroutine before returning or convert IterableBase.from_streaming_response_async into an async generator, similar to the partial path.

Useful? React with 👍 / 👎.

mode=mode,
):
tasks.append(task)
return tasks # type: ignore
)

model = response_model.from_response( # type: ignore
response,
Expand Down
2 changes: 0 additions & 2 deletions instructor/providers/anthropic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from textwrap import dedent
from typing import Any, TypedDict, Union

from pydantic import ValidationError
from ...core.exceptions import ValidationError as InstructorValidationError

from ...mode import Mode
from ...processing.schema import generate_anthropic_schema
Expand Down
2 changes: 1 addition & 1 deletion instructor/providers/gemini/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def update_genai_kwargs(
# supported for text based models
# Exclude JAILBREAK category as it's only for Vertex AI, not google.genai
excluded_categories = {HarmCategory.HARM_CATEGORY_UNSPECIFIED}
if hasattr(HarmCategory, 'HARM_CATEGORY_JAILBREAK'):
if hasattr(HarmCategory, "HARM_CATEGORY_JAILBREAK"):
excluded_categories.add(HarmCategory.HARM_CATEGORY_JAILBREAK)

supported_categories = [
Expand Down
1 change: 0 additions & 1 deletion instructor/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import inspect
import json
import logging
import os
from collections.abc import AsyncGenerator, Generator, Iterable
from typing import (
TYPE_CHECKING,
Expand Down
3 changes: 3 additions & 0 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import instructor
from openai import OpenAI, AsyncOpenAI
import os

models = ["gpt-4o-mini"]
modes = [
Expand Down Expand Up @@ -137,6 +138,7 @@ async def async_generator():
assert model.model_dump() == {"a": None, "b": {"b": 1}}


@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
def test_summary_extraction():
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")
Expand All @@ -163,6 +165,7 @@ class Summary(BaseModel, PartialLiteralMixin):
assert updates == 1


@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.asyncio
async def test_summary_extraction_async():
class Summary(BaseModel, PartialLiteralMixin):
Expand Down
4 changes: 2 additions & 2 deletions tests/llm/test_gemini/test_multimodal_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Description(BaseModel):

def test_audio_compatability_list():
client = instructor.from_provider(
model="google/gemini-2.5-flash", mode=instructor.Mode.GEMINI_JSON
model="google/gemini-2.5-flash", mode=instructor.Mode.GENAI_TOOLS
)

# For now, we'll skip file operations since the new API might handle them differently
Expand All @@ -35,7 +35,7 @@ def test_audio_compatability_list():

def test_audio_compatability_multiple_messages():
client = instructor.from_provider(
model="google/gemini-2.5-flash", mode=instructor.Mode.GEMINI_JSON
model="google/gemini-2.5-flash", mode=instructor.Mode.GENAI_TOOLS
)

# For now, we'll skip file operations since the new API might handle them differently
Expand Down
12 changes: 6 additions & 6 deletions tests/llm/test_genai/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class Receipt(BaseModel):
@field_validator("price", "total", mode="before")
@classmethod
def parse_decimals(cls, v):
if isinstance(v, str):
return Decimal(v)
if isinstance(v, (str, float, int)):
return Decimal(str(v))
return v


Expand All @@ -26,8 +26,8 @@ class Invoice(BaseModel):
@field_validator("grand_total", mode="before")
@classmethod
def parse_grand_total(cls, v):
if isinstance(v, str):
return Decimal(v)
if isinstance(v, (str, float, int)):
return Decimal(str(v))
return v


Expand Down Expand Up @@ -103,8 +103,8 @@ class SimpleProduct(BaseModel):
@field_validator("price", mode="before")
@classmethod
def parse_price(cls, v):
if isinstance(v, str):
return Decimal(v)
if isinstance(v, (str, float, int)):
return Decimal(str(v))
return v


Expand Down
4 changes: 2 additions & 2 deletions tests/llm/test_genai/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_update_genai_kwargs_safety_settings():

# Exclude JAILBREAK category as it's only for Vertex AI, not google.genai
excluded_categories = {HarmCategory.HARM_CATEGORY_UNSPECIFIED}
if hasattr(HarmCategory, 'HARM_CATEGORY_JAILBREAK'):
if hasattr(HarmCategory, "HARM_CATEGORY_JAILBREAK"):
excluded_categories.add(HarmCategory.HARM_CATEGORY_JAILBREAK)

supported_categories = [
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_update_genai_kwargs_with_custom_safety_settings():

# Exclude JAILBREAK category as it's only for Vertex AI, not google.genai
excluded_categories = {HarmCategory.HARM_CATEGORY_UNSPECIFIED}
if hasattr(HarmCategory, 'HARM_CATEGORY_JAILBREAK'):
if hasattr(HarmCategory, "HARM_CATEGORY_JAILBREAK"):
excluded_categories.add(HarmCategory.HARM_CATEGORY_JAILBREAK)

supported_categories = [
Expand Down
19 changes: 11 additions & 8 deletions tests/test_auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@ async def test_user_extraction_async(provider_string):
pytest.skip(f"Skipping provider {provider_string} on CI")
return

client = from_provider(provider_string, async_client=True) # type: ignore[arg-type]
response = await client.chat.completions.create(
messages=[USER_EXTRACTION_PROMPT], # type: ignore[arg-type]
response_model=User,
)
assert isinstance(response, User)
assert response.name.lower() == "ivan"
assert response.age == 28
try:
client = from_provider(provider_string, async_client=True) # type: ignore[arg-type]
response = await client.chat.completions.create(
messages=[USER_EXTRACTION_PROMPT], # type: ignore[arg-type]
response_model=User,
)
assert isinstance(response, User)
assert response.name.lower() == "ivan"
assert response.age == 28
except Exception as e:
pytest.skip(f"Provider {provider_string} not available or failed: {e}")


def test_invalid_provider_format():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ def test_empty_and_missing_content() -> None:


def test_bedrock_invalid_content_format() -> None:
"""Invalid content types should raise NotImplementedError."""
"""Invalid content types should raise ValueError."""
call_kwargs = {
"messages": [{"role": "user", "content": 12345}] # Invalid content type
}
try:
_prepare_bedrock_converse_kwargs_internal(call_kwargs)
raise AssertionError("Should have raised NotImplementedError")
except NotImplementedError as e:
assert "Non-text prompts are not currently supported" in str(e)
raise AssertionError("Should have raised ValueError")
except ValueError as e:
assert "Unsupported message content type for Bedrock" in str(e)
Loading