diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py b/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py index 8bba842705..35e648c100 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py @@ -11,21 +11,61 @@ class ConfirmationStrategy(ABC): - """Strategy for generating confirmation messages during human-in-the-loop flows.""" + """Strategy for generating confirmation messages during human-in-the-loop flows. + Subclasses must define the message properties. The methods use those properties + by default, but can be overridden for complete customization. + """ + + @property + @abstractmethod + def approval_header(self) -> str: + """Header for approval accepted message. Must be overridden.""" + ... + + @property + @abstractmethod + def approval_footer(self) -> str: + """Footer for approval accepted message. Must be overridden.""" + ... + + @property + @abstractmethod + def rejection_message(self) -> str: + """Message when user rejects. Must be overridden.""" + ... + + @property @abstractmethod + def state_confirmed_message(self) -> str: + """Message when state is confirmed. Must be overridden.""" + ... + + @property + @abstractmethod + def state_rejected_message(self) -> str: + """Message when state is rejected. Must be overridden.""" + ... + def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: """Generate message when user approves function execution. + Default implementation uses header/footer properties. + Override for complete customization. + Args: steps: List of approved steps with 'description', 'status', etc. Returns: Message to display to user """ - ... + enabled_steps = [s for s in steps if s.get("status") == "enabled"] + message_parts = [self.approval_header.format(count=len(enabled_steps))] + for i, step in enumerate(enabled_steps, 1): + message_parts.append(f"{i}. {step['description']}\n") + message_parts.append(self.approval_footer) + return "".join(message_parts) - @abstractmethod def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: """Generate message when user rejects function execution. @@ -35,141 +75,143 @@ def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: Returns: Message to display to user """ - ... + return self.rejection_message - @abstractmethod def on_state_confirmed(self) -> str: """Generate message when user confirms predictive state changes. Returns: Message to display to user """ - ... + return self.state_confirmed_message - @abstractmethod def on_state_rejected(self) -> str: """Generate message when user rejects predictive state changes. Returns: Message to display to user """ - ... + return self.state_rejected_message class DefaultConfirmationStrategy(ConfirmationStrategy): - """Generic confirmation messages suitable for most agents. - - This preserves the original behavior from v1. - """ + """Generic confirmation messages suitable for most agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate generic approval message with step list.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"] - - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - - message_parts.append("\nAll steps completed successfully!") + @property + def approval_header(self) -> str: + return "Executing {count} approved steps:\n\n" - return "".join(message_parts) + @property + def approval_footer(self) -> str: + return "\nAll steps completed successfully!" - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate generic rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! What would you like me to change about the plan?" - def on_state_confirmed(self) -> str: - """Generate generic state confirmation message.""" + @property + def state_confirmed_message(self) -> str: return "Changes confirmed and applied successfully!" - def on_state_rejected(self) -> str: - """Generate generic state rejection message.""" + @property + def state_rejected_message(self) -> str: return "No problem! What would you like me to change?" class TaskPlannerConfirmationStrategy(ConfirmationStrategy): """Domain-specific confirmation messages for task planning agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate task-specific approval message.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = ["Executing your requested tasks:\n\n"] - - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") + @property + def approval_header(self) -> str: + return "Executing your requested tasks:\n\n" - message_parts.append("\nAll tasks completed successfully!") + @property + def approval_footer(self) -> str: + return "\nAll tasks completed successfully!" - return "".join(message_parts) - - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate task-specific rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! Let me revise the plan. What would you like me to change?" - def on_state_confirmed(self) -> str: - """Task planners typically don't use state confirmation.""" + @property + def state_confirmed_message(self) -> str: return "Tasks confirmed and ready to execute!" - def on_state_rejected(self) -> str: - """Task planners typically don't use state confirmation.""" + @property + def state_rejected_message(self) -> str: return "No problem! How should I adjust the task list?" class RecipeConfirmationStrategy(ConfirmationStrategy): """Domain-specific confirmation messages for recipe agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate recipe-specific approval message.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = ["Updating your recipe:\n\n"] + @property + def approval_header(self) -> str: + return "Updating your recipe:\n\n" - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - - message_parts.append("\nRecipe updated successfully!") - - return "".join(message_parts) + @property + def approval_footer(self) -> str: + return "\nRecipe updated successfully!" - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate recipe-specific rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! What ingredients or steps should I change?" - def on_state_confirmed(self) -> str: - """Generate recipe-specific state confirmation message.""" + @property + def state_confirmed_message(self) -> str: return "Recipe changes applied successfully!" - def on_state_rejected(self) -> str: - """Generate recipe-specific state rejection message.""" + @property + def state_rejected_message(self) -> str: return "No problem! What would you like me to adjust in the recipe?" class DocumentWriterConfirmationStrategy(ConfirmationStrategy): """Domain-specific confirmation messages for document writing agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate document-specific approval message.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = ["Applying your edits:\n\n"] - - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - - message_parts.append("\nDocument updated successfully!") + @property + def approval_header(self) -> str: + return "Applying your edits:\n\n" - return "".join(message_parts) + @property + def approval_footer(self) -> str: + return "\nDocument updated successfully!" - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate document-specific rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! Which changes should I keep or modify?" - def on_state_confirmed(self) -> str: - """Generate document-specific state confirmation message.""" + @property + def state_confirmed_message(self) -> str: return "Document edits applied!" - def on_state_rejected(self) -> str: - """Generate document-specific state rejection message.""" + @property + def state_rejected_message(self) -> str: return "No problem! What should I change about the document?" + + +def apply_confirmation_strategy( + strategy: ConfirmationStrategy | None, + accepted: bool, + steps: list[dict[str, Any]], +) -> str: + """Apply a confirmation strategy to generate a message. + + This helper consolidates the pattern used in multiple orchestrators. + + Args: + strategy: Strategy to use, or None for default + accepted: Whether the user approved + steps: List of steps (may be empty for state confirmations) + + Returns: + Generated message string + """ + if strategy is None: + strategy = DefaultConfirmationStrategy() + + if not steps: + # State confirmation (no steps) + return strategy.on_state_confirmed() if accepted else strategy.on_state_rejected() + # Step-based approval + return strategy.on_approval_accepted(steps) if accepted else strategy.on_approval_rejected(steps) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 449b7eac87..633fc501db 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -11,8 +11,6 @@ from ag_ui.core import ( BaseEvent, CustomEvent, - EventType, - MessagesSnapshotEvent, RunFinishedEvent, RunStartedEvent, StateDeltaEvent, @@ -34,7 +32,7 @@ prepare_function_call_results, ) -from ._utils import generate_event_id +from ._utils import extract_state_from_tool_args, generate_event_id, safe_json_parse logger = logging.getLogger(__name__) @@ -49,8 +47,8 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, current_state: dict[str, Any] | None = None, skip_text_content: bool = False, - input_messages: list[Any] | None = None, require_confirmation: bool = True, + approval_tool_name: str | None = None, ) -> None: """ Initialize the event bridge. @@ -62,7 +60,6 @@ def __init__( Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} current_state: Reference to the current state dict for tracking updates. skip_text_content: If True, skip emitting TextMessageContentEvents (for structured outputs). - input_messages: The input messages from the conversation history. require_confirmation: Whether predictive state updates require user confirmation. """ self.run_id = run_id @@ -75,6 +72,7 @@ def __init__( self.pending_state_updates: dict[str, Any] = {} # Track updates from tool calls self.skip_text_content = skip_text_content self.require_confirmation = require_confirmation + self.approval_tool_name = approval_tool_name # For predictive state updates: accumulate streaming arguments self.streaming_tool_args: str = "" # Accumulated JSON string @@ -83,13 +81,6 @@ def __init__( self.should_stop_after_confirm: bool = False # Flag to stop run after confirm_changes self.suppressed_summary: str = "" # Store LLM summary to show after confirmation - # For MessagesSnapshotEvent: track tool calls and results - self.input_messages = input_messages or [] - self.pending_tool_calls: list[dict[str, Any]] = [] # Track tool calls for assistant message - self.tool_results: list[dict[str, Any]] = [] # Track tool results - self.tool_calls_ended: set[str] = set() # Track which tool calls have had ToolCallEndEvent emitted - self.accumulated_text_content: str = "" # Track accumulated text for final MessagesSnapshotEvent - async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[BaseEvent]: """ Convert an AgentRunResponseUpdate to AG-UI events. @@ -155,7 +146,6 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: message_id=self.current_message_id, delta=content.text, ) - self.accumulated_text_content += content.text logger.info(f" EMITTING TextMessageContentEvent with text_len={len(content.text)}") events.append(event) return events @@ -184,17 +174,6 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba ) logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'") events.append(tool_start_event) - - self.pending_tool_calls.append( - { - "id": tool_call_id, - "type": "function", - "function": { - "name": content.name, - "arguments": "", - }, - } - ) elif tool_call_id: self.current_tool_call_id = tool_call_id @@ -207,13 +186,7 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba ) events.append(args_event) - for tool_call in self.pending_tool_calls: - if tool_call["id"] == tool_call_id: - tool_call["function"]["arguments"] += delta_str - break - events.extend(self._emit_predictive_state_deltas(delta_str)) - events.extend(self._legacy_predictive_state(content)) return events @@ -236,10 +209,8 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.current_tool_call_name, ) - parsed_args = None - try: - parsed_args = json.loads(self.streaming_tool_args) - except json.JSONDecodeError: + parsed_args = safe_json_parse(self.streaming_tool_args) + if parsed_args is None: for state_key, config in self.predict_state_config.items(): if config["tool"] != self.current_tool_call_name: continue @@ -283,11 +254,8 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: continue tool_arg_name = config["tool_argument"] - if tool_arg_name == "*": - state_value = parsed_args - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - else: + state_value = extract_state_from_tool_args(parsed_args, tool_arg_name) + if state_value is None: continue if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != state_value: @@ -318,59 +286,6 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.pending_state_updates[state_key] = state_value return events - def _legacy_predictive_state(self, content: FunctionCallContent) -> list[BaseEvent]: - events: list[BaseEvent] = [] - if not (content.name and content.arguments): - return events - parsed_args = content.parse_arguments() - if not parsed_args: - return events - - logger.info( - "Checking predict_state_config keys: %s", - list(self.predict_state_config.keys()) if self.predict_state_config else "None", - ) - for state_key, config in self.predict_state_config.items(): - logger.info(f"Checking state_key='{state_key}'") - if config["tool"] != content.name: - continue - tool_arg_name = config["tool_argument"] - logger.info(f"MATCHED tool '{content.name}' for state key '{state_key}', arg='{tool_arg_name}'") - - state_value: Any - if tool_arg_name == "*": - state_value = parsed_args - logger.info(f"Using all args as state value, keys: {list(state_value.keys())}") - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - logger.info(f"Using specific arg '{tool_arg_name}' as state value") - else: - logger.warning(f"Tool argument '{tool_arg_name}' not found in parsed args") - continue - - previous_value = self.last_emitted_state.get(state_key, object()) - if previous_value == state_value: - logger.info( - "Skipping duplicate StateDeltaEvent for key '%s' - value unchanged", - state_key, - ) - continue - - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", - "path": f"/{state_key}", - "value": state_value, - } - ], - ) - logger.info(f"Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}") # type: ignore - events.append(state_delta_event) - self.pending_state_updates[state_key] = state_value - self.last_emitted_state[state_key] = state_value - return events - def _handle_function_result_content(self, content: FunctionResultContent) -> list[BaseEvent]: events: list[BaseEvent] = [] if content.call_id: @@ -379,7 +294,6 @@ def _handle_function_result_content(self, content: FunctionResultContent) -> lis ) logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'") events.append(end_event) - self.tool_calls_ended.add(content.call_id) if self.state_delta_count > 0: logger.info( @@ -401,55 +315,10 @@ def _handle_function_result_content(self, content: FunctionResultContent) -> lis role="tool", ) events.append(result_event) - - self.tool_results.append( - { - "id": result_message_id, - "role": "tool", - "toolCallId": content.call_id, - "content": result_content, - } - ) - - events.extend(self._emit_snapshot_for_tool_result()) events.extend(self._emit_state_snapshot_and_confirmation()) return events - def _emit_snapshot_for_tool_result(self) -> list[BaseEvent]: - events: list[BaseEvent] = [] - should_emit_snapshot = self.pending_tool_calls and self.tool_results - - is_predictive_without_confirmation = False - if should_emit_snapshot and self.current_tool_call_name and self.predict_state_config: - for _, config in self.predict_state_config.items(): - if config["tool"] == self.current_tool_call_name and not self.require_confirmation: - is_predictive_without_confirmation = True - logger.info( - "Skipping intermediate MessagesSnapshotEvent for predictive tool '%s' - delaying until summary", - self.current_tool_call_name, - ) - break - - if should_emit_snapshot and not is_predictive_without_confirmation: - from ._message_adapters import agent_framework_messages_to_agui - - assistant_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": self.pending_tool_calls.copy(), - } - converted_input_messages = agent_framework_messages_to_agui(self.input_messages) - all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() - - messages_snapshot_event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, - messages=all_messages, # type: ignore[arg-type] - ) - logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages") - events.append(messages_snapshot_event) - return events - def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: events: list[BaseEvent] = [] if self.pending_state_updates: @@ -498,31 +367,46 @@ def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: self.current_tool_call_name = None return events - def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: + def _emit_confirm_changes_tool_call(self, function_call: FunctionCallContent | None = None) -> list[BaseEvent]: + """Emit a confirm_changes tool call for Dojo UI compatibility. + + Args: + function_call: Optional function call that needs confirmation. + If provided, includes function info in the confirm_changes args + so Dojo UI can display what's being confirmed. + """ events: list[BaseEvent] = [] confirm_call_id = generate_event_id() logger.info("Emitting confirm_changes tool call for predictive update") - self.pending_tool_calls.append( - { - "id": confirm_call_id, - "type": "function", - "function": { - "name": "confirm_changes", - "arguments": "{}", - }, - } - ) - confirm_start = ToolCallStartEvent( tool_call_id=confirm_call_id, tool_call_name="confirm_changes", + parent_message_id=self.current_message_id, ) events.append(confirm_start) + # Include function info if this is for a function approval + # This helps Dojo UI display meaningful confirmation info + if function_call: + args_dict = { + "function_name": function_call.name, + "function_call_id": function_call.call_id, + "function_arguments": function_call.parse_arguments() or {}, + "steps": [ + { + "description": f"Execute {function_call.name}", + "status": "enabled", + } + ], + } + args_json = json.dumps(args_dict) + else: + args_json = "{}" + confirm_args = ToolCallArgsEvent( tool_call_id=confirm_call_id, - delta="{}", + delta=args_json, ) events.append(confirm_args) @@ -531,23 +415,48 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: ) events.append(confirm_end) - from ._message_adapters import agent_framework_messages_to_agui + self.should_stop_after_confirm = True + logger.info("Set flag to stop run after confirm_changes") + return events - assistant_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": self.pending_tool_calls.copy(), - } + def _emit_function_approval_tool_call(self, function_call: FunctionCallContent) -> list[BaseEvent]: + """Emit a tool call that can drive UI approval for function requests.""" + tool_call_name = "confirm_changes" + if self.approval_tool_name and self.approval_tool_name != function_call.name: + tool_call_name = self.approval_tool_name + + tool_call_id = generate_event_id() + tool_start = ToolCallStartEvent( + tool_call_id=tool_call_id, + tool_call_name=tool_call_name, + parent_message_id=self.current_message_id, + ) + events: list[BaseEvent] = [tool_start] - converted_input_messages = agent_framework_messages_to_agui(self.input_messages) - all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() + args_dict = { + "function_name": function_call.name, + "function_call_id": function_call.call_id, + "function_arguments": function_call.parse_arguments() or {}, + "steps": [ + { + "description": f"Execute {function_call.name}", + "status": "enabled", + } + ], + } + args_json = json.dumps(args_dict) - messages_snapshot_event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, - messages=all_messages, # type: ignore[arg-type] + events.append( + ToolCallArgsEvent( + tool_call_id=tool_call_id, + delta=args_json, + ) + ) + events.append( + ToolCallEndEvent( + tool_call_id=tool_call_id, + ) ) - logger.info(f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages") - events.append(messages_snapshot_event) self.should_stop_after_confirm = True logger.info("Set flag to stop run after confirm_changes") @@ -579,12 +488,8 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq tool_arg_name, ) - state_value: Any - if tool_arg_name == "*": - state_value = parsed_args - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - else: + state_value = extract_state_from_tool_args(parsed_args, tool_arg_name) + if state_value is None: logger.warning(f" Tool argument '{tool_arg_name}' not found in parsed args") continue @@ -601,8 +506,8 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq ) logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") events.append(end_event) - self.tool_calls_ended.add(content.function_call.call_id) + # Emit the function_approval_request custom event for UI implementations that support it approval_event = CustomEvent( name="function_approval_request", value={ @@ -616,6 +521,14 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq ) logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") events.append(approval_event) + + # Emit a UI-friendly approval tool call for function approvals. + if self.require_confirmation: + events.extend(self._emit_function_approval_tool_call(content.function_call)) + + # Signal orchestrator to stop the run and wait for user approval response + self.should_stop_after_confirm = True + logger.info("Set flag to stop run - waiting for function approval response") return events def create_run_started_event(self) -> RunStartedEvent: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index b87f3b1827..1ff858e9f5 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -3,6 +3,7 @@ """Message format conversion between AG-UI and Agent Framework.""" import json +import logging from typing import Any, cast from agent_framework import ( @@ -15,18 +16,226 @@ prepare_function_call_results, ) -# Role mapping constants -_AGUI_TO_FRAMEWORK_ROLE = { - "user": Role.USER, - "assistant": Role.ASSISTANT, - "system": Role.SYSTEM, -} +from ._utils import ( + AGUI_TO_FRAMEWORK_ROLE, + FRAMEWORK_TO_AGUI_ROLE, + get_role_value, + normalize_agui_role, + safe_json_parse, +) + +logger = logging.getLogger(__name__) + + +def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: + """Normalize tool ordering and inject synthetic results for AG-UI edge cases.""" + sanitized: list[ChatMessage] = [] + pending_tool_call_ids: set[str] | None = None + pending_confirm_changes_id: str | None = None + + for msg in messages: + role_value = get_role_value(msg) + + if role_value == "assistant": + tool_ids = { + str(content.call_id) + for content in msg.contents or [] + if isinstance(content, FunctionCallContent) and content.call_id + } + confirm_changes_call = None + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": + confirm_changes_call = content + break + + sanitized.append(msg) + pending_tool_call_ids = tool_ids if tool_ids else None + pending_confirm_changes_id = ( + str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None + ) + continue + + if role_value == "user": + approval_call_ids: set[str] = set() + approval_accepted: bool | None = None + for content in msg.contents or []: + if type(content) is FunctionApprovalResponseContent: + if content.function_call and content.function_call.call_id: + approval_call_ids.add(str(content.function_call.call_id)) + if approval_accepted is None: + approval_accepted = bool(content.approved) + else: + approval_accepted = approval_accepted and bool(content.approved) -_FRAMEWORK_TO_AGUI_ROLE = { - Role.USER: "user", - Role.ASSISTANT: "assistant", - Role.SYSTEM: "system", -} + if approval_call_ids and pending_tool_call_ids: + pending_tool_call_ids -= approval_call_ids + logger.info( + f"FunctionApprovalResponseContent found for call_ids={sorted(approval_call_ids)} - " + "framework will handle execution" + ) + + if pending_confirm_changes_id and approval_accepted is not None: + logger.info(f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}") + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_confirm_changes_id, + result="Confirmed" if approval_accepted else "Rejected", + ) + ], + ) + sanitized.append(synthetic_result) + if pending_tool_call_ids: + pending_tool_call_ids.discard(pending_confirm_changes_id) + pending_confirm_changes_id = None + + if pending_confirm_changes_id: + user_text = "" + for content in msg.contents or []: + if isinstance(content, TextContent): + user_text = content.text + break + + try: + parsed = json.loads(user_text) + if "accepted" in parsed: + logger.info( + f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" + ) + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_confirm_changes_id, + result="Confirmed" if parsed.get("accepted") else "Rejected", + ) + ], + ) + sanitized.append(synthetic_result) + if pending_tool_call_ids: + pending_tool_call_ids.discard(pending_confirm_changes_id) + pending_confirm_changes_id = None + continue + except (json.JSONDecodeError, KeyError) as exc: + logger.debug(f"Could not parse user message as confirm_changes response: {type(exc).__name__}") + + if pending_tool_call_ids: + logger.info( + f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - " + "injecting synthetic results" + ) + for pending_call_id in pending_tool_call_ids: + logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}") + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_call_id, + result="Tool execution skipped - user provided follow-up message", + ) + ], + ) + sanitized.append(synthetic_result) + pending_tool_call_ids = None + pending_confirm_changes_id = None + + sanitized.append(msg) + pending_confirm_changes_id = None + continue + + if role_value == "tool": + if not pending_tool_call_ids: + continue + keep = False + for content in msg.contents or []: + if isinstance(content, FunctionResultContent): + call_id = str(content.call_id) + if call_id in pending_tool_call_ids: + keep = True + if call_id == pending_confirm_changes_id: + pending_confirm_changes_id = None + break + if keep: + sanitized.append(msg) + continue + + sanitized.append(msg) + pending_tool_call_ids = None + pending_confirm_changes_id = None + + return sanitized + + +def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: + """Remove duplicate messages while preserving order.""" + seen_keys: dict[Any, int] = {} + unique_messages: list[ChatMessage] = [] + + for idx, msg in enumerate(messages): + role_value = get_role_value(msg) + + if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): + call_id = str(msg.contents[0].call_id) + key: Any = (role_value, call_id) + + if key in seen_keys: + existing_idx = seen_keys[key] + existing_msg = unique_messages[existing_idx] + + existing_result = None + if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): + existing_result = existing_msg.contents[0].result + new_result = msg.contents[0].result + + if (not existing_result or existing_result == "") and new_result: + logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}") + unique_messages[existing_idx] = msg + else: + logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}") + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + elif ( + role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents) + ): + tool_call_ids = tuple( + sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) + ) + key = (role_value, tool_call_ids) + + if key in seen_keys: + logger.info(f"Skipping duplicate assistant tool call at index {idx}") + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + else: + content_str = str([str(c) for c in msg.contents]) if msg.contents else "" + key = (role_value, hash(content_str)) + + if key in seen_keys: + logger.info(f"Skipping duplicate message at index {idx}: role={role_value}") + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + return unique_messages + + +def normalize_agui_input_messages( + messages: list[dict[str, Any]], +) -> tuple[list[ChatMessage], list[dict[str, Any]]]: + """Normalize raw AG-UI messages into provider and snapshot formats.""" + provider_messages = agui_messages_to_agent_framework(messages) + provider_messages = _sanitize_tool_history(provider_messages) + provider_messages = _deduplicate_messages(provider_messages) + snapshot_messages = agui_messages_to_snapshot_format(messages) + return provider_messages, snapshot_messages def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[ChatMessage]: @@ -38,11 +247,108 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha Returns: List of Agent Framework ChatMessage objects """ + + def _update_tool_call_arguments( + raw_messages: list[dict[str, Any]], + tool_call_id: str, + modified_args: dict[str, Any], + ) -> None: + for raw_msg in raw_messages: + tool_calls = raw_msg.get("tool_calls") or raw_msg.get("toolCalls") + if not isinstance(tool_calls, list): + continue + tool_calls_list = cast(list[Any], tool_calls) + for tool_call in tool_calls_list: + if not isinstance(tool_call, dict): + continue + tool_call_dict = cast(dict[str, Any], tool_call) + if str(tool_call_dict.get("id", "")) != tool_call_id: + continue + function_payload = tool_call_dict.get("function") + if not isinstance(function_payload, dict): + return + function_payload_dict = cast(dict[str, Any], function_payload) + existing_args = function_payload_dict.get("arguments") + if isinstance(existing_args, str): + function_payload_dict["arguments"] = json.dumps(modified_args) + else: + function_payload_dict["arguments"] = modified_args + return + + def _find_matching_func_call(call_id: str) -> FunctionCallContent | None: + for prev_msg in result: + role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) + if role_val != "assistant": + continue + for content in prev_msg.contents or []: + if isinstance(content, FunctionCallContent): + if content.call_id == call_id and content.name != "confirm_changes": + return content + return None + + def _parse_arguments(arguments: Any) -> dict[str, Any] | None: + return safe_json_parse(arguments) + + def _resolve_approval_call_id(tool_call_id: str, parsed_payload: dict[str, Any] | None) -> str | None: + if parsed_payload: + explicit_call_id = parsed_payload.get("function_call_id") + if explicit_call_id: + return str(explicit_call_id) + + for prev_msg in result: + role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) + if role_val != "assistant": + continue + direct_call = None + confirm_call = None + sibling_calls: list[FunctionCallContent] = [] + for content in prev_msg.contents or []: + if not isinstance(content, FunctionCallContent): + continue + if content.call_id == tool_call_id: + direct_call = content + if content.name == "confirm_changes" and content.call_id == tool_call_id: + confirm_call = content + elif content.name != "confirm_changes": + sibling_calls.append(content) + + if direct_call: + direct_args = direct_call.parse_arguments() or {} + if isinstance(direct_args, dict): + explicit_call_id = direct_args.get("function_call_id") + if explicit_call_id: + return str(explicit_call_id) + + if not confirm_call: + continue + + confirm_args = confirm_call.parse_arguments() or {} + if isinstance(confirm_args, dict): + explicit_call_id = confirm_args.get("function_call_id") + if explicit_call_id: + return str(explicit_call_id) + + if len(sibling_calls) == 1 and sibling_calls[0].call_id: + return str(sibling_calls[0].call_id) + + return None + + def _filter_modified_args( + modified_args: dict[str, Any], + original_args: dict[str, Any] | None, + ) -> dict[str, Any]: + if not modified_args: + return {} + if not isinstance(original_args, dict) or not original_args: + return {} + allowed_keys = set(original_args.keys()) + return {key: value for key, value in modified_args.items() if key in allowed_keys} + result: list[ChatMessage] = [] for msg in messages: # Handle standard tool result messages early (role="tool") to preserve provider invariants # This path maps AG‑UI tool messages to FunctionResultContent with the correct tool_call_id - role_str = msg.get("role", "user") + role_str = normalize_agui_role(msg.get("role", "user")) if role_str == "tool": # Prefer explicit tool_call_id fields; fall back to backend fields only if necessary tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId") @@ -59,29 +365,153 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha result_content = msg.get("result", "") # Distinguish approval payloads from actual tool results - is_approval = False + parsed: dict[str, Any] | None = None if isinstance(result_content, str) and result_content: try: - parsed = json.loads(result_content) - is_approval = isinstance(parsed, dict) and "accepted" in parsed + parsed_candidate = json.loads(result_content) except Exception: - is_approval = False + parsed_candidate = None + if isinstance(parsed_candidate, dict): + parsed = cast(dict[str, Any], parsed_candidate) + elif isinstance(result_content, dict): + parsed = cast(dict[str, Any], result_content) + + is_approval = parsed is not None and "accepted" in parsed if is_approval: - # Approval responses should be treated as user messages to trigger human-in-the-loop flow - chat_msg = ChatMessage( - role=Role.USER, - contents=[TextContent(text=str(result_content))], - additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")}, - ) + # Look for the matching function call in previous messages to create + # a proper FunctionApprovalResponseContent. This enables the agent framework + # to execute the approved tool (fix for GitHub issue #3034). + accepted = parsed.get("accepted", False) if parsed is not None else False + approval_payload_text = result_content if isinstance(result_content, str) else json.dumps(parsed) + + # Log the full approval payload to debug modified arguments + import logging + + logger = logging.getLogger(__name__) + logger.info(f"Approval payload received: {parsed}") + + approval_call_id = tool_call_id + resolved_call_id = _resolve_approval_call_id(tool_call_id, parsed) + if resolved_call_id: + approval_call_id = resolved_call_id + matching_func_call = _find_matching_func_call(approval_call_id) + + if matching_func_call: + # Remove any existing tool result for this call_id since the framework + # will re-execute the tool after approval. Keeping old results causes + # OpenAI API errors ("tool message must follow assistant with tool_calls"). + result = [ + m + for m in result + if not ( + (m.role.value if hasattr(m.role, "value") else str(m.role)) == "tool" + and any( + isinstance(c, FunctionResultContent) and c.call_id == approval_call_id + for c in (m.contents or []) + ) + ) + ] + + # Check if the approval payload contains modified arguments + # The UI sends back the modified state (e.g., deselected steps) in the approval payload + modified_args = {k: v for k, v in parsed.items() if k != "accepted"} if parsed else {} + original_args = matching_func_call.parse_arguments() + filtered_args = _filter_modified_args(modified_args, original_args) + state_args: dict[str, Any] | None = None + if filtered_args: + original_args = original_args or {} + merged_args: dict[str, Any] + if isinstance(original_args, dict) and original_args: + merged_args = {**original_args, **filtered_args} + else: + merged_args = dict(filtered_args) + + if isinstance(filtered_args.get("steps"), list): + original_steps = original_args.get("steps") if isinstance(original_args, dict) else None + if isinstance(original_steps, list): + approved_steps_list = list(filtered_args.get("steps") or []) + approved_by_description: dict[str, dict[str, Any]] = {} + for step_item in approved_steps_list: + if isinstance(step_item, dict): + step_item_dict = cast(dict[str, Any], step_item) + desc = step_item_dict.get("description") + if desc: + approved_by_description[str(desc)] = step_item_dict + merged_steps: list[Any] = [] + original_steps_list = cast(list[Any], original_steps) + for orig_step in original_steps_list: + if not isinstance(orig_step, dict): + merged_steps.append(orig_step) + continue + orig_step_dict = cast(dict[str, Any], orig_step) + description = str(orig_step_dict.get("description", "")) + approved_step = approved_by_description.get(description) + status: str = ( + str(approved_step.get("status")) + if approved_step is not None and approved_step.get("status") + else "disabled" + ) + updated_step: dict[str, Any] = orig_step_dict.copy() + updated_step["status"] = status + merged_steps.append(updated_step) + merged_args["steps"] = merged_steps + state_args = merged_args + + # Keep the original tool call and AG-UI snapshot in sync with approved args. + updated_args = ( + json.dumps(merged_args) if isinstance(matching_func_call.arguments, str) else merged_args + ) + matching_func_call.arguments = updated_args + _update_tool_call_arguments(messages, str(approval_call_id), merged_args) + # Create a new FunctionCallContent with the modified arguments + func_call_for_approval = FunctionCallContent( + call_id=matching_func_call.call_id, + name=matching_func_call.name, + arguments=json.dumps(filtered_args), + ) + logger.info(f"Using modified arguments from approval: {filtered_args}") + else: + # No modified arguments - use the original function call + func_call_for_approval = matching_func_call + + # Create FunctionApprovalResponseContent for the agent framework + approval_response = FunctionApprovalResponseContent( + approved=accepted, + id=str(approval_call_id), + function_call=func_call_for_approval, + additional_properties={"ag_ui_state_args": state_args} if state_args else None, + ) + chat_msg = ChatMessage( + role=Role.USER, + contents=[approval_response], + ) + else: + # No matching function call found - this is likely a confirm_changes approval + # Keep the old behavior for backwards compatibility + chat_msg = ChatMessage( + role=Role.USER, + contents=[TextContent(text=approval_payload_text)], + additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")}, + ) if "id" in msg: chat_msg.message_id = msg["id"] result.append(chat_msg) continue + # Cast result_content to acceptable type for FunctionResultContent + func_result: str | dict[str, Any] | list[Any] + if isinstance(result_content, str): + func_result = result_content + elif isinstance(result_content, dict): + func_result = cast(dict[str, Any], result_content) + elif isinstance(result_content, list): + func_result = cast(list[Any], result_content) + else: + func_result = str(result_content) chat_msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id=str(tool_call_id), result=result_content)], + contents=[FunctionResultContent(call_id=str(tool_call_id), result=func_result)], ) if "id" in msg: chat_msg.message_id = msg["id"] @@ -142,7 +572,7 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha # No special handling required for assistant/plain messages here - role = _AGUI_TO_FRAMEWORK_ROLE.get(role_str, Role.USER) + role = AGUI_TO_FRAMEWORK_ROLE.get(role_str, Role.USER) # Check if this message contains function approvals if "function_approvals" in msg and msg["function_approvals"]: @@ -198,6 +628,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str if isinstance(msg, dict): # Always work on a copy to avoid mutating input normalized_msg = msg.copy() + normalized_msg["role"] = normalize_agui_role(normalized_msg.get("role")) # Ensure ID exists if "id" not in normalized_msg: normalized_msg["id"] = generate_event_id() @@ -214,7 +645,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str continue # Convert ChatMessage to AG-UI format - role = _FRAMEWORK_TO_AGUI_ROLE.get(msg.role, "user") + role = FRAMEWORK_TO_AGUI_ROLE.get(msg.role, "user") content_text = "" tool_calls: list[dict[str, Any]] = [] @@ -303,22 +734,44 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic content = normalized_msg.get("content") if isinstance(content, list): # Convert content array format to simple string - text_parts = [] - for item in content: + text_parts: list[str] = [] + content_list = cast(list[Any], content) + for item in content_list: if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) # Convert 'input_text' to 'text' type - if item.get("type") == "input_text": - text_parts.append(item.get("text", "")) - elif item.get("type") == "text": - text_parts.append(item.get("text", "")) + if item_dict.get("type") == "input_text": + text_parts.append(str(item_dict.get("text", ""))) + elif item_dict.get("type") == "text": + text_parts.append(str(item_dict.get("text", ""))) else: # Other types - just extract text field if present - text_parts.append(item.get("text", "")) + text_parts.append(str(item_dict.get("text", ""))) normalized_msg["content"] = "".join(text_parts) elif content is None: normalized_msg["content"] = "" + tool_calls = normalized_msg.get("tool_calls") or normalized_msg.get("toolCalls") + if isinstance(tool_calls, list): + tool_calls_list = cast(list[Any], tool_calls) + for tool_call in tool_calls_list: + if not isinstance(tool_call, dict): + continue + tool_call_dict = cast(dict[str, Any], tool_call) + function_payload = tool_call_dict.get("function") + if not isinstance(function_payload, dict): + continue + function_payload_dict = cast(dict[str, Any], function_payload) + if "arguments" not in function_payload_dict: + continue + arguments = function_payload_dict.get("arguments") + if arguments is None: + function_payload_dict["arguments"] = "" + elif not isinstance(arguments, str): + function_payload_dict["arguments"] = json.dumps(arguments) + # Normalize tool_call_id to toolCallId for tool messages + normalized_msg["role"] = normalize_agui_role(normalized_msg.get("role")) if normalized_msg.get("role") == "tool": if "tool_call_id" in normalized_msg: normalized_msg["toolCallId"] = normalized_msg["tool_call_id"] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py new file mode 100644 index 0000000000..ebf6ef6f57 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -0,0 +1,391 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Helper functions for orchestration logic.""" + +import json +import logging +from typing import TYPE_CHECKING, Any + +from ag_ui.core import StateSnapshotEvent +from agent_framework import ( + ChatMessage, + FunctionApprovalResponseContent, + FunctionCallContent, + FunctionResultContent, + TextContent, +) + +from .._utils import get_role_value, safe_json_parse + +if TYPE_CHECKING: + from .._events import AgentFrameworkEventBridge + from ._state_manager import StateManager + +logger = logging.getLogger(__name__) + + +def pending_tool_call_ids(messages: list[ChatMessage]) -> set[str]: + """Get IDs of tool calls without corresponding results. + + Args: + messages: List of messages to scan + + Returns: + Set of pending tool call IDs + """ + pending_ids: set[str] = set() + resolved_ids: set[str] = set() + for msg in messages: + for content in msg.contents: + if isinstance(content, FunctionCallContent) and content.call_id: + pending_ids.add(str(content.call_id)) + elif isinstance(content, FunctionResultContent) and content.call_id: + resolved_ids.add(str(content.call_id)) + return pending_ids - resolved_ids + + +def is_state_context_message(message: ChatMessage) -> bool: + """Check if a message is a state context system message. + + Args: + message: Message to check + + Returns: + True if this is a state context message + """ + if get_role_value(message) != "system": + return False + for content in message.contents: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + return True + return False + + +def ensure_tool_call_entry( + tool_call_id: str, + tool_calls_by_id: dict[str, dict[str, Any]], + pending_tool_calls: list[dict[str, Any]], +) -> dict[str, Any]: + """Get or create a tool call entry in the tracking dicts. + + Args: + tool_call_id: The tool call ID + tool_calls_by_id: Dict mapping IDs to tool call entries + pending_tool_calls: List of pending tool calls + + Returns: + The tool call entry dict + """ + entry = tool_calls_by_id.get(tool_call_id) + if entry is None: + entry = { + "id": tool_call_id, + "type": "function", + "function": { + "name": "", + "arguments": "", + }, + } + tool_calls_by_id[tool_call_id] = entry + pending_tool_calls.append(entry) + return entry + + +def tool_name_for_call_id( + tool_calls_by_id: dict[str, dict[str, Any]], + tool_call_id: str, +) -> str | None: + """Get the tool name for a given call ID. + + Args: + tool_calls_by_id: Dict mapping IDs to tool call entries + tool_call_id: The tool call ID to look up + + Returns: + Tool name or None if not found + """ + entry = tool_calls_by_id.get(tool_call_id) + if not entry: + return None + function = entry.get("function") + if not isinstance(function, dict): + return None + name = function.get("name") + return str(name) if name else None + + +def tool_calls_match_state( + provider_messages: list[ChatMessage], + state_manager: "StateManager", +) -> bool: + """Check if tool calls in messages match current state. + + Args: + provider_messages: Messages to check + state_manager: State manager with config and current state + + Returns: + True if tool calls match state configuration + """ + if not state_manager.predict_state_config or not state_manager.current_state: + return False + + for state_key, config in state_manager.predict_state_config.items(): + tool_name = config["tool"] + tool_arg_name = config["tool_argument"] + tool_args: dict[str, Any] | None = None + + for msg in reversed(provider_messages): + if get_role_value(msg) != "assistant": + continue + for content in msg.contents: + if isinstance(content, FunctionCallContent) and content.name == tool_name: + tool_args = safe_json_parse(content.arguments) + break + if tool_args is not None: + break + + if not tool_args: + return False + + if tool_arg_name == "*": + state_value = tool_args + elif tool_arg_name in tool_args: + state_value = tool_args[tool_arg_name] + else: + return False + + if state_manager.current_state.get(state_key) != state_value: + return False + + return True + + +def schema_has_steps(schema: Any) -> bool: + """Check if a schema has a steps array property. + + Args: + schema: JSON schema to check + + Returns: + True if schema has steps array + """ + if not isinstance(schema, dict): + return False + properties = schema.get("properties") + if not isinstance(properties, dict): + return False + steps_schema = properties.get("steps") + if not isinstance(steps_schema, dict): + return False + return steps_schema.get("type") == "array" + + +def select_approval_tool_name(client_tools: list[Any] | None) -> str | None: + """Select appropriate approval tool from client tools. + + Args: + client_tools: List of client tool definitions + + Returns: + Name of approval tool, or None if not found + """ + if not client_tools: + return None + for tool in client_tools: + tool_name = getattr(tool, "name", None) + if not tool_name: + continue + params_fn = getattr(tool, "parameters", None) + if not callable(params_fn): + continue + schema = params_fn() + if schema_has_steps(schema): + return str(tool_name) + return None + + +def select_messages_to_run( + provider_messages: list[ChatMessage], + state_manager: "StateManager", +) -> list[ChatMessage]: + """Select and prepare messages for agent execution. + + Injects state context message when appropriate. + + Args: + provider_messages: Original messages from client + state_manager: State manager instance + + Returns: + Messages ready for agent execution + """ + if not provider_messages: + return [] + + is_new_user_turn = get_role_value(provider_messages[-1]) == "user" + conversation_has_tool_calls = tool_calls_match_state(provider_messages, state_manager) + state_context_msg = state_manager.state_context_message( + is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls + ) + if not state_context_msg: + return list(provider_messages) + + messages_to_run = [msg for msg in provider_messages if not is_state_context_message(msg)] + if pending_tool_call_ids(messages_to_run): + return messages_to_run + + insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) + if insert_index < 0: + insert_index = 0 + messages_to_run.insert(insert_index, state_context_msg) + return messages_to_run + + +def build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: + """Build metadata dict with truncated string values. + + Args: + thread_metadata: Raw metadata dict + + Returns: + Metadata with string values truncated to 512 chars + """ + if not thread_metadata: + return {} + safe_metadata: dict[str, Any] = {} + for key, value in thread_metadata.items(): + value_str = value if isinstance(value, str) else json.dumps(value) + if len(value_str) > 512: + value_str = value_str[:512] + safe_metadata[key] = value_str + return safe_metadata + + +def collect_approved_state_snapshots( + provider_messages: list[ChatMessage], + predict_state_config: dict[str, dict[str, str]] | None, + current_state: dict[str, Any], + event_bridge: "AgentFrameworkEventBridge", +) -> list[StateSnapshotEvent]: + """Collect state snapshots from approved function calls. + + Args: + provider_messages: Messages containing approvals + predict_state_config: Predictive state configuration + current_state: Current state dict (will be mutated) + event_bridge: Event bridge for creating events + + Returns: + List of state snapshot events + """ + if not predict_state_config: + return [] + + events: list[StateSnapshotEvent] = [] + for msg in provider_messages: + if get_role_value(msg) != "user": + continue + for content in msg.contents: + if type(content) is FunctionApprovalResponseContent: + if not content.function_call or not content.approved: + continue + parsed_args = content.function_call.parse_arguments() + state_args = None + if content.additional_properties: + state_args = content.additional_properties.get("ag_ui_state_args") + if not isinstance(state_args, dict): + state_args = parsed_args + if not state_args: + continue + for state_key, config in predict_state_config.items(): + if config["tool"] != content.function_call.name: + continue + tool_arg_name = config["tool_argument"] + if tool_arg_name == "*": + state_value = state_args + elif isinstance(state_args, dict) and tool_arg_name in state_args: + state_value = state_args[tool_arg_name] + else: + continue + current_state[state_key] = state_value + event_bridge.current_state[state_key] = state_value + logger.info( + f"Emitting StateSnapshotEvent for approved state key '{state_key}' " + f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" + ) + events.append(StateSnapshotEvent(snapshot=current_state)) + break + return events + + +def latest_approval_response(messages: list[ChatMessage]) -> FunctionApprovalResponseContent | None: + """Get the latest approval response from messages. + + Args: + messages: Messages to search + + Returns: + Latest approval response or None + """ + if not messages: + return None + last_message = messages[-1] + for content in last_message.contents: + if type(content) is FunctionApprovalResponseContent: + return content + return None + + +def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: + """Extract steps from an approval response. + + Args: + approval: Approval response content + + Returns: + List of steps, or empty list if none + """ + state_args: Any | None = None + if approval.additional_properties: + state_args = approval.additional_properties.get("ag_ui_state_args") + if isinstance(state_args, dict): + steps = state_args.get("steps") + if isinstance(steps, list): + return steps + + if approval.function_call: + parsed_args = approval.function_call.parse_arguments() + if isinstance(parsed_args, dict): + steps = parsed_args.get("steps") + if isinstance(steps, list): + return steps + + return [] + + +def is_step_based_approval( + approval: FunctionApprovalResponseContent, + predict_state_config: dict[str, dict[str, str]] | None, +) -> bool: + """Check if an approval is step-based. + + Args: + approval: Approval response to check + predict_state_config: Predictive state configuration + + Returns: + True if this is a step-based approval + """ + steps = approval_steps(approval) + if steps: + return True + if not approval.function_call: + return False + if not predict_state_config: + return False + tool_name = approval.function_call.name + for config in predict_state_config.values(): + if config.get("tool") == tool_name and config.get("tool_argument") == "steps": + return True + return False diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py deleted file mode 100644 index 97c990781b..0000000000 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Message hygiene utilities for orchestrators.""" - -import json -import logging -from typing import Any - -from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent - -logger = logging.getLogger(__name__) - - -def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: - """Normalize tool ordering and inject synthetic results for AG-UI edge cases.""" - sanitized: list[ChatMessage] = [] - pending_tool_call_ids: set[str] | None = None - pending_confirm_changes_id: str | None = None - - for msg in messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - if role_value == "assistant": - tool_ids = { - str(content.call_id) - for content in msg.contents or [] - if isinstance(content, FunctionCallContent) and content.call_id - } - confirm_changes_call = None - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": - confirm_changes_call = content - break - - sanitized.append(msg) - pending_tool_call_ids = tool_ids if tool_ids else None - pending_confirm_changes_id = ( - str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None - ) - continue - - if role_value == "user": - if pending_confirm_changes_id: - user_text = "" - for content in msg.contents or []: - if isinstance(content, TextContent): - user_text = content.text - break - - try: - parsed = json.loads(user_text) - if "accepted" in parsed: - logger.info( - f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" - ) - synthetic_result = ChatMessage( - role="tool", - contents=[ - FunctionResultContent( - call_id=pending_confirm_changes_id, - result="Confirmed" if parsed.get("accepted") else "Rejected", - ) - ], - ) - sanitized.append(synthetic_result) - if pending_tool_call_ids: - pending_tool_call_ids.discard(pending_confirm_changes_id) - pending_confirm_changes_id = None - continue - except (json.JSONDecodeError, KeyError) as exc: - logger.debug("Could not parse user message as confirm_changes response: %s", type(exc).__name__) - - if pending_tool_call_ids: - logger.info( - f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results" - ) - for pending_call_id in pending_tool_call_ids: - logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}") - synthetic_result = ChatMessage( - role="tool", - contents=[ - FunctionResultContent( - call_id=pending_call_id, - result="Tool execution skipped - user provided follow-up message", - ) - ], - ) - sanitized.append(synthetic_result) - pending_tool_call_ids = None - pending_confirm_changes_id = None - - sanitized.append(msg) - pending_confirm_changes_id = None - continue - - if role_value == "tool": - if not pending_tool_call_ids: - continue - keep = False - for content in msg.contents or []: - if isinstance(content, FunctionResultContent): - call_id = str(content.call_id) - if call_id in pending_tool_call_ids: - keep = True - if call_id == pending_confirm_changes_id: - pending_confirm_changes_id = None - break - if keep: - sanitized.append(msg) - continue - - sanitized.append(msg) - pending_tool_call_ids = None - pending_confirm_changes_id = None - - return sanitized - - -def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: - """Remove duplicate messages while preserving order.""" - seen_keys: dict[Any, int] = {} - unique_messages: list[ChatMessage] = [] - - for idx, msg in enumerate(messages): - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): - call_id = str(msg.contents[0].call_id) - key: Any = (role_value, call_id) - - if key in seen_keys: - existing_idx = seen_keys[key] - existing_msg = unique_messages[existing_idx] - - existing_result = None - if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): - existing_result = existing_msg.contents[0].result - new_result = msg.contents[0].result - - if (not existing_result or existing_result == "") and new_result: - logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}") - unique_messages[existing_idx] = msg - else: - logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - elif ( - role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents) - ): - tool_call_ids = tuple( - sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) - ) - key = (role_value, tool_call_ids) - - if key in seen_keys: - logger.info(f"Skipping duplicate assistant tool call at index {idx}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - else: - content_str = str([str(c) for c in msg.contents]) if msg.contents else "" - key = (role_value, hash(content_str)) - - if key in seen_keys: - logger.info(f"Skipping duplicate message at index {idx}: role={role_value}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - return unique_messages diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py new file mode 100644 index 0000000000..8662036bbf --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py @@ -0,0 +1,230 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Predictive state handling utilities.""" + +import json +import logging +import re +from typing import Any + +from ag_ui.core import StateDeltaEvent + +from .._utils import safe_json_parse + +logger = logging.getLogger(__name__) + + +class PredictiveStateHandler: + """Handles predictive state updates from streaming tool calls.""" + + def __init__( + self, + predict_state_config: dict[str, dict[str, str]] | None = None, + current_state: dict[str, Any] | None = None, + ) -> None: + """Initialize the handler. + + Args: + predict_state_config: Configuration mapping state keys to tool/argument pairs + current_state: Reference to current state dict + """ + self.predict_state_config = predict_state_config or {} + self.current_state = current_state or {} + self.streaming_tool_args: str = "" + self.last_emitted_state: dict[str, Any] = {} + self.state_delta_count: int = 0 + self.pending_state_updates: dict[str, Any] = {} + + def reset_streaming(self) -> None: + """Reset streaming state for a new tool call.""" + self.streaming_tool_args = "" + self.state_delta_count = 0 + + def extract_state_value( + self, + tool_name: str, + args: dict[str, Any] | str | None, + ) -> tuple[str, Any] | None: + """Extract state value from tool arguments based on config. + + Args: + tool_name: Name of the tool being called + args: Tool arguments (dict or JSON string) + + Returns: + Tuple of (state_key, state_value) or None if no match + """ + if not self.predict_state_config: + return None + + parsed_args = safe_json_parse(args) if isinstance(args, str) else args + if not parsed_args: + return None + + for state_key, config in self.predict_state_config.items(): + if config["tool"] != tool_name: + continue + tool_arg_name = config["tool_argument"] + if tool_arg_name == "*": + return (state_key, parsed_args) + if tool_arg_name in parsed_args: + return (state_key, parsed_args[tool_arg_name]) + + return None + + def is_predictive_tool(self, tool_name: str | None) -> bool: + """Check if a tool is configured for predictive state. + + Args: + tool_name: Name of the tool to check + + Returns: + True if tool is in predictive state config + """ + if not tool_name or not self.predict_state_config: + return False + for config in self.predict_state_config.values(): + if config["tool"] == tool_name: + return True + return False + + def emit_streaming_deltas( + self, + tool_name: str | None, + argument_chunk: str, + ) -> list[StateDeltaEvent]: + """Process streaming argument chunk and emit state deltas. + + Args: + tool_name: Name of the current tool + argument_chunk: New chunk of JSON arguments + + Returns: + List of state delta events to emit + """ + events: list[StateDeltaEvent] = [] + if not tool_name or not self.predict_state_config: + return events + + self.streaming_tool_args += argument_chunk + logger.debug( + "Predictive state: accumulated %s chars for tool '%s'", + len(self.streaming_tool_args), + tool_name, + ) + + # Try to parse complete JSON first + parsed_args = None + try: + parsed_args = json.loads(self.streaming_tool_args) + except json.JSONDecodeError: + # Fall back to regex matching for partial JSON + events.extend(self._emit_partial_deltas(tool_name)) + + if parsed_args: + events.extend(self._emit_complete_deltas(tool_name, parsed_args)) + + return events + + def _emit_partial_deltas(self, tool_name: str) -> list[StateDeltaEvent]: + """Emit deltas from partial JSON using regex matching. + + Args: + tool_name: Name of the current tool + + Returns: + List of state delta events + """ + events: list[StateDeltaEvent] = [] + + for state_key, config in self.predict_state_config.items(): + if config["tool"] != tool_name: + continue + tool_arg_name = config["tool_argument"] + pattern = rf'"{re.escape(tool_arg_name)}":\s*"([^"]*)' + match = re.search(pattern, self.streaming_tool_args) + + if match: + partial_value = match.group(1).replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\") + + if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != partial_value: + event = self._create_delta_event(state_key, partial_value) + events.append(event) + self.last_emitted_state[state_key] = partial_value + self.pending_state_updates[state_key] = partial_value + + return events + + def _emit_complete_deltas( + self, + tool_name: str, + parsed_args: dict[str, Any], + ) -> list[StateDeltaEvent]: + """Emit deltas from complete parsed JSON. + + Args: + tool_name: Name of the current tool + parsed_args: Fully parsed arguments dict + + Returns: + List of state delta events + """ + events: list[StateDeltaEvent] = [] + + for state_key, config in self.predict_state_config.items(): + if config["tool"] != tool_name: + continue + tool_arg_name = config["tool_argument"] + + if tool_arg_name == "*": + state_value = parsed_args + elif tool_arg_name in parsed_args: + state_value = parsed_args[tool_arg_name] + else: + continue + + if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != state_value: + event = self._create_delta_event(state_key, state_value) + events.append(event) + self.last_emitted_state[state_key] = state_value + self.pending_state_updates[state_key] = state_value + + return events + + def _create_delta_event(self, state_key: str, value: Any) -> StateDeltaEvent: + """Create a state delta event with logging. + + Args: + state_key: The state key being updated + value: The new value + + Returns: + StateDeltaEvent instance + """ + self.state_delta_count += 1 + if self.state_delta_count % 10 == 1: + logger.info( + "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value_length=%s", + self.state_delta_count, + state_key, + state_key, + len(str(value)), + ) + elif self.state_delta_count % 100 == 0: + logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") + + return StateDeltaEvent( + delta=[ + { + "op": "replace", + "path": f"/{state_key}", + "value": value, + } + ], + ) + + def apply_pending_updates(self) -> None: + """Apply pending updates to current state and clear them.""" + for key, value in self.pending_state_updates.items(): + self.current_state[key] = value + self.pending_state_updates.clear() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py index 45c16afef4..7d8a23d84c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py @@ -22,9 +22,11 @@ def __init__( self.predict_state_config = predict_state_config or {} self.require_confirmation = require_confirmation self.current_state: dict[str, Any] = {} + self._state_from_input: bool = False def initialize(self, initial_state: dict[str, Any] | None) -> dict[str, Any]: """Initialize state with schema defaults.""" + self._state_from_input = initial_state is not None self.current_state = (initial_state or {}).copy() self._apply_schema_defaults() return self.current_state @@ -60,7 +62,9 @@ def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_ca """Inject state context only when starting a new user turn.""" if not self.current_state or not self.state_schema: return None - if not is_new_user_turn or conversation_has_tool_calls: + if not is_new_user_turn: + return None + if conversation_has_tool_calls and not self._state_from_input: return None state_json = json.dumps(self.current_state, indent=2) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6bdff552b6..3067e3e4a7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -16,6 +16,10 @@ TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, ) from agent_framework import ( AgentProtocol, @@ -25,8 +29,31 @@ FunctionResultContent, TextContent, ) +from agent_framework._middleware import extract_and_merge_function_middleware +from agent_framework._tools import ( + FunctionInvocationConfiguration, + _collect_approval_responses, # type: ignore + _replace_approval_contents_with_results, # type: ignore + _try_execute_function_calls, # type: ignore +) -from ._utils import convert_agui_tools_to_agent_framework, generate_event_id +from ._orchestration._helpers import ( + approval_steps, + build_safe_metadata, + collect_approved_state_snapshots, + ensure_tool_call_entry, + is_step_based_approval, + latest_approval_response, + select_approval_tool_name, + select_messages_to_run, + tool_name_for_call_id, +) +from ._orchestration._tooling import ( + collect_server_tools, + merge_tools, + register_additional_client_tools, +) +from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value if TYPE_CHECKING: from ._agent import AgentConfig @@ -61,6 +88,7 @@ def __init__( # Lazy-loaded properties self._messages = None + self._snapshot_messages = None self._last_message = None self._run_id: str | None = None self._thread_id: str | None = None @@ -69,12 +97,27 @@ def __init__( def messages(self): """Get converted Agent Framework messages (lazy loaded).""" if self._messages is None: - from ._message_adapters import agui_messages_to_agent_framework + from ._message_adapters import normalize_agui_input_messages raw = self.input_data.get("messages", []) - self._messages = agui_messages_to_agent_framework(raw) + if not isinstance(raw, list): + raw = [] + self._messages, self._snapshot_messages = normalize_agui_input_messages(raw) return self._messages + @property + def snapshot_messages(self) -> list[dict[str, Any]]: + """Get normalized AG-UI snapshot messages (lazy loaded).""" + if self._snapshot_messages is None: + if self._messages is None: + _ = self.messages + else: + from ._message_adapters import agent_framework_messages_to_agui, agui_messages_to_snapshot_format + + raw_snapshot = agent_framework_messages_to_agui(self._messages) + self._snapshot_messages = agui_messages_to_snapshot_format(raw_snapshot) + return self._snapshot_messages or [] + @property def last_message(self): """Get the last message in the conversation (lazy loaded).""" @@ -270,14 +313,7 @@ async def run( AG-UI events """ from ._events import AgentFrameworkEventBridge - from ._message_adapters import agui_messages_to_snapshot_format - from ._orchestration._message_hygiene import deduplicate_messages, sanitize_tool_history from ._orchestration._state_manager import StateManager - from ._orchestration._tooling import ( - collect_server_tools, - merge_tools, - register_additional_client_tools, - ) logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}") @@ -286,12 +322,15 @@ async def run( response_format = context.agent.chat_options.response_format skip_text_content = response_format is not None + client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) + approval_tool_name = select_approval_tool_name(client_tools) + state_manager = StateManager( state_schema=context.config.state_schema, predict_state_config=context.config.predict_state_config, require_confirmation=context.config.require_confirmation, ) - current_state = state_manager.initialize(context.input_data.get("state", {})) + current_state = state_manager.initialize(context.input_data.get("state")) event_bridge = AgentFrameworkEventBridge( run_id=context.run_id, @@ -299,8 +338,8 @@ async def run( predict_state_config=context.config.predict_state_config, current_state=current_state, skip_text_content=skip_text_content, - input_messages=context.input_data.get("messages", []), require_confirmation=context.config.require_confirmation, + approval_tool_name=approval_tool_name, ) yield event_bridge.create_run_started_event() @@ -321,17 +360,18 @@ async def run( if current_state: thread.metadata["current_state"] = current_state # type: ignore[attr-defined] - raw_messages = context.messages or [] - if not raw_messages: + provider_messages = context.messages or [] + snapshot_messages = context.snapshot_messages + if not provider_messages: logger.warning("No messages provided in AG-UI input") yield event_bridge.create_run_finished_event() return - logger.info(f"Received {len(raw_messages)} raw messages from client") - for i, msg in enumerate(raw_messages): - role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + logger.info(f"Received {len(provider_messages)} provider messages from client") + for i, msg in enumerate(provider_messages): + role = get_role_value(msg) msg_id = getattr(msg, "message_id", None) - logger.info(f" Raw message {i}: role={role}, id={msg_id}") + logger.info(f" Message {i}: role={role}, id={msg_id}") if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): content_type = type(content).__name__ @@ -354,62 +394,26 @@ async def run( else: logger.debug(f" Content {j}: {content_type}") - sanitized_messages = sanitize_tool_history(raw_messages) - provider_messages = deduplicate_messages(sanitized_messages) - - if not provider_messages: - logger.info("No provider-eligible messages after filtering; finishing run without invoking agent.") - yield event_bridge.create_run_finished_event() - return - - logger.info(f"Processing {len(provider_messages)} provider messages after sanitization/deduplication") - for i, msg in enumerate(provider_messages): - role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - logger.info(f" Message {i}: role={role}") - if hasattr(msg, "contents") and msg.contents: - for j, content in enumerate(msg.contents): - content_type = type(content).__name__ - if isinstance(content, TextContent): - logger.info(f" Content {j}: {content_type} - text_length={len(content.text)}") - elif isinstance(content, FunctionCallContent): - arg_length = len(str(content.arguments)) if content.arguments else 0 - logger.info(" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length) - elif isinstance(content, FunctionResultContent): - result_preview = type(content.result).__name__ if content.result is not None else "None" - logger.info( - " Content %s: %s - call_id=%s, result_type=%s", - j, - content_type, - content.call_id, - result_preview, - ) - else: - logger.info(f" Content {j}: {content_type}") - - messages_to_run: list[Any] = [] - is_new_user_turn = False - if provider_messages: - last_msg = provider_messages[-1] - role_value = last_msg.role.value if hasattr(last_msg.role, "value") else str(last_msg.role) - is_new_user_turn = role_value == "user" - - conversation_has_tool_calls = False - for msg in provider_messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - if role_value == "assistant" and hasattr(msg, "contents") and msg.contents: - if any(isinstance(content, FunctionCallContent) for content in msg.contents): - conversation_has_tool_calls = True - break - - state_context_msg = state_manager.state_context_message( - is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls - ) - if state_context_msg: - messages_to_run.append(state_context_msg) - - messages_to_run.extend(provider_messages) + pending_tool_calls: list[dict[str, Any]] = [] + tool_calls_by_id: dict[str, dict[str, Any]] = {} + tool_results: list[dict[str, Any]] = [] + tool_calls_ended: set[str] = set() + messages_snapshot_emitted = False + accumulated_text_content = "" + active_message_id: str | None = None + + # Check for FunctionApprovalResponseContent and emit updated state snapshot + # This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps) + for snapshot_evt in collect_approved_state_snapshots( + provider_messages, + context.config.predict_state_config, + current_state, + event_bridge, + ): + yield snapshot_evt + + messages_to_run = select_messages_to_run(provider_messages, state_manager) - client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") if client_tools: for tool in client_tools: @@ -421,17 +425,11 @@ async def run( register_additional_client_tools(context.agent, client_tools) tools_param = merge_tools(server_tools, client_tools) - all_updates: list[Any] = [] + collect_updates = response_format is not None + all_updates: list[Any] | None = [] if collect_updates else None update_count = 0 # Prepare metadata for chat client (Azure requires string values) - safe_metadata: dict[str, Any] = {} - thread_metadata = getattr(thread, "metadata", None) - if thread_metadata: - for key, value in thread_metadata.items(): - value_str = value if isinstance(value, str) else json.dumps(value) - if len(value_str) > 512: - value_str = value_str[:512] - safe_metadata[key] = value_str + safe_metadata = build_safe_metadata(getattr(thread, "metadata", None)) run_kwargs: dict[str, Any] = { "thread": thread, @@ -441,27 +439,200 @@ async def run( if safe_metadata: run_kwargs["store"] = True + async def _resolve_approval_responses( + messages: list[Any], + tools_for_execution: list[Any], + ) -> None: + fcc_todo = _collect_approval_responses(messages) + if not fcc_todo: + return + + approved_responses = [resp for resp in fcc_todo.values() if resp.approved] + approved_function_results: list[Any] = [] + if approved_responses and tools_for_execution: + chat_client = getattr(context.agent, "chat_client", None) + config = ( + getattr(chat_client, "function_invocation_configuration", None) or FunctionInvocationConfiguration() + ) + middleware_pipeline = extract_and_merge_function_middleware(chat_client, run_kwargs) + try: + results, _ = await _try_execute_function_calls( + custom_args=run_kwargs, + attempt_idx=0, + function_calls=approved_responses, + tools=tools_for_execution, + middleware_pipeline=middleware_pipeline, + config=config, + ) + approved_function_results = list(results) + except Exception: + logger.error("Failed to execute approved tool calls; injecting error results.") + approved_function_results = [] + + normalized_results: list[FunctionResultContent] = [] + for idx, approval in enumerate(approved_responses): + if idx < len(approved_function_results) and isinstance( + approved_function_results[idx], FunctionResultContent + ): + normalized_results.append(approved_function_results[idx]) + continue + call_id = approval.function_call.call_id or approval.id + normalized_results.append( + FunctionResultContent(call_id=call_id, result="Error: Tool call invocation failed.") + ) + + _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore + + def _should_emit_tool_snapshot(tool_name: str | None) -> bool: + if not pending_tool_calls or not tool_results: + return False + if tool_name and context.config.predict_state_config and not context.config.require_confirmation: + for config in context.config.predict_state_config.values(): + if config["tool"] == tool_name: + logger.info( + f"Skipping intermediate MessagesSnapshotEvent for predictive tool '{tool_name}' " + " - delaying until summary" + ) + return False + return True + + def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnapshotEvent: + has_text_content = bool(accumulated_text_content) + all_messages = snapshot_messages.copy() + + if pending_tool_calls: + if tool_message_id and not has_text_content: + tool_call_message_id = tool_message_id + else: + tool_call_message_id = ( + active_message_id if not has_text_content and active_message_id else generate_event_id() + ) + tool_call_message = { + "id": tool_call_message_id, + "role": "assistant", + "tool_calls": pending_tool_calls.copy(), + } + all_messages.append(tool_call_message) + + all_messages.extend(tool_results) + + if has_text_content and active_message_id: + assistant_text_message = { + "id": active_message_id, + "role": "assistant", + "content": accumulated_text_content, + } + all_messages.append(assistant_text_message) + + return MessagesSnapshotEvent( + messages=all_messages, # type: ignore[arg-type] + ) + + # Use tools_param if available (includes client tools), otherwise fall back to server_tools + # This ensures both server tools AND client tools can be executed after approval + tools_for_approval = tools_param if tools_param is not None else server_tools + latest_approval = latest_approval_response(messages_to_run) + await _resolve_approval_responses(messages_to_run, tools_for_approval) + + if latest_approval and is_step_based_approval(latest_approval, context.config.predict_state_config): + from ._confirmation_strategies import DefaultConfirmationStrategy + + strategy = context.confirmation_strategy + if strategy is None: + strategy = DefaultConfirmationStrategy() + + steps = approval_steps(latest_approval) + if steps: + if latest_approval.approved: + confirmation_message = strategy.on_approval_accepted(steps) + else: + confirmation_message = strategy.on_approval_rejected(steps) + else: + if latest_approval.approved: + confirmation_message = strategy.on_state_confirmed() + else: + confirmation_message = strategy.on_state_rejected() + + message_id = generate_event_id() + yield TextMessageStartEvent(message_id=message_id, role="assistant") + yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message) + yield TextMessageEndEvent(message_id=message_id) + yield event_bridge.create_run_finished_event() + return + async for update in context.agent.run_stream(messages_to_run, **run_kwargs): update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") - all_updates.append(update) + if all_updates is not None: + all_updates.append(update) + if event_bridge.current_message_id is None and update.contents: + has_tool_call = any(isinstance(content, FunctionCallContent) for content in update.contents) + has_text = any(isinstance(content, TextContent) for content in update.contents) + if has_tool_call and not has_text: + tool_message_id = generate_event_id() + event_bridge.current_message_id = tool_message_id + active_message_id = tool_message_id + accumulated_text_content = "" + logger.info( + "[STREAM] Emitting TextMessageStartEvent for tool-only response message_id=%s", + tool_message_id, + ) + yield TextMessageStartEvent(message_id=tool_message_id, role="assistant") events = await event_bridge.from_agent_run_update(update) logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") for event in events: + if isinstance(event, TextMessageStartEvent): + active_message_id = event.message_id + accumulated_text_content = "" + elif isinstance(event, TextMessageContentEvent): + accumulated_text_content += event.delta + elif isinstance(event, ToolCallStartEvent): + tool_call_entry = ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) + tool_call_entry["function"]["name"] = event.tool_call_name + elif isinstance(event, ToolCallArgsEvent): + tool_call_entry = ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) + tool_call_entry["function"]["arguments"] += event.delta + elif isinstance(event, ToolCallEndEvent): + tool_calls_ended.add(event.tool_call_id) + elif isinstance(event, ToolCallResultEvent): + tool_results.append( + { + "id": event.message_id, + "role": "tool", + "toolCallId": event.tool_call_id, + "content": event.content, + } + ) logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event + if isinstance(event, ToolCallResultEvent): + tool_name = tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) + if _should_emit_tool_snapshot(tool_name): + messages_snapshot_emitted = True + messages_snapshot = _build_messages_snapshot() + logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") + yield messages_snapshot + elif isinstance(event, ToolCallEndEvent): + tool_name = tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) + if tool_name == "confirm_changes": + messages_snapshot_emitted = True + messages_snapshot = _build_messages_snapshot() + logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") + yield messages_snapshot logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") if event_bridge.should_stop_after_confirm: - logger.info("Stopping run after confirm_changes - waiting for user response") + logger.info("Stopping run - waiting for user approval/confirmation response") + if event_bridge.current_message_id: + logger.info(f"[CONFIRM] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") + yield event_bridge.create_message_end_event(event_bridge.current_message_id) + event_bridge.current_message_id = None yield event_bridge.create_run_finished_event() return - if event_bridge.pending_tool_calls: - pending_without_end = [ - tc for tc in event_bridge.pending_tool_calls if tc.get("id") not in event_bridge.tool_calls_ended - ] + if pending_tool_calls: + pending_without_end = [tc for tc in pending_tool_calls if tc.get("id") not in tool_calls_ended] if pending_without_end: logger.info( "Found %s pending tool calls without end event - emitting ToolCallEndEvent", @@ -470,13 +641,11 @@ async def run( for tool_call in pending_without_end: tool_call_id = tool_call.get("id") if tool_call_id: - from ag_ui.core import ToolCallEndEvent - end_event = ToolCallEndEvent(tool_call_id=tool_call_id) logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'") yield end_event - if all_updates and response_format: + if response_format and all_updates: from agent_framework import AgentRunResponse from pydantic import BaseModel @@ -508,37 +677,22 @@ async def run( logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) - assistant_text_message = { - "id": event_bridge.current_message_id, - "role": "assistant", - "content": event_bridge.accumulated_text_content, - } - - converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) - all_messages = converted_input_messages.copy() - - if event_bridge.pending_tool_calls: - tool_call_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": event_bridge.pending_tool_calls.copy(), - } - all_messages.append(tool_call_message) - - all_messages.extend(event_bridge.tool_results.copy()) - all_messages.append(assistant_text_message) - - messages_snapshot = MessagesSnapshotEvent( - messages=all_messages, # type: ignore[arg-type] - ) + messages_snapshot = _build_messages_snapshot(tool_message_id=event_bridge.current_message_id) + messages_snapshot_emitted = True logger.info( - "[FINALIZE] Emitting MessagesSnapshotEvent with %s messages (text content length: %s)", - len(all_messages), - len(event_bridge.accumulated_text_content), + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(messages_snapshot.messages)} messages " + f"(text content length: {len(accumulated_text_content)})" ) yield messages_snapshot else: logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") + if not messages_snapshot_emitted and (pending_tool_calls or tool_results): + messages_snapshot = _build_messages_snapshot() + messages_snapshot_emitted = True + logger.info( + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(messages_snapshot.messages)} messages" + ) + yield messages_snapshot logger.info("[FINALIZE] Emitting RUN_FINISHED event") yield event_bridge.create_run_finished_event() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index 8b271988dc..c0da986308 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -3,13 +3,29 @@ """Utility functions for AG-UI integration.""" import copy +import json import uuid from collections.abc import Callable, MutableMapping, Sequence from dataclasses import asdict, is_dataclass from datetime import date, datetime from typing import Any -from agent_framework import AIFunction, ToolProtocol +from agent_framework import AIFunction, Role, ToolProtocol + +# Role mapping constants +AGUI_TO_FRAMEWORK_ROLE: dict[str, Role] = { + "user": Role.USER, + "assistant": Role.ASSISTANT, + "system": Role.SYSTEM, +} + +FRAMEWORK_TO_AGUI_ROLE: dict[Role, str] = { + Role.USER: "user", + Role.ASSISTANT: "assistant", + Role.SYSTEM: "system", +} + +ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool"} def generate_event_id() -> str: @@ -17,6 +33,85 @@ def generate_event_id() -> str: return str(uuid.uuid4()) +def safe_json_parse(value: Any) -> dict[str, Any] | None: + """Safely parse a value as JSON dict. + + Args: + value: String or dict to parse + + Returns: + Parsed dict or None if parsing fails + """ + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + return None + + +def get_role_value(message: Any) -> str: + """Extract role string from a message object. + + Handles both enum roles (with .value) and string roles. + + Args: + message: Message object with role attribute + + Returns: + Role as lowercase string, or empty string if not found + """ + role = getattr(message, "role", None) + if role is None: + return "" + if hasattr(role, "value"): + return str(role.value) + return str(role) + + +def normalize_agui_role(raw_role: Any) -> str: + """Normalize an AG-UI role to a standard role string. + + Args: + raw_role: Raw role value from AG-UI message + + Returns: + Normalized role string (user, assistant, system, or tool) + """ + if not isinstance(raw_role, str): + return "user" + role = raw_role.lower() + if role == "developer": + return "system" + if role in ALLOWED_AGUI_ROLES: + return role + return "user" + + +def extract_state_from_tool_args( + args: dict[str, Any] | None, + tool_arg_name: str, +) -> Any: + """Extract state value from tool arguments based on config. + + Args: + args: Parsed tool arguments dict + tool_arg_name: Name of the argument to extract, or "*" for entire args + + Returns: + Extracted state value, or None if not found + """ + if not args: + return None + if tool_arg_name == "*": + return args + return args.get(tool_arg_name) + + def merge_state(current: dict[str, Any], update: dict[str, Any]) -> dict[str, Any]: """Merge state updates. diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py index abbd113418..ab7a3533cd 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py @@ -75,8 +75,10 @@ def human_in_the_loop_agent(chat_client: ChatClientProtocol) -> ChatAgent: 9. "Calibrate systems" 10. "Final testing" - After calling the function, provide a brief acknowledgment like: - "I've created a plan with 10 steps. You can customize which steps to enable before I proceed." + IMPORTANT: When you call generate_task_steps, the user will be shown the steps and asked to approve. + Do NOT output any text along with the function call - just call the function. + After the user approves and the function executes, THEN provide a brief acknowledgment like: + "The plan has been created with X steps selected." """, chat_client=chat_client, tools=[generate_task_steps], diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index beb6f8af2c..281b81c968 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -630,3 +630,179 @@ async def stream_fn( # Should contain some reference to the document full_text = "".join(e.delta for e in text_events) assert "written" in full_text.lower() or "document" in full_text.lower() + + +async def test_function_approval_mode_executes_tool(): + """Test that function approval with approval_mode='always_require' sends the correct messages.""" + from agent_framework import FunctionResultContent, ai_function + from agent_framework.ag_ui import AgentFrameworkAgent + + messages_received: list[Any] = [] + + @ai_function( + name="get_datetime", + description="Get the current date and time", + approval_mode="always_require", + ) + def get_datetime() -> str: + return "2025/12/01 12:00:00" + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + # Capture the messages received by the chat client + messages_received.clear() + messages_received.extend(messages) + yield ChatResponseUpdate(contents=[TextContent(text="Processing completed")]) + + agent = ChatAgent( + name="test_agent", + instructions="Test", + chat_client=StreamingChatClientStub(stream_fn), + tools=[get_datetime], + ) + wrapper = AgentFrameworkAgent(agent=agent) + + # Simulate the conversation history with: + # 1. User message asking for time + # 2. Assistant message with the function call that needs approval + # 3. Tool approval message from user + tool_result: dict[str, Any] = {"accepted": True} + input_data: dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": "What time is it?", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_get_datetime_123", + "type": "function", + "function": { + "name": "get_datetime", + "arguments": "{}", + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps(tool_result), + "toolCallId": "call_get_datetime_123", + }, + ], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + # Verify the run completed successfully + run_started = [e for e in events if e.type == "RUN_STARTED"] + run_finished = [e for e in events if e.type == "RUN_FINISHED"] + assert len(run_started) == 1 + assert len(run_finished) == 1 + + # Verify that a FunctionResultContent was created and sent to the agent + # Approved tool calls are resolved before the model run. + tool_result_found = False + for msg in messages_received: + for content in msg.contents: + if isinstance(content, FunctionResultContent): + tool_result_found = True + assert content.call_id == "call_get_datetime_123" + assert content.result == "2025/12/01 12:00:00" + break + + assert tool_result_found, ( + "FunctionResultContent should be included in messages sent to agent. " + "This is required for the model to see the approved tool execution result." + ) + + +async def test_function_approval_mode_rejection(): + """Test that function approval rejection creates a rejection response.""" + from agent_framework import FunctionResultContent, ai_function + from agent_framework.ag_ui import AgentFrameworkAgent + + messages_received: list[Any] = [] + + @ai_function( + name="delete_all_data", + description="Delete all user data", + approval_mode="always_require", + ) + def delete_all_data() -> str: + return "All data deleted" + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + # Capture the messages received by the chat client + messages_received.clear() + messages_received.extend(messages) + yield ChatResponseUpdate(contents=[TextContent(text="Operation cancelled")]) + + agent = ChatAgent( + name="test_agent", + instructions="Test", + chat_client=StreamingChatClientStub(stream_fn), + tools=[delete_all_data], + ) + wrapper = AgentFrameworkAgent(agent=agent) + + # Simulate rejection + tool_result: dict[str, Any] = {"accepted": False} + input_data: dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": "Delete all my data", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_delete_123", + "type": "function", + "function": { + "name": "delete_all_data", + "arguments": "{}", + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps(tool_result), + "toolCallId": "call_delete_123", + }, + ], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + # Verify the run completed + run_finished = [e for e in events if e.type == "RUN_FINISHED"] + assert len(run_finished) == 1 + + # Verify that a FunctionResultContent with rejection payload was created + rejection_found = False + for msg in messages_received: + for content in msg.contents: + if isinstance(content, FunctionResultContent): + rejection_found = True + assert content.call_id == "call_delete_123" + assert content.result == "Error: Tool call invocation was rejected by user." + break + + assert rejection_found, ( + "FunctionResultContent with rejection details should be included in messages sent to agent. " + "This tells the model that the tool was rejected." + ) diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py index 6fefc14665..97654182cf 100644 --- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py +++ b/python/packages/ag-ui/tests/test_backend_tool_rendering.py @@ -52,8 +52,8 @@ async def test_tool_call_flow(): update2 = AgentRunResponseUpdate(contents=[tool_result]) events2 = await bridge.from_agent_run_update(update2) - # Should have: ToolCallEndEvent, ToolCallResultEvent, MessagesSnapshotEvent - assert len(events2) == 3 + # Should have: ToolCallEndEvent, ToolCallResultEvent + assert len(events2) == 2 assert isinstance(events2[0], ToolCallEndEvent) assert isinstance(events2[1], ToolCallResultEvent) diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py index 20b53cc18f..cfd45ea5c8 100644 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ b/python/packages/ag-ui/tests/test_events_comprehensive.py @@ -231,7 +231,12 @@ async def test_function_approval_request_basic(): """Test FunctionApprovalRequestContent conversion.""" from agent_framework_ag_ui._events import AgentFrameworkEventBridge - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + # Set require_confirmation=False to test just the function_approval_request event + bridge = AgentFrameworkEventBridge( + run_id="test_run", + thread_id="test_thread", + require_confirmation=False, + ) func_call = FunctionCallContent( call_id="call_123", @@ -284,14 +289,12 @@ async def test_empty_predict_state_config(): assert "STATE_DELTA" not in event_types assert "STATE_SNAPSHOT" not in event_types - # Should have: ToolCallStart, ToolCallArgs, ToolCallEnd, ToolCallResult, MessagesSnapshot - # MessagesSnapshotEvent is emitted after tool results to track the conversation + # Should have: ToolCallStart, ToolCallArgs, ToolCallEnd, ToolCallResult assert event_types == [ "TOOL_CALL_START", "TOOL_CALL_ARGS", "TOOL_CALL_END", "TOOL_CALL_RESULT", - "MESSAGES_SNAPSHOT", ] diff --git a/python/packages/ag-ui/tests/test_helpers_ag_ui.py b/python/packages/ag-ui/tests/test_helpers_ag_ui.py index bfb528511e..fc82b11510 100644 --- a/python/packages/ag-ui/tests/test_helpers_ag_ui.py +++ b/python/packages/ag-ui/tests/test_helpers_ag_ui.py @@ -18,6 +18,7 @@ from agent_framework._clients import BaseChatClient from agent_framework._types import ChatResponse, ChatResponseUpdate +from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history from agent_framework_ag_ui._orchestrators import ExecutionContext StreamFn = Callable[..., AsyncIterator[ChatResponseUpdate]] @@ -134,5 +135,9 @@ def get_new_thread(self, **kwargs: Any) -> AgentThread: class TestExecutionContext(ExecutionContext): """ExecutionContext helper that allows setting messages for tests.""" - def set_messages(self, messages: list[ChatMessage]) -> None: - self._messages = messages + def set_messages(self, messages: list[ChatMessage], *, normalize: bool = True) -> None: + if normalize: + self._messages = _deduplicate_messages(_sanitize_tool_history(messages)) + else: + self._messages = messages + self._snapshot_messages = None diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py index 92f6d69926..55a2869c91 100644 --- a/python/packages/ag-ui/tests/test_human_in_the_loop.py +++ b/python/packages/ag-ui/tests/test_human_in_the_loop.py @@ -10,9 +10,11 @@ async def test_function_approval_request_emission(): """Test that CustomEvent is emitted for FunctionApprovalRequestContent.""" + # Set require_confirmation=False to test just the function_approval_request event bridge = AgentFrameworkEventBridge( run_id="test_run", thread_id="test_thread", + require_confirmation=False, ) # Create approval request @@ -47,11 +49,65 @@ async def test_function_approval_request_emission(): assert event.value["function_call"]["arguments"]["subject"] == "Test" +async def test_function_approval_request_with_confirm_changes(): + """Test that confirm_changes is also emitted when require_confirmation=True.""" + bridge = AgentFrameworkEventBridge( + run_id="test_run", + thread_id="test_thread", + require_confirmation=True, + ) + + func_call = FunctionCallContent( + call_id="call_456", + name="delete_file", + arguments={"path": "/tmp/test.txt"}, + ) + approval_request = FunctionApprovalRequestContent( + id="approval_002", + function_call=func_call, + ) + + update = AgentRunResponseUpdate(contents=[approval_request]) + events = await bridge.from_agent_run_update(update) + + # Should emit: ToolCallEndEvent, CustomEvent, and confirm_changes (Start, Args, End) = 5 events + assert len(events) == 5 + + # Check ToolCallEndEvent + assert events[0].type == "TOOL_CALL_END" + assert events[0].tool_call_id == "call_456" + + # Check function_approval_request CustomEvent + assert events[1].type == "CUSTOM" + assert events[1].name == "function_approval_request" + + # Check confirm_changes tool call events + assert events[2].type == "TOOL_CALL_START" + assert events[2].tool_call_name == "confirm_changes" + assert events[3].type == "TOOL_CALL_ARGS" + # Verify confirm_changes includes function info for Dojo UI + import json + + args = json.loads(events[3].delta) + assert args["function_name"] == "delete_file" + assert args["function_call_id"] == "call_456" + assert args["function_arguments"] == {"path": "/tmp/test.txt"} + assert args["steps"] == [ + { + "description": "Execute delete_file", + "status": "enabled", + } + ] + assert events[4].type == "TOOL_CALL_END" + + async def test_multiple_approval_requests(): """Test handling multiple approval requests in one update.""" + # Set require_confirmation=False to simplify the test bridge = AgentFrameworkEventBridge( run_id="test_run", thread_id="test_thread", + require_confirmation=False, ) func_call_1 = FunctionCallContent( @@ -94,3 +150,32 @@ async def test_multiple_approval_requests(): assert events[3].type == "CUSTOM" assert events[3].name == "function_approval_request" assert events[3].value["id"] == "approval_2" + + +async def test_function_approval_request_sets_stop_flag(): + """Test that function approval request sets should_stop_after_confirm flag. + + This ensures the orchestrator stops the run after emitting the approval request, + allowing the UI to send back an approval response. + """ + bridge = AgentFrameworkEventBridge( + run_id="test_run", + thread_id="test_thread", + ) + + assert bridge.should_stop_after_confirm is False + + func_call = FunctionCallContent( + call_id="call_stop_test", + name="get_datetime", + arguments={}, + ) + approval_request = FunctionApprovalRequestContent( + id="approval_stop_test", + function_call=func_call, + ) + + update = AgentRunResponseUpdate(contents=[approval_request]) + await bridge.from_agent_run_update(update) + + assert bridge.should_stop_after_confirm is True diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index 51a51c9fd4..9173314a28 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -2,12 +2,15 @@ """Tests for message adapters.""" +import json + import pytest from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent from agent_framework_ag_ui._message_adapters import ( agent_framework_messages_to_agui, agui_messages_to_agent_framework, + agui_messages_to_snapshot_format, extract_text_from_contents, ) @@ -43,6 +46,32 @@ def test_agent_framework_to_agui_basic(sample_agent_framework_message): assert messages[0]["id"] == "msg-123" +def test_agent_framework_to_agui_normalizes_dict_roles(): + """Dict inputs normalize unknown roles for UI compatibility.""" + messages = [ + {"role": "developer", "content": "policy"}, + {"role": "weird_role", "content": "payload"}, + ] + + converted = agent_framework_messages_to_agui(messages) + + assert converted[0]["role"] == "system" + assert converted[1]["role"] == "user" + + +def test_agui_snapshot_format_normalizes_roles(): + """Snapshot normalization coerces roles into supported AG-UI values.""" + messages = [ + {"role": "Developer", "content": "policy"}, + {"role": "unknown", "content": "payload"}, + ] + + normalized = agui_messages_to_snapshot_format(messages) + + assert normalized[0]["role"] == "system" + assert normalized[1]["role"] == "user" + + def test_agui_tool_result_to_agent_framework(): """Test converting AG-UI tool result message to Agent Framework.""" tool_result_message = { @@ -68,6 +97,237 @@ def test_agui_tool_result_to_agent_framework(): assert message.additional_properties.get("tool_call_id") == "call_123" +def test_agui_tool_approval_updates_tool_call_arguments(): + """Tool approval updates matching tool call arguments for snapshots and agent context.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "generate_task_steps", + "arguments": { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "enabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + }, + }, + } + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps( + { + "accepted": True, + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ], + } + ), + "toolCallId": "call_123", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + assert len(messages) == 2 + assistant_msg = messages[0] + func_call = next(content for content in assistant_msg.contents if isinstance(content, FunctionCallContent)) + assert func_call.arguments == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "disabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "disabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + assert approval_content.function_call.parse_arguments() == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + assert approval_content.additional_properties is not None + assert approval_content.additional_properties.get("ag_ui_state_args") == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "disabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + + +def test_agui_tool_approval_from_confirm_changes_maps_to_function_call(): + """Confirm_changes approvals map back to the original tool call when metadata is present.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_tool", + "type": "function", + "function": {"name": "get_datetime", "arguments": {}}, + }, + { + "id": "call_confirm", + "type": "function", + "function": { + "name": "confirm_changes", + "arguments": {"function_call_id": "call_tool"}, + }, + }, + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps({"accepted": True, "function_call_id": "call_tool"}), + "toolCallId": "call_confirm", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + + assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call.parse_arguments() == {} + assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == {} + + +def test_agui_tool_approval_from_confirm_changes_falls_back_to_sibling_call(): + """Confirm_changes approvals map to the only sibling tool call when metadata is missing.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_tool", + "type": "function", + "function": {"name": "get_datetime", "arguments": {}}, + }, + { + "id": "call_confirm", + "type": "function", + "function": {"name": "confirm_changes", "arguments": {}}, + }, + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps( + { + "accepted": True, + "steps": [{"description": "Approve get_datetime", "status": "enabled"}], + } + ), + "toolCallId": "call_confirm", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + + assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call.parse_arguments() == {} + assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == {} + + +def test_agui_tool_approval_from_generate_task_steps_maps_to_function_call(): + """Approval tool payloads map to the referenced function call when function_call_id is present.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_tool", + "type": "function", + "function": {"name": "get_datetime", "arguments": {}}, + }, + { + "id": "call_steps", + "type": "function", + "function": { + "name": "generate_task_steps", + "arguments": { + "function_name": "get_datetime", + "function_call_id": "call_tool", + "function_arguments": {}, + "steps": [{"description": "Execute get_datetime", "status": "enabled"}], + }, + }, + }, + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps( + { + "accepted": True, + "steps": [{"description": "Execute get_datetime", "status": "enabled"}], + } + ), + "toolCallId": "call_steps", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + + assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call.parse_arguments() == {} + + def test_agui_multiple_messages_to_agent_framework(): """Test converting multiple AG-UI messages.""" messages_input = [ diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py index ba775fa7d9..380ff438bd 100644 --- a/python/packages/ag-ui/tests/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -2,10 +2,7 @@ from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent -from agent_framework_ag_ui._orchestration._message_hygiene import ( - deduplicate_messages, - sanitize_tool_history, -) +from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history def test_sanitize_tool_history_injects_confirm_changes_result() -> None: @@ -26,7 +23,7 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None: ), ] - sanitized = sanitize_tool_history(messages) + sanitized = _sanitize_tool_history(messages) tool_messages = [ msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" @@ -48,6 +45,6 @@ def test_deduplicate_messages_prefers_non_empty_tool_results() -> None: ), ] - deduped = deduplicate_messages(messages) + deduped = _deduplicate_messages(messages) assert len(deduped) == 1 assert deduped[0].contents[0].result == "result data" diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index af90ea2e88..8c00602538 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -42,6 +42,29 @@ async def run_stream( yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") +class RecordingAgent: + """Agent stub that captures messages passed to run_stream.""" + + def __init__(self) -> None: + self.chat_options = SimpleNamespace(tools=[], response_format=None) + self.tools: list[Any] = [] + self.chat_client = SimpleNamespace( + function_invocation_configuration=FunctionInvocationConfiguration(), + ) + self.seen_messages: list[Any] | None = None + + async def run_stream( + self, + messages: list[Any], + *, + thread: Any, + tools: list[Any] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[AgentRunResponseUpdate, None]: + self.seen_messages = messages + yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + + async def test_default_orchestrator_merges_client_tools() -> None: """Client tool declarations are merged with server tools before running agent.""" @@ -151,3 +174,104 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: last_event = events[-1] assert last_event.run_id == "test-snakecase-runid" assert last_event.thread_id == "test-snakecase-threadid" + + +async def test_state_context_injected_when_tool_call_state_mismatch() -> None: + """State context should be injected when current state differs from tool call args.""" + + agent = RecordingAgent() + orchestrator = DefaultOrchestrator() + + tool_recipe = {"title": "Salad", "special_preferences": []} + current_recipe = {"title": "Salad", "special_preferences": ["Vegetarian"]} + + input_data = { + "state": {"recipe": current_recipe}, + "messages": [ + {"role": "system", "content": "Instructions"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "update_recipe", "arguments": {"recipe": tool_recipe}}, + } + ], + }, + {"role": "user", "content": "What are the dietary preferences?"}, + ], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig( + state_schema={"recipe": {"type": "object"}}, + predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}}, + require_confirmation=False, + ), + ) + + async for _event in orchestrator.run(context): + pass + + assert agent.seen_messages is not None + state_messages = [] + for msg in agent.seen_messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value != "system": + continue + for content in msg.contents or []: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + state_messages.append(content.text) + assert state_messages + assert "Vegetarian" in state_messages[0] + + +async def test_state_context_not_injected_when_tool_call_matches_state() -> None: + """State context should be skipped when tool call args match current state.""" + + agent = RecordingAgent() + orchestrator = DefaultOrchestrator() + + input_data = { + "messages": [ + {"role": "system", "content": "Instructions"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "update_recipe", "arguments": {"recipe": {}}}, + } + ], + }, + {"role": "user", "content": "What are the dietary preferences?"}, + ], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig( + state_schema={"recipe": {"type": "object"}}, + predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}}, + require_confirmation=False, + ), + ) + + async for _event in orchestrator.run(context): + pass + + assert agent.seen_messages is not None + state_messages = [] + for msg in agent.seen_messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value != "system": + continue + for content in msg.contents or []: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + state_messages.append(content.text) + assert not state_messages diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 1da11bffbc..041e25c3d2 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -62,7 +62,7 @@ async def test_human_in_the_loop_json_decode_error() -> None: agent=agent, config=AgentConfig(), ) - context.set_messages(messages) + context.set_messages(messages, normalize=False) assert orchestrator.can_handle(context) @@ -385,8 +385,8 @@ async def test_state_context_injection() -> None: assert "banana" in system_messages[0].contents[0].text -async def test_no_state_context_injection_with_tool_calls() -> None: - """Test state context is NOT injected if conversation has tool calls.""" +async def test_state_context_injection_with_tool_calls_and_input_state() -> None: + """Test state context is injected when state is provided, even with tool calls.""" from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent messages = [ @@ -420,13 +420,13 @@ async def test_no_state_context_injection_with_tool_calls() -> None: async for event in orchestrator.run(context): events.append(event) - # Should NOT inject state context system message since conversation has tool calls + # Should inject state context system message because input state is provided system_messages = [ msg for msg in agent.messages_received if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "system" ] - assert len(system_messages) == 0 + assert len(system_messages) == 1 async def test_structured_output_processing() -> None: @@ -685,6 +685,54 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None: assert len(user_messages) == 1 +async def test_confirm_changes_closes_active_message_before_finish() -> None: + """Confirm-changes flow closes any active text message before run finishes.""" + from ag_ui.core import TextMessageEndEvent, TextMessageStartEvent + from agent_framework import FunctionCallContent, FunctionResultContent + + updates = [ + AgentRunResponseUpdate( + contents=[ + FunctionCallContent( + name="write_document_local", + call_id="call_1", + arguments='{"document": "Draft"}', + ) + ] + ), + AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]), + ] + + orchestrator = DefaultOrchestrator() + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Start"}]} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + updates=updates, + ) + context = TestExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig( + predict_state_config={"document": {"tool": "write_document_local", "tool_argument": "document"}}, + require_confirmation=True, + ), + ) + + events: list[Any] = [] + async for event in orchestrator.run(context): + events.append(event) + + start_events = [e for e in events if isinstance(e, TextMessageStartEvent)] + end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] + assert len(start_events) == 1 + assert len(end_events) == 1 + assert end_events[0].message_id == start_events[0].message_id + + end_index = events.index(end_events[0]) + finished_index = events.index([e for e in events if e.type == "RUN_FINISHED"][0]) + assert end_index < finished_index + + async def test_tool_result_kept_when_call_id_matches() -> None: """Test tool result is kept when call_id matches pending tool calls.""" from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent diff --git a/python/pyproject.toml b/python/pyproject.toml index ed98cc8020..cbdf3f0d75 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -145,7 +145,8 @@ ignore = [ "D418", # allow overload to have a docstring "TD003", # allow missing link to todo issue "FIX002", # allow todo - "B027" # allow empty non-abstract method in ABC + "B027", # allow empty non-abstract method in ABC + "RUF067", # allow version detection in __init__.py ] [tool.ruff.lint.per-file-ignores] diff --git a/python/uv.lock b/python/uv.lock index 0ad64ed9ea..15d1fd855e 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -2237,7 +2237,7 @@ wheels = [ [[package]] name = "google-api-core" -version = "2.28.1" +version = "2.29.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -2246,9 +2246,9 @@ dependencies = [ { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/61/da/83d7043169ac2c8c7469f0e375610d78ae2160134bf1b80634c482fa079c/google_api_core-2.28.1.tar.gz", hash = "sha256:2b405df02d68e68ce0fbc138559e6036559e685159d148ae5861013dc201baf8", size = 176759, upload-time = "2025-10-28T21:34:51.529Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/10/05572d33273292bac49c2d1785925f7bc3ff2fe50e3044cf1062c1dde32e/google_api_core-2.29.0.tar.gz", hash = "sha256:84181be0f8e6b04006df75ddfe728f24489f0af57c96a529ff7cf45bc28797f7", size = 177828, upload-time = "2026-01-08T22:21:39.269Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl", hash = "sha256:4021b0f8ceb77a6fb4de6fde4502cecab45062e66ff4f2895169e0b35bc9466c", size = 173706, upload-time = "2025-10-28T21:34:50.151Z" }, + { url = "https://files.pythonhosted.org/packages/77/b6/85c4d21067220b9a78cfb81f516f9725ea6befc1544ec9bd2c1acd97c324/google_api_core-2.29.0-py3-none-any.whl", hash = "sha256:d30bc60980daa36e314b5d5a3e5958b0200cb44ca8fa1be2b614e932b75a3ea9", size = 173906, upload-time = "2026-01-08T22:21:36.093Z" }, ] [[package]] @@ -4277,11 +4277,11 @@ wheels = [ [[package]] name = "pathspec" -version = "1.0.1" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/28/2e/83722ece0f6ee24387d6cb830dd562ddbcd6ce0b9d76072c6849670c31b4/pathspec-1.0.1.tar.gz", hash = "sha256:e2769b508d0dd47b09af6ee2c75b2744a2cb1f474ae4b1494fd6a1b7a841613c", size = 129791, upload-time = "2026-01-06T13:02:55.15Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/b9/6eb731b52f132181a9144bbe77ff82117f6b2d2fbfba49aaab2c014c4760/pathspec-1.0.2.tar.gz", hash = "sha256:fa32b1eb775ed9ba8d599b22c5f906dc098113989da2c00bf8b210078ca7fb92", size = 130502, upload-time = "2026-01-08T04:33:27.613Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/fe/2257c71721aeab6a6e8aa1f00d01f2a20f58547d249a6c8fef5791f559fc/pathspec-1.0.1-py3-none-any.whl", hash = "sha256:8870061f22c58e6d83463cfce9a7dd6eca0512c772c1001fb09ac64091816721", size = 54584, upload-time = "2026-01-06T13:02:53.601Z" }, + { url = "https://files.pythonhosted.org/packages/78/6b/14fc9049d78435fd29e82846c777bd7ed9c470013dc8d0260fff3ff1c11e/pathspec-1.0.2-py3-none-any.whl", hash = "sha256:62f8558917908d237d399b9b338ef455a814801a4688bc41074b25feefd93472", size = 54844, upload-time = "2026-01-08T04:33:26.4Z" }, ] [[package]] @@ -4485,7 +4485,7 @@ wheels = [ [[package]] name = "posthog" -version = "7.5.0" +version = "7.5.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -4495,9 +4495,9 @@ dependencies = [ { name = "six", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/7d/7b81b79ab79de1d47230267a389df127d15761202123c0f705d00621ca61/posthog-7.5.0.tar.gz", hash = "sha256:ae57605508ff16bd5a89f392efb26c88e8f3019db8f35611fd94273bf51048e3", size = 144880, upload-time = "2026-01-07T13:11:52.07Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/3b/866af11cb12e9d35feffcd480d4ebf31f87b2164926b9c670cbdafabc814/posthog-7.5.1.tar.gz", hash = "sha256:d8a8165b3d47465023ea2f919982a34890e2dda76402ec47d6c68424b2534a55", size = 145244, upload-time = "2026-01-08T21:18:39.266Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cd/c7/42c0cf72d37256fed5552517ddcbe549b6b1408f38b78bc6d980a1d06bc2/posthog-7.5.0-py3-none-any.whl", hash = "sha256:e1cba868a804fe1a13d5c0aaf5bab70aa89fd067d73a4046fa9d3699e225c9d0", size = 167271, upload-time = "2026-01-07T13:11:48.948Z" }, + { url = "https://files.pythonhosted.org/packages/1f/03/ba011712ce9d07fe87dcfb72474c388d960e6d0c4f2262d2ae11fd27f0c5/posthog-7.5.1-py3-none-any.whl", hash = "sha256:fd3431ce32c9bbfb1e3775e3633c32ee589c052b0054fafe5ed9e4b17c1969d3", size = 167555, upload-time = "2026-01-08T21:18:37.437Z" }, ] [[package]] @@ -5017,15 +5017,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.407" +version = "1.1.408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/1b/0aa08ee42948b61745ac5b5b5ccaec4669e8884b53d31c8ec20b2fcd6b6f/pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262", size = 4122872, upload-time = "2025-10-24T23:17:15.145Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/93/b69052907d032b00c40cb656d21438ec00b3a471733de137a3f65a49a0a0/pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21", size = 5997008, upload-time = "2025-10-24T23:17:13.159Z" }, + { url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" }, ] [[package]] @@ -5626,28 +5626,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/08/52232a877978dd8f9cf2aeddce3e611b40a63287dfca29b6b8da791f5e8d/ruff-0.14.10.tar.gz", hash = "sha256:9a2e830f075d1a42cd28420d7809ace390832a490ed0966fe373ba288e77aaf4", size = 5859763, upload-time = "2025-12-18T19:28:57.98Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/60/01/933704d69f3f05ee16ef11406b78881733c186fe14b6a46b05cfcaf6d3b2/ruff-0.14.10-py3-none-linux_armv6l.whl", hash = "sha256:7a3ce585f2ade3e1f29ec1b92df13e3da262178df8c8bdf876f48fa0e8316c49", size = 13527080, upload-time = "2025-12-18T19:29:25.642Z" }, - { url = "https://files.pythonhosted.org/packages/df/58/a0349197a7dfa603ffb7f5b0470391efa79ddc327c1e29c4851e85b09cc5/ruff-0.14.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:674f9be9372907f7257c51f1d4fc902cb7cf014b9980152b802794317941f08f", size = 13797320, upload-time = "2025-12-18T19:29:02.571Z" }, - { url = "https://files.pythonhosted.org/packages/7b/82/36be59f00a6082e38c23536df4e71cdbc6af8d7c707eade97fcad5c98235/ruff-0.14.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d85713d522348837ef9df8efca33ccb8bd6fcfc86a2cde3ccb4bc9d28a18003d", size = 12918434, upload-time = "2025-12-18T19:28:51.202Z" }, - { url = "https://files.pythonhosted.org/packages/a6/00/45c62a7f7e34da92a25804f813ebe05c88aa9e0c25e5cb5a7d23dd7450e3/ruff-0.14.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6987ebe0501ae4f4308d7d24e2d0fe3d7a98430f5adfd0f1fead050a740a3a77", size = 13371961, upload-time = "2025-12-18T19:29:04.991Z" }, - { url = "https://files.pythonhosted.org/packages/40/31/a5906d60f0405f7e57045a70f2d57084a93ca7425f22e1d66904769d1628/ruff-0.14.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:16a01dfb7b9e4eee556fbfd5392806b1b8550c9b4a9f6acd3dbe6812b193c70a", size = 13275629, upload-time = "2025-12-18T19:29:21.381Z" }, - { url = "https://files.pythonhosted.org/packages/3e/60/61c0087df21894cf9d928dc04bcd4fb10e8b2e8dca7b1a276ba2155b2002/ruff-0.14.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7165d31a925b7a294465fa81be8c12a0e9b60fb02bf177e79067c867e71f8b1f", size = 14029234, upload-time = "2025-12-18T19:29:00.132Z" }, - { url = "https://files.pythonhosted.org/packages/44/84/77d911bee3b92348b6e5dab5a0c898d87084ea03ac5dc708f46d88407def/ruff-0.14.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c561695675b972effb0c0a45db233f2c816ff3da8dcfbe7dfc7eed625f218935", size = 15449890, upload-time = "2025-12-18T19:28:53.573Z" }, - { url = "https://files.pythonhosted.org/packages/e9/36/480206eaefa24a7ec321582dda580443a8f0671fdbf6b1c80e9c3e93a16a/ruff-0.14.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bb98fcbbc61725968893682fd4df8966a34611239c9fd07a1f6a07e7103d08e", size = 15123172, upload-time = "2025-12-18T19:29:23.453Z" }, - { url = "https://files.pythonhosted.org/packages/5c/38/68e414156015ba80cef5473d57919d27dfb62ec804b96180bafdeaf0e090/ruff-0.14.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f24b47993a9d8cb858429e97bdf8544c78029f09b520af615c1d261bf827001d", size = 14460260, upload-time = "2025-12-18T19:29:27.808Z" }, - { url = "https://files.pythonhosted.org/packages/b3/19/9e050c0dca8aba824d67cc0db69fb459c28d8cd3f6855b1405b3f29cc91d/ruff-0.14.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59aabd2e2c4fd614d2862e7939c34a532c04f1084476d6833dddef4afab87e9f", size = 14229978, upload-time = "2025-12-18T19:29:11.32Z" }, - { url = "https://files.pythonhosted.org/packages/51/eb/e8dd1dd6e05b9e695aa9dd420f4577debdd0f87a5ff2fedda33c09e9be8c/ruff-0.14.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:213db2b2e44be8625002dbea33bb9c60c66ea2c07c084a00d55732689d697a7f", size = 14338036, upload-time = "2025-12-18T19:29:09.184Z" }, - { url = "https://files.pythonhosted.org/packages/6a/12/f3e3a505db7c19303b70af370d137795fcfec136d670d5de5391e295c134/ruff-0.14.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b914c40ab64865a17a9a5b67911d14df72346a634527240039eb3bd650e5979d", size = 13264051, upload-time = "2025-12-18T19:29:13.431Z" }, - { url = "https://files.pythonhosted.org/packages/08/64/8c3a47eaccfef8ac20e0484e68e0772013eb85802f8a9f7603ca751eb166/ruff-0.14.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1484983559f026788e3a5c07c81ef7d1e97c1c78ed03041a18f75df104c45405", size = 13283998, upload-time = "2025-12-18T19:29:06.994Z" }, - { url = "https://files.pythonhosted.org/packages/12/84/534a5506f4074e5cc0529e5cd96cfc01bb480e460c7edf5af70d2bcae55e/ruff-0.14.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c70427132db492d25f982fffc8d6c7535cc2fd2c83fc8888f05caaa248521e60", size = 13601891, upload-time = "2025-12-18T19:28:55.811Z" }, - { url = "https://files.pythonhosted.org/packages/0d/1e/14c916087d8598917dbad9b2921d340f7884824ad6e9c55de948a93b106d/ruff-0.14.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5bcf45b681e9f1ee6445d317ce1fa9d6cba9a6049542d1c3d5b5958986be8830", size = 14336660, upload-time = "2025-12-18T19:29:16.531Z" }, - { url = "https://files.pythonhosted.org/packages/f2/1c/d7b67ab43f30013b47c12b42d1acd354c195351a3f7a1d67f59e54227ede/ruff-0.14.10-py3-none-win32.whl", hash = "sha256:104c49fc7ab73f3f3a758039adea978869a918f31b73280db175b43a2d9b51d6", size = 13196187, upload-time = "2025-12-18T19:29:19.006Z" }, - { url = "https://files.pythonhosted.org/packages/fb/9c/896c862e13886fae2af961bef3e6312db9ebc6adc2b156fe95e615dee8c1/ruff-0.14.10-py3-none-win_amd64.whl", hash = "sha256:466297bd73638c6bdf06485683e812db1c00c7ac96d4ddd0294a338c62fdc154", size = 14661283, upload-time = "2025-12-18T19:29:30.16Z" }, - { url = "https://files.pythonhosted.org/packages/74/31/b0e29d572670dca3674eeee78e418f20bdf97fa8aa9ea71380885e175ca0/ruff-0.14.10-py3-none-win_arm64.whl", hash = "sha256:e51d046cf6dda98a4633b8a8a771451107413b0f07183b2bef03f075599e44e6", size = 13729839, upload-time = "2025-12-18T19:28:48.636Z" }, +version = "0.14.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/77/9a7fe084d268f8855d493e5031ea03fa0af8cc05887f638bf1c4e3363eb8/ruff-0.14.11.tar.gz", hash = "sha256:f6dc463bfa5c07a59b1ff2c3b9767373e541346ea105503b4c0369c520a66958", size = 5993417, upload-time = "2026-01-08T19:11:58.322Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/a6/a4c40a5aaa7e331f245d2dc1ac8ece306681f52b636b40ef87c88b9f7afd/ruff-0.14.11-py3-none-linux_armv6l.whl", hash = "sha256:f6ff2d95cbd335841a7217bdfd9c1d2e44eac2c584197ab1385579d55ff8830e", size = 12951208, upload-time = "2026-01-08T19:12:09.218Z" }, + { url = "https://files.pythonhosted.org/packages/5c/5c/360a35cb7204b328b685d3129c08aca24765ff92b5a7efedbdd6c150d555/ruff-0.14.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f6eb5c1c8033680f4172ea9c8d3706c156223010b8b97b05e82c59bdc774ee6", size = 13330075, upload-time = "2026-01-08T19:12:02.549Z" }, + { url = "https://files.pythonhosted.org/packages/1b/9e/0cc2f1be7a7d33cae541824cf3f95b4ff40d03557b575912b5b70273c9ec/ruff-0.14.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f2fc34cc896f90080fca01259f96c566f74069a04b25b6205d55379d12a6855e", size = 12257809, upload-time = "2026-01-08T19:12:00.366Z" }, + { url = "https://files.pythonhosted.org/packages/a7/e5/5faab97c15bb75228d9f74637e775d26ac703cc2b4898564c01ab3637c02/ruff-0.14.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53386375001773ae812b43205d6064dae49ff0968774e6befe16a994fc233caa", size = 12678447, upload-time = "2026-01-08T19:12:13.899Z" }, + { url = "https://files.pythonhosted.org/packages/1b/33/e9767f60a2bef779fb5855cab0af76c488e0ce90f7bb7b8a45c8a2ba4178/ruff-0.14.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a697737dce1ca97a0a55b5ff0434ee7205943d4874d638fe3ae66166ff46edbe", size = 12758560, upload-time = "2026-01-08T19:11:42.55Z" }, + { url = "https://files.pythonhosted.org/packages/eb/84/4c6cf627a21462bb5102f7be2a320b084228ff26e105510cd2255ea868e5/ruff-0.14.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6845ca1da8ab81ab1dce755a32ad13f1db72e7fba27c486d5d90d65e04d17b8f", size = 13599296, upload-time = "2026-01-08T19:11:30.371Z" }, + { url = "https://files.pythonhosted.org/packages/88/e1/92b5ed7ea66d849f6157e695dc23d5d6d982bd6aa8d077895652c38a7cae/ruff-0.14.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e36ce2fd31b54065ec6f76cb08d60159e1b32bdf08507862e32f47e6dde8bcbf", size = 15048981, upload-time = "2026-01-08T19:12:04.742Z" }, + { url = "https://files.pythonhosted.org/packages/61/df/c1bd30992615ac17c2fb64b8a7376ca22c04a70555b5d05b8f717163cf9f/ruff-0.14.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:590bcc0e2097ecf74e62a5c10a6b71f008ad82eb97b0a0079e85defe19fe74d9", size = 14633183, upload-time = "2026-01-08T19:11:40.069Z" }, + { url = "https://files.pythonhosted.org/packages/04/e9/fe552902f25013dd28a5428a42347d9ad20c4b534834a325a28305747d64/ruff-0.14.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:53fe71125fc158210d57fe4da26e622c9c294022988d08d9347ec1cf782adafe", size = 14050453, upload-time = "2026-01-08T19:11:37.555Z" }, + { url = "https://files.pythonhosted.org/packages/ae/93/f36d89fa021543187f98991609ce6e47e24f35f008dfe1af01379d248a41/ruff-0.14.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a35c9da08562f1598ded8470fcfef2afb5cf881996e6c0a502ceb61f4bc9c8a3", size = 13757889, upload-time = "2026-01-08T19:12:07.094Z" }, + { url = "https://files.pythonhosted.org/packages/b7/9f/c7fb6ecf554f28709a6a1f2a7f74750d400979e8cd47ed29feeaa1bd4db8/ruff-0.14.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:0f3727189a52179393ecf92ec7057c2210203e6af2676f08d92140d3e1ee72c1", size = 13955832, upload-time = "2026-01-08T19:11:55.064Z" }, + { url = "https://files.pythonhosted.org/packages/db/a0/153315310f250f76900a98278cf878c64dfb6d044e184491dd3289796734/ruff-0.14.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:eb09f849bd37147a789b85995ff734a6c4a095bed5fd1608c4f56afc3634cde2", size = 12586522, upload-time = "2026-01-08T19:11:35.356Z" }, + { url = "https://files.pythonhosted.org/packages/2f/2b/a73a2b6e6d2df1d74bf2b78098be1572191e54bec0e59e29382d13c3adc5/ruff-0.14.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c61782543c1231bf71041461c1f28c64b961d457d0f238ac388e2ab173d7ecb7", size = 12724637, upload-time = "2026-01-08T19:11:47.796Z" }, + { url = "https://files.pythonhosted.org/packages/f0/41/09100590320394401cd3c48fc718a8ba71c7ddb1ffd07e0ad6576b3a3df2/ruff-0.14.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:82ff352ea68fb6766140381748e1f67f83c39860b6446966cff48a315c3e2491", size = 13145837, upload-time = "2026-01-08T19:11:32.87Z" }, + { url = "https://files.pythonhosted.org/packages/3b/d8/e035db859d1d3edf909381eb8ff3e89a672d6572e9454093538fe6f164b0/ruff-0.14.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:728e56879df4ca5b62a9dde2dd0eb0edda2a55160c0ea28c4025f18c03f86984", size = 13850469, upload-time = "2026-01-08T19:12:11.694Z" }, + { url = "https://files.pythonhosted.org/packages/4e/02/bb3ff8b6e6d02ce9e3740f4c17dfbbfb55f34c789c139e9cd91985f356c7/ruff-0.14.11-py3-none-win32.whl", hash = "sha256:337c5dd11f16ee52ae217757d9b82a26400be7efac883e9e852646f1557ed841", size = 12851094, upload-time = "2026-01-08T19:11:45.163Z" }, + { url = "https://files.pythonhosted.org/packages/58/f1/90ddc533918d3a2ad628bc3044cdfc094949e6d4b929220c3f0eb8a1c998/ruff-0.14.11-py3-none-win_amd64.whl", hash = "sha256:f981cea63d08456b2c070e64b79cb62f951aa1305282974d4d5216e6e0178ae6", size = 14001379, upload-time = "2026-01-08T19:11:52.591Z" }, + { url = "https://files.pythonhosted.org/packages/c4/1c/1dbe51782c0e1e9cfce1d1004752672d2d4629ea46945d19d731ad772b3b/ruff-0.14.11-py3-none-win_arm64.whl", hash = "sha256:649fb6c9edd7f751db276ef42df1f3df41c38d67d199570ae2a7bd6cbc3590f0", size = 12938644, upload-time = "2026-01-08T19:11:50.027Z" }, ] [[package]] @@ -6764,14 +6764,14 @@ wheels = [ [[package]] name = "werkzeug" -version = "3.1.4" +version = "3.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/45/ea/b0f8eeb287f8df9066e56e831c7824ac6bab645dd6c7a8f4b2d767944f9b/werkzeug-3.1.4.tar.gz", hash = "sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e", size = 864687, upload-time = "2025-11-29T02:15:22.841Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" }, ] [[package]]