Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,27 +1536,40 @@ def _merge_and_filter_chat_middleware(
return middleware["chat"] # type: ignore[return-value]


def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None:
def extract_and_merge_function_middleware(
chat_client: Any, kwargs: dict[str, Any]
) -> "FunctionMiddlewarePipeline | None":
"""Extract function middleware from chat client and merge with existing pipeline in kwargs.

Args:
chat_client: The chat client instance to extract middleware from.
kwargs: Dictionary containing middleware and pipeline information.

Keyword Args:
**kwargs: Dictionary containing middleware and pipeline information.
Returns:
A FunctionMiddlewarePipeline if function middleware is found, None otherwise.
"""
# Check if a pipeline was already created by use_chat_middleware
existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline")

# Get middleware sources
client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None
run_level_middleware = kwargs.get("middleware")
existing_pipeline = kwargs.get("_function_middleware_pipeline")

# Extract existing pipeline middlewares if present
existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None
# If we have an existing pipeline but no additional middleware sources, return it directly
if existing_pipeline and not client_middleware and not run_level_middleware:
return existing_pipeline

# If we have an existing pipeline with additional middleware, we need to merge
# Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility
existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None

# Create combined pipeline from all sources using existing helper
combined_pipeline = create_function_middleware_pipeline(
client_middleware, run_level_middleware, existing_middlewares
)

if combined_pipeline:
kwargs["_function_middleware_pipeline"] = combined_pipeline
# If we have an existing pipeline but combined is None (no new middlewares), return existing
if existing_pipeline and combined_pipeline is None:
return existing_pipeline

return combined_pipeline
43 changes: 31 additions & 12 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,7 @@ async def final_function_handler(context_obj: Any) -> Any:
return FunctionResultContent(
call_id=function_call_content.call_id,
result=function_result,
terminate_loop=middleware_context.terminate,
)
except Exception as exc:
message = "Error: Function failed."
Expand Down Expand Up @@ -1695,12 +1696,8 @@ async def function_invocation_wrapper(
prepare_messages,
)

# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, **kwargs)

# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Extract and merge function middleware from chat client with kwargs
stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs)

# Get the config for function invocation (not part of ChatClientProtocol, hence getattr)
config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None)
Expand Down Expand Up @@ -1798,6 +1795,21 @@ async def function_invocation_wrapper(
# the function calls are already in the response, so we just continue
return response

# Check if any function result signals loop termination (middleware set context.terminate=True)
# This allows middleware to short-circuit the tool loop without another LLM call
if any(
getattr(fcr, "terminate_loop", False)
for fcr in function_call_results
if isinstance(fcr, FunctionResultContent)
):
# Add tool results to response and return immediately without calling LLM again
result_message = ChatMessage(role="tool", contents=function_call_results)
response.messages.append(result_message)
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
return response

if any(
fcr.exception is not None
for fcr in function_call_results
Expand Down Expand Up @@ -1890,12 +1902,8 @@ async def streaming_function_invocation_wrapper(
prepare_messages,
)

# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, **kwargs)

# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Extract and merge function middleware from chat client with kwargs
stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs)

# Get the config for function invocation (not part of ChatClientProtocol, hence getattr)
config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None)
Expand Down Expand Up @@ -2005,6 +2013,17 @@ async def streaming_function_invocation_wrapper(
# the function calls were already yielded.
return

# Check if any function result signals loop termination (middleware set context.terminate=True)
# This allows middleware to short-circuit the tool loop without another LLM call
if any(
getattr(fcr, "terminate_loop", False)
for fcr in function_call_results
if isinstance(fcr, FunctionResultContent)
):
# Yield tool results and return immediately without calling LLM again
yield ChatResponseUpdate(contents=function_call_results, role="tool")
return

if any(
fcr.exception is not None
for fcr in function_call_results
Expand Down
6 changes: 6 additions & 0 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,9 @@ class FunctionResultContent(BaseContent):
call_id: The identifier of the function call for which this is the result.
result: The result of the function call, or a generic error message if the function call failed.
exception: An exception that occurred if the function call failed.
terminate_loop: If True, signals that the function invocation loop should terminate
immediately without calling the LLM again. This is set when middleware sets
context.terminate=True during function execution.
type: The type of content, which is always "function_result" for this class.
annotations: Optional annotations associated with the content.
additional_properties: Optional additional properties associated with the content.
Expand Down Expand Up @@ -1447,6 +1450,7 @@ def __init__(
call_id: str,
result: Any | None = None,
exception: Exception | None = None,
terminate_loop: bool = False,
annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None,
additional_properties: dict[str, Any] | None = None,
raw_representation: Any | None = None,
Expand All @@ -1458,6 +1462,7 @@ def __init__(
call_id: The identifier of the function call for which this is the result.
result: The result of the function call, or a generic error message if the function call failed.
exception: An exception that occurred if the function call failed.
terminate_loop: If True, signals the function invocation loop to terminate immediately.
annotations: Optional annotations associated with the content.
additional_properties: Optional additional properties associated with the content.
raw_representation: Optional raw representation of the content.
Expand All @@ -1472,6 +1477,7 @@ def __init__(
self.call_id = call_id
self.result = result
self.exception = exception
self.terminate_loop = terminate_loop
self.type: Literal["function_result"] = "function_result"


Expand Down
175 changes: 175 additions & 0 deletions python/packages/core/tests/core/test_function_invocation_logic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from collections.abc import Awaitable, Callable

import pytest

from agent_framework import (
Expand All @@ -16,6 +18,7 @@
TextContent,
ai_function,
)
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware


async def test_base_client_with_function_calling(chat_client_base: ChatClientProtocol):
Expand Down Expand Up @@ -2206,3 +2209,175 @@ def sometimes_fails(arg1: str) -> str:
assert len(error_results) >= 1
assert len(success_results) >= 1
assert call_count == 2 # Both calls executed


class TerminateLoopMiddleware(FunctionMiddleware):
"""Middleware that sets terminate=True to exit the function calling loop."""

async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
context.terminate = True


async def test_terminate_loop_single_function_call(chat_client_base: ChatClientProtocol):
"""Test that terminate_loop=True exits the function calling loop after single function call."""
exec_counter = 0

@ai_function(name="test_function")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}"

# Queue up two responses: function call, then final text
# If terminate_loop works, only the first response should be consumed
chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]

response = await chat_client_base.get_response(
"hello",
tool_choice="auto",
tools=[ai_func],
middleware=[TerminateLoopMiddleware()],
)

# Function should NOT have been executed - middleware intercepted it
assert exec_counter == 0

# There should be 2 messages: assistant with function call, tool result from middleware
# The loop should NOT have continued to call the LLM again
assert len(response.messages) == 2
assert response.messages[0].role == Role.ASSISTANT
assert isinstance(response.messages[0].contents[0], FunctionCallContent)
assert response.messages[1].role == Role.TOOL
assert isinstance(response.messages[1].contents[0], FunctionResultContent)
assert response.messages[1].contents[0].result == "terminated by middleware"

# Verify the second response is still in the queue (wasn't consumed)
assert len(chat_client_base.run_responses) == 1


class SelectiveTerminateMiddleware(FunctionMiddleware):
"""Only terminates for terminating_function."""

async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
if context.function.name == "terminating_function":
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
context.terminate = True
else:
await next_handler(context)


async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: ChatClientProtocol):
"""Test that any(terminate_loop=True) exits loop even with multiple function calls."""
normal_call_count = 0
terminating_call_count = 0

@ai_function(name="normal_function")
def normal_func(arg1: str) -> str:
nonlocal normal_call_count
normal_call_count += 1
return f"Normal {arg1}"

@ai_function(name="terminating_function")
def terminating_func(arg1: str) -> str:
nonlocal terminating_call_count
terminating_call_count += 1
return f"Terminating {arg1}"

# Queue up two responses: parallel function calls, then final text
chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[
FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'),
FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'),
],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]

response = await chat_client_base.get_response(
"hello",
tool_choice="auto",
tools=[normal_func, terminating_func],
middleware=[SelectiveTerminateMiddleware()],
)

# normal_function should have executed (middleware calls next_handler)
# terminating_function should NOT have executed (middleware intercepts it)
assert normal_call_count == 1
assert terminating_call_count == 0

# There should be 2 messages: assistant with function calls, tool results
# The loop should NOT have continued to call the LLM again
assert len(response.messages) == 2
assert response.messages[0].role == Role.ASSISTANT
assert len(response.messages[0].contents) == 2
assert response.messages[1].role == Role.TOOL
# Both function results should be present
assert len(response.messages[1].contents) == 2

# Verify the second response is still in the queue (wasn't consumed)
assert len(chat_client_base.run_responses) == 1


async def test_terminate_loop_streaming_single_function_call(chat_client_base: ChatClientProtocol):
"""Test that terminate_loop=True exits the streaming function calling loop."""
exec_counter = 0

@ai_function(name="test_function")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}"

# Queue up two streaming responses
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
role="assistant",
),
],
[
ChatResponseUpdate(
contents=[TextContent(text="done")],
role="assistant",
)
],
]

updates = []
async for update in chat_client_base.get_streaming_response(
"hello",
tool_choice="auto",
tools=[ai_func],
middleware=[TerminateLoopMiddleware()],
):
updates.append(update)

# Function should NOT have been executed - middleware intercepted it
assert exec_counter == 0

# Should have function call update and function result update
# The loop should NOT have continued to call the LLM again
assert len(updates) == 2

# Verify the second streaming response is still in the queue (wasn't consumed)
assert len(chat_client_base.streaming_responses) == 1
Loading