Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,12 @@ async def _handle_mcp_tool_invocation(

# Create or parse session ID
if thread_id and isinstance(thread_id, str) and thread_id.strip():
try:
session_id = AgentSessionId.parse(thread_id)
except ValueError as e:
logger.warning(
"Failed to parse AgentSessionId from thread_id '%s': %s. Falling back to new session ID.",
thread_id,
e,
)
# If thread_id is in @name@key format, extract only the key portion
if thread_id.startswith("@") and "@" in thread_id[1:]:
key = thread_id[1:].split("@", 1)[1]
session_id = AgentSessionId(name=agent_name, key=key)
else:
# Use thread_id as-is for the key
session_id = AgentSessionId(name=agent_name, key=thread_id)
else:
# Generate new session ID
Expand Down
34 changes: 34 additions & 0 deletions python/packages/azurefunctions/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,40 @@ async def test_handle_mcp_tool_invocation_runtime_error(self) -> None:
with pytest.raises(RuntimeError, match="Agent execution failed"):
await app._handle_mcp_tool_invocation("TestAgent", context, client)

async def test_handle_mcp_tool_invocation_ignores_agent_name_in_thread_id(self) -> None:
"""Test that MCP tool invocation uses the agent_name parameter, not the name from thread_id."""
mock_agent = Mock()
mock_agent.name = "PlantAdvisor"

app = AgentFunctionApp(agents=[mock_agent])
client = AsyncMock()

# Mock the entity response
mock_state = Mock()
mock_state.entity_state = {
"schemaVersion": "1.0.0",
"data": {"conversationHistory": []},
}
client.read_entity_state.return_value = mock_state

# Thread ID contains a different agent name (@StockAdvisor@poc123)
# but we're invoking PlantAdvisor - it should use PlantAdvisor's entity
context = json.dumps({"arguments": {"query": "test query", "threadId": "@StockAdvisor@test123"}})

with patch.object(app, "_get_response_from_entity") as get_response_mock:
get_response_mock.return_value = {"status": "success", "response": "Test response"}

await app._handle_mcp_tool_invocation("PlantAdvisor", context, client)

# Verify signal_entity was called with PlantAdvisor's entity, not StockAdvisor's
client.signal_entity.assert_called_once()
call_args = client.signal_entity.call_args
entity_id = call_args[0][0]

# Entity name should be dafx-PlantAdvisor, not dafx-StockAdvisor
assert entity_id.name == "dafx-PlantAdvisor"
assert entity_id.key == "test123"

def test_health_check_includes_mcp_tool_enabled(self) -> None:
"""Test that health check endpoint includes mcp_tool_enabled field."""
mock_agent = Mock()
Expand Down
Loading