From 54f96d016d71427277dd1ae03bde036fcfca595b Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 3 Oct 2025 15:35:53 +0200 Subject: [PATCH 1/5] Add multi-tool support to marimo backend --- marimo/_server/ai/providers.py | 488 ++++++++++++++++++++------------- 1 file changed, 303 insertions(+), 185 deletions(-) diff --git a/marimo/_server/ai/providers.py b/marimo/_server/ai/providers.py index de542038efd..992510af7e7 100644 --- a/marimo/_server/ai/providers.py +++ b/marimo/_server/ai/providers.py @@ -79,10 +79,16 @@ # Types for extract_content method return DictContent = tuple[ dict[str, Any], - Literal["tool_call_start", "tool_call_end", "reasoning_signature"], + Literal[ + "tool_call_start", + "tool_call_end", + "reasoning_signature", + "tool_call_delta", + ], ] -TextContent = tuple[str, Literal["text", "reasoning", "tool_call_delta"]] +TextContent = tuple[str, Literal["text", "reasoning"]] ExtractedContent = Union[TextContent, DictContent] +ExtractedContentList = list[ExtractedContent] # Types for format_stream method parameter FinishContent = tuple[FinishReason, Literal["finish_reason"]] @@ -108,6 +114,13 @@ class StreamOptions: format_stream: bool = False +@dataclass +class ActiveToolCall: + tool_call_id: str + tool_call_name: str + tool_call_args: str + + class CompletionProvider(Generic[ResponseT, StreamT], ABC): """Base class for AI completion providers.""" @@ -128,8 +141,8 @@ async def stream_completion( @abstractmethod def extract_content( - self, response: ResponseT, tool_call_id: Optional[str] = None - ) -> Optional[ExtractedContent]: + self, response: ResponseT, tool_call_ids: list[str] = [] + ) -> Optional[ExtractedContentList]: """Extract content from a response chunk.""" pass @@ -233,10 +246,9 @@ async def as_stream_response( options = options or StreamOptions() # Tool info collected from the first chunk - tool_call_id: Optional[str] = None - tool_call_name: Optional[str] = None - # Tool args collected from the tool_call_delta chunks - tool_call_args: str = "" + tool_calls: dict[str, ActiveToolCall] = {} + tool_calls_order: list[str] = [] + # Finish reason collected from the last chunk finish_reason: Optional[FinishReason] = None @@ -252,99 +264,128 @@ async def as_stream_response( # If we check content first, these chunks get skipped and finish reason is never detected finish_reason = self.get_finish_reason(chunk) or finish_reason - content = self.extract_content(chunk, tool_call_id) + content = self.extract_content(chunk, tool_calls_order) if not content: continue - content_data, content_type = content + # Loop through all content chunks + for content_data, content_type in content: + if options.text_only and content_type != "text": + continue - if options.text_only and content_type != "text": - continue - - # Handle text content with start/delta/end pattern - if ( - content_type == "text" - and isinstance(content_data, str) - and options.format_stream - ): - if not has_text_started: - # Emit text-start event - current_text_id = f"text_{uuid.uuid4().hex}" + # Handle text content with start/delta/end pattern + if ( + content_type == "text" + and isinstance(content_data, str) + and options.format_stream + ): + if not has_text_started: + # Emit text-start event + current_text_id = f"text_{uuid.uuid4().hex}" + yield convert_to_ai_sdk_messages( + "", "text_start", current_text_id + ) + has_text_started = True + + # Emit text-delta event with the actual content yield convert_to_ai_sdk_messages( - "", "text_start", current_text_id + content_data, "text", current_text_id ) - has_text_started = True + continue - # Emit text-delta event with the actual content - yield convert_to_ai_sdk_messages( - content_data, "text", current_text_id - ) - continue - - # Handle reasoning content with start/delta/end pattern - elif ( - content_type == "reasoning" - and isinstance(content_data, str) - and options.format_stream - ): - if not has_reasoning_started: - # Emit reasoning-start event - current_reasoning_id = f"reasoning_{uuid.uuid4().hex}" + # Handle reasoning content with start/delta/end pattern + elif ( + content_type == "reasoning" + and isinstance(content_data, str) + and options.format_stream + ): + if not has_reasoning_started: + # Emit reasoning-start event + current_reasoning_id = f"reasoning_{uuid.uuid4().hex}" + yield convert_to_ai_sdk_messages( + "", "reasoning_start", current_reasoning_id + ) + has_reasoning_started = True + + # Emit reasoning-delta event with the actual content yield convert_to_ai_sdk_messages( - "", "reasoning_start", current_reasoning_id + content_data, "reasoning", current_reasoning_id ) - has_reasoning_started = True + continue - # Emit reasoning-delta event with the actual content - yield convert_to_ai_sdk_messages( - content_data, "reasoning", current_reasoning_id - ) - continue + # Tool handling + if content_type == "tool_call_start" and isinstance( + content_data, dict + ): + tool_call_id: Optional[str] = content_data.get( + "toolCallId", None + ) + tool_call_name: Optional[str] = content_data.get( + "toolName", None + ) + # Sometimes GoogleProvider emits the args in the tool_call_start chunk + tool_call_args: str = "" + if content_data.get("args"): + # don't yield args in tool_call_start chunk + # it will throw an error in ai-sdk-ui + tool_call_args = content_data.pop("args") + + if tool_call_id and tool_call_name: + # Add new tool calls to the list for tracking + tool_calls_order.append(tool_call_id) + tool_calls[tool_call_id] = ActiveToolCall( + tool_call_id=tool_call_id, + tool_call_name=tool_call_name, + tool_call_args=tool_call_args, + ) + + if content_type == "tool_call_delta" and isinstance( + content_data, dict + ): + tool_call_delta_id = content_data.get("toolCallId", None) + tool_call_delta: str = content_data.get( + "inputTextDelta", "" + ) - # Tool handling - if content_type == "tool_call_start" and isinstance( - content_data, dict - ): - tool_call_id = content_data.get("toolCallId", None) - tool_call_name = content_data.get("toolName", None) - # Sometimes GoogleProvider emits the args in the tool_call_start chunk - if content_data.get("args"): - # don't yield args in tool_call_start chunk - # it will throw an error in ai-sdk-ui - tool_call_args = content_data.pop("args") - - if content_type == "tool_call_delta" and isinstance( - content_data, str - ): - if isinstance(self, GoogleProvider): - # For GoogleProvider, each chunk contains the full (possibly updated) args dict as a JSON string. - # Example: first chunk: {"location": "San Francisco"} - # second chunk: {"location": "San Francisco", "zip": "94107"} - # We overwrite tool_call_args with the latest chunk. - tool_call_args = content_data - else: - # For other providers, tool_call_args is built up incrementally from deltas. - tool_call_args += content_data - # update tool_call_delta to ai-sdk-ui structure - # based on https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#tool-call-delta-part - content_data = { - "toolCallId": tool_call_id, - "inputTextDelta": content_data, - } + if not tool_call_delta_id: + if not tool_call_delta_id: + LOGGER.error( + f"Tool call id not found for tool call delta: {content_data}" + ) + continue + tool_call = tool_calls.get(tool_call_delta_id, None) + if not tool_call: + continue + + if isinstance(self, GoogleProvider): + # For GoogleProvider, each chunk contains the full (possibly updated) args dict as a JSON string. + # Example: first chunk: {"location": "San Francisco"} + # second chunk: {"location": "San Francisco", "zip": "94107"} + # We overwrite tool_call_args with the latest chunk. + tool_call.tool_call_args = tool_call_delta + else: + # For other providers, tool_call_args is built up incrementally from deltas. + tool_call.tool_call_args += tool_call_delta + # update tool_call_delta to ai-sdk-ui structure + # based on https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#tool-call-delta-part + content_data = { + "toolCallId": tool_call.tool_call_id, + "inputTextDelta": tool_call.tool_call_args, + } - content_str = self._content_to_string(content_data) + content_str = self._content_to_string(content_data) - if options.format_stream: - stream_content = self._create_stream_content( - content_data, content_type - ) - content_str = self.format_stream(stream_content) + if options.format_stream: + stream_content = self._create_stream_content( + content_data, content_type + ) + content_str = self.format_stream(stream_content) - buffer += content_str - original_content += content_str + buffer += content_str + original_content += content_str - yield buffer - buffer = "" + yield buffer + buffer = "" # Emit text-end event if we started a text block if has_text_started and current_text_id and options.format_stream: @@ -361,19 +402,20 @@ async def as_stream_response( ) # Handle tool call end after the stream is complete - if tool_call_id and tool_call_name and not options.text_only: - content_data = { - "toolCallId": tool_call_id, - "toolName": tool_call_name, - "input": self.validate_tool_call_args(tool_call_args) - or {}, # empty object if tool doesnt have args - } - content_type = "tool_call_end" - yield self.format_stream((content_data, content_type)) - # Reset tool call state for next stream just in case - tool_call_id = None - tool_call_name = None - tool_call_args = "" + if len(tool_calls_order) > 0 and not options.text_only: + for tool_call_id in tool_calls_order: + tool_call = tool_calls.get(tool_call_id, None) + if not tool_call: + continue + content_data = { + "toolCallId": tool_call_id, + "toolName": tool_call.tool_call_name, + "input": self.validate_tool_call_args( + tool_call.tool_call_args + ), + } + content_type = "tool_call_end" + yield self.format_stream((content_data, content_type)) # Add a final finish reason chunk if finish_reason and not options.text_only: @@ -513,9 +555,8 @@ async def stream_completion( def extract_content( self, response: ChatCompletionChunk, - tool_call_id: Optional[str] = None, - ) -> Optional[ExtractedContent]: - del tool_call_id + tool_call_ids: list[str] = [], + ) -> Optional[ExtractedContentList]: if ( hasattr(response, "choices") and response.choices @@ -526,29 +567,42 @@ def extract_content( # Text content content = delta.content if content: - return (content, "text") + return [(content, "text")] # Tool call: if delta.tool_calls: - tool_calls = delta.tool_calls[0] - - # Start of tool call - # id is only present for the first tool call chunk - if ( - tool_calls.id - and tool_calls.function - and tool_calls.function.name - ): - tool_info = { - "toolCallId": tool_calls.id, - "toolName": tool_calls.function.name, - } - return (tool_info, "tool_call_start") - - # Delta of tool call - # arguments is only present second chunk onwards - if tool_calls.function and tool_calls.function.arguments: - return (tool_calls.function.arguments, "tool_call_delta") + tool_content: ExtractedContentList = [] + for tool_call in delta.tool_calls: + tool_index = tool_call.index + + # Start of tool call + # id is only present for the first tool call chunk + if ( + tool_call.id + and tool_call.function + and tool_call.function.name + ): + tool_info = { + "toolCallId": tool_call.id, + "toolName": tool_call.function.name, + } + tool_content.append((tool_info, "tool_call_start")) + + # Delta of tool call + # arguments is only present second chunk onwards + if ( + tool_call.function + and tool_call.function.arguments + and tool_call_ids[tool_index] + ): + tool_delta = { + "toolCallId": tool_call_ids[tool_index], + "inputTextDelta": tool_call.function.arguments, + } + tool_content.append((tool_delta, "tool_call_delta")) + + # return the tool content + return tool_content return None @@ -671,6 +725,9 @@ class AnthropicProvider( # 1024 tokens is the minimum budget for extended thinking DEFAULT_EXTENDED_THINKING_BUDGET_TOKENS = 1024 + # Map of block index to tool call id for tool call delta chunks + block_index_to_tool_call_id_map: dict[int, str] = {} + def is_extended_thinking_model(self, model: str) -> bool: return any( model.startswith(prefix) @@ -692,6 +749,9 @@ def get_client(self, config: AnyProviderConfig) -> AsyncClient: return AsyncClient(api_key=config.api_key) + def maybe_get_tool_call_id(self, block_index: int) -> Optional[str]: + return self.block_index_to_tool_call_id_map.get(block_index, None) + async def stream_completion( self, messages: list[ChatMessage], @@ -725,12 +785,15 @@ async def stream_completion( await client.messages.create(**create_params), ) + def block_index_to_tool_call_id(self, block_index: int) -> str: + return f"tool_call_{block_index}" + def extract_content( self, response: RawMessageStreamEvent, - tool_call_id: Optional[str] = None, - ) -> Optional[ExtractedContent]: - del tool_call_id + tool_call_ids: list[str] = [], + ) -> Optional[ExtractedContentList]: + del tool_call_ids from anthropic.types import ( InputJSONDelta, RawContentBlockDeltaEvent, @@ -744,25 +807,46 @@ def extract_content( # For streaming content if isinstance(response, RawContentBlockDeltaEvent): if isinstance(response.delta, TextDelta): - return (response.delta.text, "text") + return [(response.delta.text, "text")] if isinstance(response.delta, ThinkingDelta): - return (response.delta.thinking, "reasoning") + return [(response.delta.thinking, "reasoning")] if isinstance(response.delta, InputJSONDelta): - return (response.delta.partial_json, "tool_call_delta") + block_index = response.index + tool_call_id = self.maybe_get_tool_call_id(block_index) + if not tool_call_id: + LOGGER.error( + f"Tool call id not found for block index: {response.index}" + ) + return None + delta_json = response.delta.partial_json + tool_delta = { + "toolCallId": tool_call_id, + "inputTextDelta": delta_json, + } + return [(tool_delta, "tool_call_delta")] if isinstance(response.delta, SignatureDelta): - return ( - {"signature": response.delta.signature}, - "reasoning_signature", - ) + return [ + ( + {"signature": response.delta.signature}, + "reasoning_signature", + ) + ] # For the beginning of a tool use block if isinstance(response, RawContentBlockStartEvent): if isinstance(response.content_block, ToolUseBlock): + tool_call_id = response.content_block.id + tool_call_name = response.content_block.name + block_index = response.index + # Store the tool call id for the block index + self.block_index_to_tool_call_id_map[block_index] = ( + tool_call_id + ) tool_info = { - "toolCallId": response.content_block.id, - "toolName": response.content_block.name, + "toolCallId": tool_call_id, + "toolName": tool_call_name, } - return (tool_info, "tool_call_start") + return [(tool_info, "tool_call_start")] return None @@ -880,7 +964,7 @@ async def stream_completion( ), ) - def _get_tool_call_id(self, tool_call_id: Optional[str]) -> Optional[str]: + def _get_tool_call_id(self, tool_call_id: Optional[str]) -> str: # Custom tools don't have an id, so we have to generate a random uuid # https://ai.google.dev/gemini-api/docs/function-calling?example=meeting if not tool_call_id: @@ -891,8 +975,8 @@ def _get_tool_call_id(self, tool_call_id: Optional[str]) -> Optional[str]: def extract_content( self, response: GenerateContentResponse, - tool_call_id: Optional[str] = None, - ) -> Optional[ExtractedContent]: + tool_call_ids: list[str] = [], + ) -> Optional[ExtractedContentList]: if not response.candidates: return None @@ -903,33 +987,56 @@ def extract_content( if not candidate.content.parts: return None + # Build events by first scanning parts and rectifying tool calls by position + content: ExtractedContentList = [] + function_call_index = -1 + seen_in_frame: set[int] = set() + for part in candidate.content.parts: - # Start of tool call - # GoogleProvider may emit the function_call object in every chunk, not just the first. - # We use tool_call_id to ensure we only emit one tool_call_start event per tool call. - if part.function_call and not tool_call_id: - tool_info = { - "toolCallId": self._get_tool_call_id( - part.function_call.id - ), - "toolName": part.function_call.name, - "args": json.dumps(part.function_call.args), - } - return (tool_info, "tool_call_start") - # Tool call args (not delta) - elif part.function_call and part.function_call.args: - return (json.dumps(part.function_call.args), "tool_call_delta") - - # Skip non-text content - elif part.text: - # Reasoning content + # Handle function calls (may appear multiple times per chunk) + if part.function_call: + function_call_index += 1 + # Resolve a stable id by position if provided from the caller; else synthesize + stable_id = ( + tool_call_ids[function_call_index] + if function_call_index < len(tool_call_ids) + and tool_call_ids[function_call_index] + else self._get_tool_call_id(part.function_call.id) + ) + + # First sight of this call index in this frame => emit start + if function_call_index not in seen_in_frame: + tool_info = { + "toolCallId": stable_id, + "toolName": part.function_call.name, + "args": json.dumps(part.function_call.args), + } + content.append((tool_info, "tool_call_start")) + seen_in_frame.add(function_call_index) + else: + # Subsequent occurrences for the same index => treat as delta (snapshot semantics) + if part.function_call.args is not None: + tool_delta = { + "toolCallId": stable_id, + "inputTextDelta": json.dumps( + part.function_call.args + ), + } + content.append((tool_delta, "tool_call_delta")) + continue + + # Text/Reasoning handling + if part.text: if part.thought: - return (part.text, "reasoning") + content.append((part.text, "reasoning")) else: - return (part.text, "text") - else: + content.append((part.text, "text")) continue - return None + + # Ignore other non-text parts (e.g., images) at this layer + continue + + return content def get_finish_reason( self, response: GenerateContentResponse @@ -1007,9 +1114,8 @@ async def stream_completion( def extract_content( self, response: LitellmStreamResponse, - tool_call_id: Optional[str] = None, - ) -> Optional[ExtractedContent]: - del tool_call_id + tool_call_ids: list[str] = [], + ) -> Optional[ExtractedContentList]: if ( hasattr(response, "choices") and response.choices @@ -1020,30 +1126,42 @@ def extract_content( # Text content content = delta.content if content: - return (str(content), "text") + return [(str(content), "text")] # Tool call: LiteLLM follows OpenAI format for tool calls - if hasattr(delta, "tool_calls") and delta.tool_calls: - tool_calls = delta.tool_calls[0] - - # Start of tool call - # id is only present for the first tool call chunk - if hasattr(tool_calls, "id") and tool_calls.id: - tool_info = { - "toolCallId": tool_calls.id, - "toolName": tool_calls.function.name, - } - return (tool_info, "tool_call_start") - # Delta of tool call - # arguments is only present second chunk onwards - if ( - hasattr(tool_calls, "function") - and tool_calls.function - and hasattr(tool_calls.function, "arguments") - and tool_calls.function.arguments - ): - return (tool_calls.function.arguments, "tool_call_delta") + if hasattr(delta, "tool_calls") and delta.tool_calls: + tool_content: ExtractedContentList = [] + + for tool_call in delta.tool_calls: + tool_index: int = tool_call.index + + # Start of tool call + # id is only present for the first tool call chunk + if hasattr(tool_call, "id") and tool_call.id: + tool_info = { + "toolCallId": tool_call.id, + "toolName": tool_call.function.name, + } + tool_content.append((tool_info, "tool_call_start")) + + # Delta of tool call + # arguments is only present second chunk onwards + if ( + hasattr(tool_call, "function") + and tool_call.function + and hasattr(tool_call.function, "arguments") + and tool_call.function.arguments + and tool_call_ids[tool_index] + ): + tool_delta = { + "toolCallId": tool_call_ids[tool_index], + "inputTextDelta": tool_call.function.arguments, + } + tool_content.append((tool_delta, "tool_call_delta")) + + # return the tool content + return tool_content return None From 9d2982e59b91a3fbd25f43a4892d28e6e681d185 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 3 Oct 2025 15:38:26 +0200 Subject: [PATCH 2/5] Update hasPendingToolCalls to ensure all tool calls are complete --- frontend/src/components/chat/chat-utils.ts | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/frontend/src/components/chat/chat-utils.ts b/frontend/src/components/chat/chat-utils.ts index dc4646a4665..4f265196026 100644 --- a/frontend/src/components/chat/chat-utils.ts +++ b/frontend/src/components/chat/chat-utils.ts @@ -153,7 +153,7 @@ export async function handleToolCall({ /** * Checks if we should send a message automatically based on the messages. - * We only want to send a message if we have completed tool calls and there is no reply yet. + * We only want to send a message if all tool calls are completed and there is no reply yet. */ export function hasPendingToolCalls(messages: UIMessage[]): boolean { if (messages.length === 0) { @@ -177,7 +177,12 @@ export function hasPendingToolCalls(messages: UIMessage[]): boolean { part.type.startsWith("tool-"), ) as ToolUIPart[]; - const hasCompletedToolCalls = toolParts.some( + // Guard against no tool parts + if (toolParts.length === 0) { + return false; + } + + const allToolCallsCompleted = toolParts.every( (part) => part.state === "output-available", ); @@ -186,6 +191,8 @@ export function hasPendingToolCalls(messages: UIMessage[]): boolean { const hasTextContent = lastPart.type === "text" && lastPart.text?.trim().length > 0; + Logger.warn("All tool calls completed: %s", allToolCallsCompleted); + // Only auto-send if we have completed tool calls and there is no reply yet - return hasCompletedToolCalls && !hasTextContent; + return allToolCallsCompleted && !hasTextContent; } From 9013b0145dfa225e6a3e28039a69b2901fb83e09 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 3 Oct 2025 15:43:31 +0200 Subject: [PATCH 3/5] Update tool_call_ids to optional --- marimo/_server/ai/providers.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/marimo/_server/ai/providers.py b/marimo/_server/ai/providers.py index 992510af7e7..57c716be8a5 100644 --- a/marimo/_server/ai/providers.py +++ b/marimo/_server/ai/providers.py @@ -141,7 +141,7 @@ async def stream_completion( @abstractmethod def extract_content( - self, response: ResponseT, tool_call_ids: list[str] = [] + self, response: ResponseT, tool_call_ids: Optional[list[str]] = None ) -> Optional[ExtractedContentList]: """Extract content from a response chunk.""" pass @@ -555,8 +555,9 @@ async def stream_completion( def extract_content( self, response: ChatCompletionChunk, - tool_call_ids: list[str] = [], + tool_call_ids: Optional[list[str]] = None, ) -> Optional[ExtractedContentList]: + tool_call_ids = tool_call_ids or [] if ( hasattr(response, "choices") and response.choices @@ -791,7 +792,7 @@ def block_index_to_tool_call_id(self, block_index: int) -> str: def extract_content( self, response: RawMessageStreamEvent, - tool_call_ids: list[str] = [], + tool_call_ids: Optional[list[str]] = None, ) -> Optional[ExtractedContentList]: del tool_call_ids from anthropic.types import ( @@ -975,8 +976,9 @@ def _get_tool_call_id(self, tool_call_id: Optional[str]) -> str: def extract_content( self, response: GenerateContentResponse, - tool_call_ids: list[str] = [], + tool_call_ids: Optional[list[str]] = None, ) -> Optional[ExtractedContentList]: + tool_call_ids = tool_call_ids or [] if not response.candidates: return None @@ -1114,8 +1116,9 @@ async def stream_completion( def extract_content( self, response: LitellmStreamResponse, - tool_call_ids: list[str] = [], + tool_call_ids: Optional[list[str]] = None, ) -> Optional[ExtractedContentList]: + tool_call_ids = tool_call_ids or [] if ( hasattr(response, "choices") and response.choices From 88006ec9c2ed2299c7e7a89404fc41ce6ea5dee6 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sat, 4 Oct 2025 02:11:51 +0200 Subject: [PATCH 4/5] Add tests for updated extract content functionality and provider specific extraction and id mapping --- tests/_server/ai/test_providers.py | 152 ++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 2 deletions(-) diff --git a/tests/_server/ai/test_providers.py b/tests/_server/ai/test_providers.py index 4a593375fbb..df184d29173 100644 --- a/tests/_server/ai/test_providers.py +++ b/tests/_server/ai/test_providers.py @@ -1,6 +1,6 @@ """Tests for the LLM providers in marimo._server.ai.providers.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -56,7 +56,6 @@ def test_anyprovider_for_model(model_name: str, provider_name: str) -> None: if provider_name != "bedrock": assert config.api_key == f"{provider_name}-key" else: - # bedrock overloads the api_key for profile name assert config.api_key == "profile:aws-profile" @@ -172,3 +171,152 @@ async def test_azure_openai_provider() -> None: assert api_version == "2023-05-15" assert deployment_name == "gpt-4-1" assert endpoint == "https://unknown_domain.openai" + + +@pytest.mark.parametrize( + "provider_type", + [ + pytest.param(OpenAIProvider, id="openai"), + pytest.param(BedrockProvider, id="bedrock"), + ], +) +def test_extract_content_with_none_tool_call_ids( + provider_type: type, +) -> None: + """Test extract_content handles None tool_call_ids without errors.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test") + provider = provider_type("test-model", config) + + mock_response = MagicMock() + mock_delta = MagicMock() + mock_delta.content = "Hello" + mock_delta.tool_calls = None + mock_choice = MagicMock() + mock_choice.delta = mock_delta + mock_response.choices = [mock_choice] + + result = provider.extract_content(mock_response, None) + assert result == [("Hello", "text")] + + +def test_google_extract_content_with_none_tool_call_ids() -> None: + """Test Google extract_content handles None tool_call_ids without errors.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test") + provider = GoogleProvider("gemini-1.5-flash", config) + + mock_response = MagicMock() + mock_candidate = MagicMock() + mock_content = MagicMock() + mock_part = MagicMock() + mock_part.text = "Hello" + mock_part.thought = False + mock_part.function_call = None + mock_content.parts = [mock_part] + mock_candidate.content = mock_content + mock_response.candidates = [mock_candidate] + + result = provider.extract_content(mock_response, None) + assert result == [("Hello", "text")] + + +def test_openai_extract_content_multiple_tool_calls() -> None: + """Test OpenAI extracts multiple tool calls correctly.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test") + provider = OpenAIProvider("gpt-4", config) + + mock_response = MagicMock() + mock_delta = MagicMock() + mock_delta.content = None + + mock_tool_1 = MagicMock() + mock_tool_1.index = 0 + mock_tool_1.id = "call_1" + mock_tool_1.function = MagicMock() + mock_tool_1.function.name = "get_weather" + mock_tool_1.function.arguments = None + + mock_tool_2 = MagicMock() + mock_tool_2.index = 1 + mock_tool_2.id = "call_2" + mock_tool_2.function = MagicMock() + mock_tool_2.function.name = "get_time" + mock_tool_2.function.arguments = None + + mock_delta.tool_calls = [mock_tool_1, mock_tool_2] + mock_choice = MagicMock() + mock_choice.delta = mock_delta + mock_response.choices = [mock_choice] + + result = provider.extract_content(mock_response, None) + assert result is not None + assert len(result) == 2 + tool_data_0, _ = result[0] + tool_data_1, _ = result[1] + assert isinstance(tool_data_0, dict) + assert isinstance(tool_data_1, dict) + assert tool_data_0["toolName"] == "get_weather" + assert tool_data_1["toolName"] == "get_time" + + +def test_google_extract_content_id_rectification() -> None: + """Test Google uses provided tool_call_ids for ID rectification.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test") + provider = GoogleProvider("gemini-1.5-flash", config) + + mock_response = MagicMock() + mock_candidate = MagicMock() + mock_content = MagicMock() + mock_func_call = MagicMock() + mock_func_call.name = "get_weather" + mock_func_call.args = {"location": "SF"} + mock_func_call.id = None + mock_part = MagicMock() + mock_part.text = None + mock_part.function_call = mock_func_call + mock_content.parts = [mock_part] + mock_candidate.content = mock_content + mock_response.candidates = [mock_candidate] + + result = provider.extract_content(mock_response, ["stable_id"]) + assert result is not None + tool_data, _ = result[0] + assert isinstance(tool_data, dict) + assert tool_data["toolCallId"] == "stable_id" + + +def test_anthropic_extract_content_tool_call_id_mapping() -> None: + """Test Anthropic maps tool call IDs via block index.""" + try: + from anthropic.types import ( + InputJSONDelta, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + ToolUseBlock, + ) + except ImportError: + pytest.skip("Anthropic not installed") + + config = AnyProviderConfig(api_key="test-key", base_url="http://test") + provider = AnthropicProvider("claude-3-opus-20240229", config) + + start_event = RawContentBlockStartEvent( + type="content_block_start", + index=0, + content_block=ToolUseBlock( + type="tool_use", id="toolu_123", name="get_weather", input={} + ), + ) + provider.extract_content(start_event, None) + + delta_event = RawContentBlockDeltaEvent( + type="content_block_delta", + index=0, + delta=InputJSONDelta( + type="input_json_delta", partial_json='{"location": "SF"}' + ), + ) + result = provider.extract_content(delta_event, None) + assert result is not None + tool_data, _ = result[0] + assert isinstance(tool_data, dict) + assert tool_data["toolCallId"] == "toolu_123" From c0d81ca2d310237ca81d7d3891d278e75be909b0 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Sun, 5 Oct 2025 13:45:18 +0200 Subject: [PATCH 5/5] Fix test unpacking error --- tests/_server/api/endpoints/test_ai.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/_server/api/endpoints/test_ai.py b/tests/_server/api/endpoints/test_ai.py index 2e6256686c9..bd1f1d39845 100644 --- a/tests/_server/api/endpoints/test_ai.py +++ b/tests/_server/api/endpoints/test_ai.py @@ -891,9 +891,11 @@ def test_extract_content_with_delta_content(self) -> None: # Call get_content with the mock response config = AnyProviderConfig(base_url=None, api_key="test-key") provider = OpenAIProvider(model="gpt-4o", config=config) - result_text, result_type = provider.extract_content(mock_response) + result = provider.extract_content(mock_response) - # Assert that the result is the expected content + # Assert that the result is not None and has expected content + assert result is not None + result_text, result_type = result[0] assert result_text == "Test content" assert result_type == "text"