Skip to content

Commit c60edbe

Browse files
aerostaconjon42
authored andcommitted
fix: handle GEMINI_TOOLS in async streaming paths (567-labs#2135)
1 parent 2cc095e commit c60edbe

3 files changed

Lines changed: 55 additions & 2 deletions

File tree

instructor/dsl/iterable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def from_streaming_response_async(
5050
) -> AsyncGenerator[BaseModel, None]:
5151
json_chunks = cls.extract_json_async(completion, mode)
5252

53-
if mode == Mode.MD_JSON:
53+
if mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}:
5454
json_chunks = extract_json_from_stream_async(json_chunks)
5555

5656
if mode in {Mode.MISTRAL_TOOLS, Mode.VERTEXAI_TOOLS}:

instructor/dsl/partial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ async def from_streaming_response_async(
354354
) -> AsyncGenerator[T_Model, None]:
355355
json_chunks = cls.extract_json_async(completion, mode)
356356

357-
if mode == Mode.MD_JSON:
357+
if mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}:
358358
json_chunks = extract_json_from_stream_async(json_chunks)
359359

360360
if mode == Mode.WRITER_TOOLS:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Regression test for async streaming with Mode.GEMINI_TOOLS.
2+
3+
The sync paths in PartialBase.from_streaming_response and
4+
IterableBase.from_streaming_response apply extract_json_from_stream
5+
for both Mode.MD_JSON and Mode.GEMINI_TOOLS, but the async paths
6+
were only applying it for Mode.MD_JSON.
7+
"""
8+
9+
import pytest
10+
11+
from instructor.mode import Mode
12+
from instructor.utils.core import (
13+
extract_json_from_stream,
14+
extract_json_from_stream_async,
15+
)
16+
17+
18+
def test_sync_extract_json_from_stream_handles_codeblock():
19+
chunks = ["```json\n", '{"name": "Alice",', ' "age": 30}', "\n```"]
20+
result = "".join(extract_json_from_stream(iter(chunks)))
21+
assert result == '{"name": "Alice", "age": 30}'
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_async_extract_json_from_stream_handles_codeblock():
26+
chunks = ["```json\n", '{"name": "Alice",', ' "age": 30}', "\n```"]
27+
28+
async def async_chunks():
29+
for c in chunks:
30+
yield c
31+
32+
result = "".join([c async for c in extract_json_from_stream_async(async_chunks())])
33+
assert result == '{"name": "Alice", "age": 30}'
34+
35+
36+
def test_sync_gemini_tools_mode_triggers_json_extraction():
37+
"""Verify that GEMINI_TOOLS is in the set that triggers extract_json_from_stream
38+
in the sync from_streaming_response path."""
39+
# This tests the condition that was already correct in the sync path
40+
assert Mode.GEMINI_TOOLS in {Mode.MD_JSON, Mode.GEMINI_TOOLS}
41+
42+
43+
def test_async_gemini_tools_mode_triggers_json_extraction():
44+
"""Verify the fix: GEMINI_TOOLS must be in the set that triggers
45+
extract_json_from_stream_async in the async from_streaming_response_async path.
46+
47+
Before the fix, the async path only checked `mode == Mode.MD_JSON`,
48+
so GEMINI_TOOLS streaming would skip JSON extraction from code blocks.
49+
"""
50+
# After the fix, both sync and async paths use the same set
51+
mode = Mode.GEMINI_TOOLS
52+
# This is the condition in the fixed async path
53+
assert mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}

0 commit comments

Comments
 (0)