Skip to content

Commit 88006ec

Browse files
committed
Add tests for updated extract content functionality and provider specific extraction and id mapping
1 parent 9013b01 commit 88006ec

File tree

1 file changed

+150
-2
lines changed

1 file changed

+150
-2
lines changed

tests/_server/ai/test_providers.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for the LLM providers in marimo._server.ai.providers."""
22

3-
from unittest.mock import AsyncMock, patch
3+
from unittest.mock import AsyncMock, MagicMock, patch
44

55
import pytest
66

@@ -56,7 +56,6 @@ def test_anyprovider_for_model(model_name: str, provider_name: str) -> None:
5656
if provider_name != "bedrock":
5757
assert config.api_key == f"{provider_name}-key"
5858
else:
59-
# bedrock overloads the api_key for profile name
6059
assert config.api_key == "profile:aws-profile"
6160

6261

@@ -172,3 +171,152 @@ async def test_azure_openai_provider() -> None:
172171
assert api_version == "2023-05-15"
173172
assert deployment_name == "gpt-4-1"
174173
assert endpoint == "https://unknown_domain.openai"
174+
175+
176+
@pytest.mark.parametrize(
177+
"provider_type",
178+
[
179+
pytest.param(OpenAIProvider, id="openai"),
180+
pytest.param(BedrockProvider, id="bedrock"),
181+
],
182+
)
183+
def test_extract_content_with_none_tool_call_ids(
184+
provider_type: type,
185+
) -> None:
186+
"""Test extract_content handles None tool_call_ids without errors."""
187+
config = AnyProviderConfig(api_key="test-key", base_url="http://test")
188+
provider = provider_type("test-model", config)
189+
190+
mock_response = MagicMock()
191+
mock_delta = MagicMock()
192+
mock_delta.content = "Hello"
193+
mock_delta.tool_calls = None
194+
mock_choice = MagicMock()
195+
mock_choice.delta = mock_delta
196+
mock_response.choices = [mock_choice]
197+
198+
result = provider.extract_content(mock_response, None)
199+
assert result == [("Hello", "text")]
200+
201+
202+
def test_google_extract_content_with_none_tool_call_ids() -> None:
203+
"""Test Google extract_content handles None tool_call_ids without errors."""
204+
config = AnyProviderConfig(api_key="test-key", base_url="http://test")
205+
provider = GoogleProvider("gemini-1.5-flash", config)
206+
207+
mock_response = MagicMock()
208+
mock_candidate = MagicMock()
209+
mock_content = MagicMock()
210+
mock_part = MagicMock()
211+
mock_part.text = "Hello"
212+
mock_part.thought = False
213+
mock_part.function_call = None
214+
mock_content.parts = [mock_part]
215+
mock_candidate.content = mock_content
216+
mock_response.candidates = [mock_candidate]
217+
218+
result = provider.extract_content(mock_response, None)
219+
assert result == [("Hello", "text")]
220+
221+
222+
def test_openai_extract_content_multiple_tool_calls() -> None:
223+
"""Test OpenAI extracts multiple tool calls correctly."""
224+
config = AnyProviderConfig(api_key="test-key", base_url="http://test")
225+
provider = OpenAIProvider("gpt-4", config)
226+
227+
mock_response = MagicMock()
228+
mock_delta = MagicMock()
229+
mock_delta.content = None
230+
231+
mock_tool_1 = MagicMock()
232+
mock_tool_1.index = 0
233+
mock_tool_1.id = "call_1"
234+
mock_tool_1.function = MagicMock()
235+
mock_tool_1.function.name = "get_weather"
236+
mock_tool_1.function.arguments = None
237+
238+
mock_tool_2 = MagicMock()
239+
mock_tool_2.index = 1
240+
mock_tool_2.id = "call_2"
241+
mock_tool_2.function = MagicMock()
242+
mock_tool_2.function.name = "get_time"
243+
mock_tool_2.function.arguments = None
244+
245+
mock_delta.tool_calls = [mock_tool_1, mock_tool_2]
246+
mock_choice = MagicMock()
247+
mock_choice.delta = mock_delta
248+
mock_response.choices = [mock_choice]
249+
250+
result = provider.extract_content(mock_response, None)
251+
assert result is not None
252+
assert len(result) == 2
253+
tool_data_0, _ = result[0]
254+
tool_data_1, _ = result[1]
255+
assert isinstance(tool_data_0, dict)
256+
assert isinstance(tool_data_1, dict)
257+
assert tool_data_0["toolName"] == "get_weather"
258+
assert tool_data_1["toolName"] == "get_time"
259+
260+
261+
def test_google_extract_content_id_rectification() -> None:
262+
"""Test Google uses provided tool_call_ids for ID rectification."""
263+
config = AnyProviderConfig(api_key="test-key", base_url="http://test")
264+
provider = GoogleProvider("gemini-1.5-flash", config)
265+
266+
mock_response = MagicMock()
267+
mock_candidate = MagicMock()
268+
mock_content = MagicMock()
269+
mock_func_call = MagicMock()
270+
mock_func_call.name = "get_weather"
271+
mock_func_call.args = {"location": "SF"}
272+
mock_func_call.id = None
273+
mock_part = MagicMock()
274+
mock_part.text = None
275+
mock_part.function_call = mock_func_call
276+
mock_content.parts = [mock_part]
277+
mock_candidate.content = mock_content
278+
mock_response.candidates = [mock_candidate]
279+
280+
result = provider.extract_content(mock_response, ["stable_id"])
281+
assert result is not None
282+
tool_data, _ = result[0]
283+
assert isinstance(tool_data, dict)
284+
assert tool_data["toolCallId"] == "stable_id"
285+
286+
287+
def test_anthropic_extract_content_tool_call_id_mapping() -> None:
288+
"""Test Anthropic maps tool call IDs via block index."""
289+
try:
290+
from anthropic.types import (
291+
InputJSONDelta,
292+
RawContentBlockDeltaEvent,
293+
RawContentBlockStartEvent,
294+
ToolUseBlock,
295+
)
296+
except ImportError:
297+
pytest.skip("Anthropic not installed")
298+
299+
config = AnyProviderConfig(api_key="test-key", base_url="http://test")
300+
provider = AnthropicProvider("claude-3-opus-20240229", config)
301+
302+
start_event = RawContentBlockStartEvent(
303+
type="content_block_start",
304+
index=0,
305+
content_block=ToolUseBlock(
306+
type="tool_use", id="toolu_123", name="get_weather", input={}
307+
),
308+
)
309+
provider.extract_content(start_event, None)
310+
311+
delta_event = RawContentBlockDeltaEvent(
312+
type="content_block_delta",
313+
index=0,
314+
delta=InputJSONDelta(
315+
type="input_json_delta", partial_json='{"location": "SF"}'
316+
),
317+
)
318+
result = provider.extract_content(delta_event, None)
319+
assert result is not None
320+
tool_data, _ = result[0]
321+
assert isinstance(tool_data, dict)
322+
assert tool_data["toolCallId"] == "toolu_123"

0 commit comments

Comments
 (0)