diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index eec6bb646b..e32a88c3f6 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -14,6 +14,7 @@ from __future__ import annotations +import copy import json import logging import os @@ -350,7 +351,24 @@ def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" async def _get_a2a_runner_async() -> Runner: - return await adk_web_server.get_runner_async(captured_app_name) + original_runner = await adk_web_server.get_runner_async( + captured_app_name + ) + kwargs = {} + if original_runner.app: + kwargs["app"] = original_runner.app + else: + kwargs["app_name"] = original_runner.app_name + kwargs["agent"] = original_runner.agent + + runner = Runner( + session_service=InMemorySessionService(), + artifact_service=InMemoryArtifactService(), + memory_service=InMemoryMemoryService(), + credential_service=InMemoryCredentialService(), + **kwargs, + ) + return runner return _get_a2a_runner_async diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d50bfcd8e5..0182663369 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -30,6 +30,8 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService from google.adk.cli.fast_api import get_fast_api_app from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import Invocation @@ -38,6 +40,7 @@ from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session @@ -991,6 +994,103 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) +def test_a2a_runner_factory_creates_isolated_runner(temp_agents_dir_with_a2a): + """Verify the A2A runner factory creates a copy of the runner with in-memory services.""" + # 1. Setup Mocks for the original runner and its services + original_runner = Runner( + agent=MagicMock(), + app_name="test_app", + session_service=MagicMock(), + ) + original_runner.memory_service = MagicMock() + original_runner.artifact_service = MagicMock() + original_runner.credential_service = MagicMock() + + # Mock the AdkWebServer to control the runner it returns + mock_web_server_instance = MagicMock() + mock_web_server_instance.get_runner_async = AsyncMock( + return_value=original_runner + ) + # The factory captures the app_name, so we need to mock list_agents + mock_web_server_instance.list_agents.return_value = ["test_a2a_agent"] + + # 2. Patch dependencies in the fast_api module + with ( + patch("google.adk.cli.fast_api.AdkWebServer") as mock_web_server, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" + ) as mock_executor, + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.types.AgentCard") as mock_agent_card, + patch("a2a.utils.constants.AGENT_CARD_WELL_KNOWN_PATH", "/agent.json"), + ): + mock_web_server.return_value = mock_web_server_instance + mock_task_store.return_value = MagicMock() + mock_executor.return_value = MagicMock() + mock_handler.return_value = MagicMock() + mock_agent_card.return_value = MagicMock() + + # Change to temp directory + original_cwd = os.getcwd() + os.chdir(temp_agents_dir_with_a2a) + try: + # 3. Call get_fast_api_app to trigger the factory creation + get_fast_api_app( + agents_dir=".", + web=False, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=[], + a2a=True, # Enable A2A to create the factory + host="127.0.0.1", + port=8000, + ) + finally: + os.chdir(original_cwd) + + # 4. Capture the factory from the mocked A2aAgentExecutor + assert mock_executor.call_args is not None, "A2aAgentExecutor not called" + kwargs = mock_executor.call_args.kwargs + assert "runner" in kwargs + runner_factory = kwargs["runner"] + + # 5. Execute the factory to get the new runner + # Since runner_factory is an async function, we need to run it. + # We run it in a separate thread to avoid event loop conflicts if an event loop is already running. + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=1) as executor: + a2a_runner = executor.submit(asyncio.run, runner_factory()).result() + + # 6. Assert that the new runner is a separate, modified copy + assert a2a_runner is not original_runner, "Runner should be a copy" + + # Assert that services have been replaced with InMemory versions + assert isinstance(a2a_runner.memory_service, InMemoryMemoryService) + assert isinstance(a2a_runner.session_service, InMemorySessionService) + assert isinstance(a2a_runner.artifact_service, InMemoryArtifactService) + assert isinstance(a2a_runner.credential_service, InMemoryCredentialService) + + # Assert that the original runner's services are unchanged + assert not isinstance(original_runner.memory_service, InMemoryMemoryService) + assert not isinstance( + original_runner.session_service, InMemorySessionService + ) + assert not isinstance( + original_runner.artifact_service, InMemoryArtifactService + ) + assert not isinstance( + original_runner.credential_service, InMemoryCredentialService + ) + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="A2A requires Python 3.10+" )