|
1 | 1 | """Tests for the LLM providers in marimo._server.ai.providers.""" |
2 | 2 |
|
3 | | -from unittest.mock import AsyncMock, patch |
| 3 | +from unittest.mock import AsyncMock, MagicMock, patch |
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 |
|
@@ -56,7 +56,6 @@ def test_anyprovider_for_model(model_name: str, provider_name: str) -> None: |
56 | 56 | if provider_name != "bedrock": |
57 | 57 | assert config.api_key == f"{provider_name}-key" |
58 | 58 | else: |
59 | | - # bedrock overloads the api_key for profile name |
60 | 59 | assert config.api_key == "profile:aws-profile" |
61 | 60 |
|
62 | 61 |
|
@@ -172,3 +171,152 @@ async def test_azure_openai_provider() -> None: |
172 | 171 | assert api_version == "2023-05-15" |
173 | 172 | assert deployment_name == "gpt-4-1" |
174 | 173 | assert endpoint == "https://unknown_domain.openai" |
| 174 | + |
| 175 | + |
| 176 | +@pytest.mark.parametrize( |
| 177 | + "provider_type", |
| 178 | + [ |
| 179 | + pytest.param(OpenAIProvider, id="openai"), |
| 180 | + pytest.param(BedrockProvider, id="bedrock"), |
| 181 | + ], |
| 182 | +) |
| 183 | +def test_extract_content_with_none_tool_call_ids( |
| 184 | + provider_type: type, |
| 185 | +) -> None: |
| 186 | + """Test extract_content handles None tool_call_ids without errors.""" |
| 187 | + config = AnyProviderConfig(api_key="test-key", base_url="http://test") |
| 188 | + provider = provider_type("test-model", config) |
| 189 | + |
| 190 | + mock_response = MagicMock() |
| 191 | + mock_delta = MagicMock() |
| 192 | + mock_delta.content = "Hello" |
| 193 | + mock_delta.tool_calls = None |
| 194 | + mock_choice = MagicMock() |
| 195 | + mock_choice.delta = mock_delta |
| 196 | + mock_response.choices = [mock_choice] |
| 197 | + |
| 198 | + result = provider.extract_content(mock_response, None) |
| 199 | + assert result == [("Hello", "text")] |
| 200 | + |
| 201 | + |
| 202 | +def test_google_extract_content_with_none_tool_call_ids() -> None: |
| 203 | + """Test Google extract_content handles None tool_call_ids without errors.""" |
| 204 | + config = AnyProviderConfig(api_key="test-key", base_url="http://test") |
| 205 | + provider = GoogleProvider("gemini-1.5-flash", config) |
| 206 | + |
| 207 | + mock_response = MagicMock() |
| 208 | + mock_candidate = MagicMock() |
| 209 | + mock_content = MagicMock() |
| 210 | + mock_part = MagicMock() |
| 211 | + mock_part.text = "Hello" |
| 212 | + mock_part.thought = False |
| 213 | + mock_part.function_call = None |
| 214 | + mock_content.parts = [mock_part] |
| 215 | + mock_candidate.content = mock_content |
| 216 | + mock_response.candidates = [mock_candidate] |
| 217 | + |
| 218 | + result = provider.extract_content(mock_response, None) |
| 219 | + assert result == [("Hello", "text")] |
| 220 | + |
| 221 | + |
| 222 | +def test_openai_extract_content_multiple_tool_calls() -> None: |
| 223 | + """Test OpenAI extracts multiple tool calls correctly.""" |
| 224 | + config = AnyProviderConfig(api_key="test-key", base_url="http://test") |
| 225 | + provider = OpenAIProvider("gpt-4", config) |
| 226 | + |
| 227 | + mock_response = MagicMock() |
| 228 | + mock_delta = MagicMock() |
| 229 | + mock_delta.content = None |
| 230 | + |
| 231 | + mock_tool_1 = MagicMock() |
| 232 | + mock_tool_1.index = 0 |
| 233 | + mock_tool_1.id = "call_1" |
| 234 | + mock_tool_1.function = MagicMock() |
| 235 | + mock_tool_1.function.name = "get_weather" |
| 236 | + mock_tool_1.function.arguments = None |
| 237 | + |
| 238 | + mock_tool_2 = MagicMock() |
| 239 | + mock_tool_2.index = 1 |
| 240 | + mock_tool_2.id = "call_2" |
| 241 | + mock_tool_2.function = MagicMock() |
| 242 | + mock_tool_2.function.name = "get_time" |
| 243 | + mock_tool_2.function.arguments = None |
| 244 | + |
| 245 | + mock_delta.tool_calls = [mock_tool_1, mock_tool_2] |
| 246 | + mock_choice = MagicMock() |
| 247 | + mock_choice.delta = mock_delta |
| 248 | + mock_response.choices = [mock_choice] |
| 249 | + |
| 250 | + result = provider.extract_content(mock_response, None) |
| 251 | + assert result is not None |
| 252 | + assert len(result) == 2 |
| 253 | + tool_data_0, _ = result[0] |
| 254 | + tool_data_1, _ = result[1] |
| 255 | + assert isinstance(tool_data_0, dict) |
| 256 | + assert isinstance(tool_data_1, dict) |
| 257 | + assert tool_data_0["toolName"] == "get_weather" |
| 258 | + assert tool_data_1["toolName"] == "get_time" |
| 259 | + |
| 260 | + |
| 261 | +def test_google_extract_content_id_rectification() -> None: |
| 262 | + """Test Google uses provided tool_call_ids for ID rectification.""" |
| 263 | + config = AnyProviderConfig(api_key="test-key", base_url="http://test") |
| 264 | + provider = GoogleProvider("gemini-1.5-flash", config) |
| 265 | + |
| 266 | + mock_response = MagicMock() |
| 267 | + mock_candidate = MagicMock() |
| 268 | + mock_content = MagicMock() |
| 269 | + mock_func_call = MagicMock() |
| 270 | + mock_func_call.name = "get_weather" |
| 271 | + mock_func_call.args = {"location": "SF"} |
| 272 | + mock_func_call.id = None |
| 273 | + mock_part = MagicMock() |
| 274 | + mock_part.text = None |
| 275 | + mock_part.function_call = mock_func_call |
| 276 | + mock_content.parts = [mock_part] |
| 277 | + mock_candidate.content = mock_content |
| 278 | + mock_response.candidates = [mock_candidate] |
| 279 | + |
| 280 | + result = provider.extract_content(mock_response, ["stable_id"]) |
| 281 | + assert result is not None |
| 282 | + tool_data, _ = result[0] |
| 283 | + assert isinstance(tool_data, dict) |
| 284 | + assert tool_data["toolCallId"] == "stable_id" |
| 285 | + |
| 286 | + |
| 287 | +def test_anthropic_extract_content_tool_call_id_mapping() -> None: |
| 288 | + """Test Anthropic maps tool call IDs via block index.""" |
| 289 | + try: |
| 290 | + from anthropic.types import ( |
| 291 | + InputJSONDelta, |
| 292 | + RawContentBlockDeltaEvent, |
| 293 | + RawContentBlockStartEvent, |
| 294 | + ToolUseBlock, |
| 295 | + ) |
| 296 | + except ImportError: |
| 297 | + pytest.skip("Anthropic not installed") |
| 298 | + |
| 299 | + config = AnyProviderConfig(api_key="test-key", base_url="http://test") |
| 300 | + provider = AnthropicProvider("claude-3-opus-20240229", config) |
| 301 | + |
| 302 | + start_event = RawContentBlockStartEvent( |
| 303 | + type="content_block_start", |
| 304 | + index=0, |
| 305 | + content_block=ToolUseBlock( |
| 306 | + type="tool_use", id="toolu_123", name="get_weather", input={} |
| 307 | + ), |
| 308 | + ) |
| 309 | + provider.extract_content(start_event, None) |
| 310 | + |
| 311 | + delta_event = RawContentBlockDeltaEvent( |
| 312 | + type="content_block_delta", |
| 313 | + index=0, |
| 314 | + delta=InputJSONDelta( |
| 315 | + type="input_json_delta", partial_json='{"location": "SF"}' |
| 316 | + ), |
| 317 | + ) |
| 318 | + result = provider.extract_content(delta_event, None) |
| 319 | + assert result is not None |
| 320 | + tool_data, _ = result[0] |
| 321 | + assert isinstance(tool_data, dict) |
| 322 | + assert tool_data["toolCallId"] == "toolu_123" |
0 commit comments