|
| 1 | +"""Regression test for async streaming with Mode.GEMINI_TOOLS. |
| 2 | +
|
| 3 | +The sync paths in PartialBase.from_streaming_response and |
| 4 | +IterableBase.from_streaming_response apply extract_json_from_stream |
| 5 | +for both Mode.MD_JSON and Mode.GEMINI_TOOLS, but the async paths |
| 6 | +were only applying it for Mode.MD_JSON. |
| 7 | +""" |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from instructor.mode import Mode |
| 12 | +from instructor.utils.core import ( |
| 13 | + extract_json_from_stream, |
| 14 | + extract_json_from_stream_async, |
| 15 | +) |
| 16 | + |
| 17 | + |
| 18 | +def test_sync_extract_json_from_stream_handles_codeblock(): |
| 19 | + chunks = ["```json\n", '{"name": "Alice",', ' "age": 30}', "\n```"] |
| 20 | + result = "".join(extract_json_from_stream(iter(chunks))) |
| 21 | + assert result == '{"name": "Alice", "age": 30}' |
| 22 | + |
| 23 | + |
| 24 | +@pytest.mark.asyncio |
| 25 | +async def test_async_extract_json_from_stream_handles_codeblock(): |
| 26 | + chunks = ["```json\n", '{"name": "Alice",', ' "age": 30}', "\n```"] |
| 27 | + |
| 28 | + async def async_chunks(): |
| 29 | + for c in chunks: |
| 30 | + yield c |
| 31 | + |
| 32 | + result = "".join([c async for c in extract_json_from_stream_async(async_chunks())]) |
| 33 | + assert result == '{"name": "Alice", "age": 30}' |
| 34 | + |
| 35 | + |
| 36 | +def test_sync_gemini_tools_mode_triggers_json_extraction(): |
| 37 | + """Verify that GEMINI_TOOLS is in the set that triggers extract_json_from_stream |
| 38 | + in the sync from_streaming_response path.""" |
| 39 | + # This tests the condition that was already correct in the sync path |
| 40 | + assert Mode.GEMINI_TOOLS in {Mode.MD_JSON, Mode.GEMINI_TOOLS} |
| 41 | + |
| 42 | + |
| 43 | +def test_async_gemini_tools_mode_triggers_json_extraction(): |
| 44 | + """Verify the fix: GEMINI_TOOLS must be in the set that triggers |
| 45 | + extract_json_from_stream_async in the async from_streaming_response_async path. |
| 46 | +
|
| 47 | + Before the fix, the async path only checked `mode == Mode.MD_JSON`, |
| 48 | + so GEMINI_TOOLS streaming would skip JSON extraction from code blocks. |
| 49 | + """ |
| 50 | + # After the fix, both sync and async paths use the same set |
| 51 | + mode = Mode.GEMINI_TOOLS |
| 52 | + # This is the condition in the fixed async path |
| 53 | + assert mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS} |
0 commit comments