diff --git a/README.md b/README.md index 1abbb742b9..ca0655f579 100644 --- a/README.md +++ b/README.md @@ -808,10 +808,21 @@ Request additional information from users. This example shows an Elicitation dur ```python +"""Elicitation examples demonstrating form and URL mode elicitation. + +Form mode elicitation collects structured, non-sensitive data through a schema. +URL mode elicitation directs users to external URLs for sensitive operations +like OAuth flows, credential collection, or payment processing. +""" + +import uuid + from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ElicitRequestURLParams mcp = FastMCP(name="Elicitation Example") @@ -828,7 +839,10 @@ class BookingPreferences(BaseModel): @mcp.tool() async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: - """Book a table with date availability check.""" + """Book a table with date availability check. + + This demonstrates form mode elicitation for collecting non-sensitive user input. + """ # Check if date is available if date == "2024-12-25": # Date unavailable - ask user for alternative @@ -845,6 +859,54 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS # Date available return f"[SUCCESS] Booked for {date} at {time}" + + +@mcp.tool() +async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: + """Process a secure payment requiring URL confirmation. + + This demonstrates URL mode elicitation using ctx.elicit_url() for + operations that require out-of-band user interaction. + """ + elicitation_id = str(uuid.uuid4()) + + result = await ctx.elicit_url( + message=f"Please confirm payment of ${amount:.2f}", + url=f"https://payments.example.com/confirm?amount={amount}&id={elicitation_id}", + elicitation_id=elicitation_id, + ) + + if result.action == "accept": + # In a real app, the payment confirmation would happen out-of-band + # and you'd verify the payment status from your backend + return f"Payment of ${amount:.2f} initiated - check your browser to complete" + elif result.action == "decline": + return "Payment declined by user" + return "Payment cancelled" + + +@mcp.tool() +async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + """Connect to a third-party service requiring OAuth authorization. + + This demonstrates the "throw error" pattern using UrlElicitationRequiredError. + Use this pattern when the tool cannot proceed without user authorization. + """ + elicitation_id = str(uuid.uuid4()) + + # Raise UrlElicitationRequiredError to signal that the client must complete + # a URL elicitation before this request can be processed. + # The MCP framework will convert this to a -32042 error response. + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message=f"Authorization required to connect to {service_name}", + url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", + elicitationId=elicitation_id, + ) + ] + ) ``` _Full example: [examples/snippets/servers/elicitation.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/elicitation.py)_ diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index f42479af53..38dc5a9167 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -150,9 +150,15 @@ def get_state(self): class SimpleAuthClient: """Simple MCP client with auth support.""" - def __init__(self, server_url: str, transport_type: str = "streamable-http"): + def __init__( + self, + server_url: str, + transport_type: str = "streamable-http", + client_metadata_url: str | None = None, + ): self.server_url = server_url self.transport_type = transport_type + self.client_metadata_url = client_metadata_url self.session: ClientSession | None = None async def connect(self): @@ -185,12 +191,14 @@ async def _default_redirect_handler(authorization_url: str) -> None: webbrowser.open(authorization_url) # Create OAuth authentication handler using the new interface + # Use client_metadata_url to enable CIMD when the server supports it oauth_auth = OAuthClientProvider( server_url=self.server_url, client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), storage=InMemoryTokenStorage(), redirect_handler=_default_redirect_handler, callback_handler=callback_handler, + client_metadata_url=self.client_metadata_url, ) # Create transport with auth handler based on transport type @@ -334,6 +342,7 @@ async def main(): # Most MCP streamable HTTP servers use /mcp as the endpoint server_url = os.getenv("MCP_SERVER_PORT", 8000) transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable-http") + client_metadata_url = os.getenv("MCP_CLIENT_METADATA_URL") server_url = ( f"http://localhost:{server_url}/mcp" if transport_type == "streamable-http" @@ -343,9 +352,11 @@ async def main(): print("🚀 Simple MCP Auth Client") print(f"Connecting to: {server_url}") print(f"Transport type: {transport_type}") + if client_metadata_url: + print(f"Client metadata URL: {client_metadata_url}") # Start connection flow - OAuth will be handled automatically - client = SimpleAuthClient(server_url, transport_type) + client = SimpleAuthClient(server_url, transport_type, client_metadata_url) await client.connect() diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index eb632b4d63..4958341c7e 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -14,6 +14,13 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.prompts.base import UserMessage from mcp.server.session import ServerSession +from mcp.server.streamable_http import ( + EventCallback, + EventId, + EventMessage, + EventStore, + StreamId, +) from mcp.types import ( AudioContent, Completion, @@ -21,6 +28,7 @@ CompletionContext, EmbeddedResource, ImageContent, + JSONRPCMessage, PromptReference, ResourceTemplateReference, SamplingMessage, @@ -39,8 +47,47 @@ resource_subscriptions: set[str] = set() watched_resource_content = "Watched resource content" + +# Simple in-memory event store for SSE polling resumability (SEP-1699) +class SimpleEventStore(EventStore): + """Simple in-memory event store for testing resumability.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, EventId, JSONRPCMessage]] = [] + self._event_id_counter = 0 + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: + """Store an event and return its ID.""" + self._event_id_counter += 1 + event_id = str(self._event_id_counter) + self._events.append((stream_id, event_id, message)) + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replay events after the specified ID.""" + target_stream_id = None + found = False + for stream_id, event_id, message in self._events: + if event_id == last_event_id: + target_stream_id = stream_id + found = True + continue + if found and stream_id == target_stream_id: + await send_callback(EventMessage(message=message, event_id=event_id)) + return target_stream_id + + +# Create event store for resumability +event_store = SimpleEventStore() + mcp = FastMCP( name="mcp-conformance-test-server", + event_store=event_store, + sse_retry_interval=3000, # 3 seconds ) @@ -257,6 +304,31 @@ async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> s return f"Elicitation not supported or error: {str(e)}" +@mcp.tool() +async def test_reconnection(ctx: Context[ServerSession, None]) -> str: + """Tests SSE polling via server-initiated disconnect (SEP-1699) + + This tool closes the SSE stream mid-call, requiring the client to reconnect + with Last-Event-ID to receive the remaining events. + """ + # Send notification before disconnect + await ctx.info("Notification before disconnect") + + # Use the close_sse_stream callback if available + # This is None if not on streamable HTTP transport or no event store configured + if ctx.close_sse_stream: + # Trigger server-initiated SSE disconnect with optional retry interval + await ctx.close_sse_stream(retry_interval=3000) # 3 seconds + + # Wait for client to reconnect + await asyncio.sleep(0.2) + + # Send notification after disconnect (will be replayed via event store) + await ctx.info("Notification after disconnect") + + return "Reconnection test completed successfully" + + @mcp.tool() def test_error_handling() -> str: """Tests error response handling""" diff --git a/examples/snippets/clients/sse_polling_client.py b/examples/snippets/clients/sse_polling_client.py new file mode 100644 index 0000000000..e8a0d0ce8f --- /dev/null +++ b/examples/snippets/clients/sse_polling_client.py @@ -0,0 +1,103 @@ +""" +SSE Polling Example Client + +Demonstrates client-side behavior during server-initiated SSE disconnect. + +Key features: +- Automatic reconnection when server closes SSE stream +- Event replay via Last-Event-ID header (handled internally by the transport) +- Progress notifications via logging callback + +This client connects to the SSE polling server and calls the `long-task` tool. +The server disconnects at 50% progress, and the client automatically reconnects +to receive remaining progress updates. + +Run: + # First start the server: + uv run examples/snippets/servers/sse_polling_server.py + + # Then run this client: + uv run examples/snippets/clients/sse_polling_client.py +""" + +import asyncio +import logging + +from mcp import ClientSession +from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client +from mcp.types import LoggingMessageNotificationParams, TextContent + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +async def main() -> None: + print("SSE Polling Example Client") + print("=" * 50) + print() + + # Track notifications received via the logging callback + notifications_received: list[str] = [] + + async def logging_callback(params: LoggingMessageNotificationParams) -> None: + """Called when a log message notification is received from the server.""" + data = params.data + if data: + data_str = str(data) + notifications_received.append(data_str) + print(f"[Progress] {data_str}") + + # Configure reconnection behavior + reconnection_options = StreamableHTTPReconnectionOptions( + initial_reconnection_delay=1.0, # Start with 1 second + max_reconnection_delay=30.0, # Cap at 30 seconds + reconnection_delay_grow_factor=1.5, # Exponential backoff + max_retries=5, # Try up to 5 times + ) + + print("[Client] Connecting to server...") + + async with streamablehttp_client( + "http://localhost:3001/mcp", + reconnection_options=reconnection_options, + ) as (read_stream, write_stream, get_session_id): + # Create session with logging callback to receive progress notifications + async with ClientSession( + read_stream, + write_stream, + logging_callback=logging_callback, + ) as session: + # Initialize the session + await session.initialize() + session_id = get_session_id() + print(f"[Client] Connected! Session ID: {session_id}") + + # List available tools + tools = await session.list_tools() + tool_names = [t.name for t in tools.tools] + print(f"[Client] Available tools: {tool_names}") + print() + + # Call the long-running task + print("[Client] Calling long-task tool...") + print("[Client] The server will disconnect at 50% and we'll auto-reconnect") + print() + + # Call the tool + result = await session.call_tool("long-task", {}) + + print() + print("[Client] Task completed!") + if result.content and isinstance(result.content[0], TextContent): + print(f"[Result] {result.content[0].text}") + else: + print("[Result] No content") + print() + print(f"[Summary] Received {len(notifications_received)} progress notifications") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py new file mode 100644 index 0000000000..56457512c6 --- /dev/null +++ b/examples/snippets/clients/url_elicitation_client.py @@ -0,0 +1,318 @@ +"""URL Elicitation Client Example. + +Demonstrates how clients handle URL elicitation requests from servers. +This is the Python equivalent of TypeScript SDK's elicitationUrlExample.ts, +focused on URL elicitation patterns without OAuth complexity. + +Features demonstrated: +1. Client elicitation capability declaration +2. Handling elicitation requests from servers via callback +3. Catching UrlElicitationRequiredError from tool calls +4. Browser interaction with security warnings +5. Interactive CLI for testing + +Run with: + cd examples/snippets + uv run elicitation-client + +Requires a server with URL elicitation tools running. Start the elicitation +server first: + uv run server elicitation sse +""" + +from __future__ import annotations + +import asyncio +import json +import subprocess +import sys +import webbrowser +from typing import Any +from urllib.parse import urlparse + +from mcp import ClientSession, types +from mcp.client.sse import sse_client +from mcp.shared.context import RequestContext +from mcp.shared.exceptions import McpError, UrlElicitationRequiredError +from mcp.types import URL_ELICITATION_REQUIRED + + +async def handle_elicitation( + context: RequestContext[ClientSession, Any], + params: types.ElicitRequestParams, +) -> types.ElicitResult | types.ErrorData: + """Handle elicitation requests from the server. + + This callback is invoked when the server sends an elicitation/request. + For URL mode, we prompt the user and optionally open their browser. + """ + if params.mode == "url": + return await handle_url_elicitation(params) + else: + # We only support URL mode in this example + return types.ErrorData( + code=types.INVALID_REQUEST, + message=f"Unsupported elicitation mode: {params.mode}", + ) + + +async def handle_url_elicitation( + params: types.ElicitRequestParams, +) -> types.ElicitResult: + """Handle URL mode elicitation - show security warning and optionally open browser. + + This function demonstrates the security-conscious approach to URL elicitation: + 1. Display the full URL and domain for user inspection + 2. Show the server's reason for requesting this interaction + 3. Require explicit user consent before opening any URL + """ + # Extract URL parameters - these are available on URL mode requests + url = getattr(params, "url", None) + elicitation_id = getattr(params, "elicitationId", None) + message = params.message + + if not url: + print("Error: No URL provided in elicitation request") + return types.ElicitResult(action="cancel") + + # Extract domain for security display + domain = extract_domain(url) + + # Security warning - always show the user what they're being asked to do + print("\n" + "=" * 60) + print("SECURITY WARNING: External URL Request") + print("=" * 60) + print("\nThe server is requesting you to open an external URL.") + print(f"\n Domain: {domain}") + print(f" Full URL: {url}") + print("\n Server's reason:") + print(f" {message}") + print(f"\n Elicitation ID: {elicitation_id}") + print("\n" + "-" * 60) + + # Get explicit user consent + try: + response = input("\nOpen this URL in your browser? (y/n): ").strip().lower() + except EOFError: + return types.ElicitResult(action="cancel") + + if response in ("n", "no"): + print("URL navigation declined.") + return types.ElicitResult(action="decline") + elif response not in ("y", "yes"): + print("Invalid response. Cancelling.") + return types.ElicitResult(action="cancel") + + # Open the browser + print(f"\nOpening browser to: {url}") + open_browser(url) + + print("Waiting for you to complete the interaction in your browser...") + print("(The server will continue once you've finished)") + + return types.ElicitResult(action="accept") + + +def extract_domain(url: str) -> str: + """Extract domain from URL for security display.""" + try: + return urlparse(url).netloc + except Exception: + return "unknown" + + +def open_browser(url: str) -> None: + """Open URL in the default browser.""" + try: + if sys.platform == "darwin": + subprocess.run(["open", url], check=False) + elif sys.platform == "win32": + subprocess.run(["start", url], shell=True, check=False) + else: + webbrowser.open(url) + except Exception as e: + print(f"Failed to open browser: {e}") + print(f"Please manually open: {url}") + + +async def call_tool_with_error_handling( + session: ClientSession, + tool_name: str, + arguments: dict[str, Any], +) -> types.CallToolResult | None: + """Call a tool, handling UrlElicitationRequiredError if raised. + + When a server tool needs URL elicitation before it can proceed, + it can either: + 1. Send an elicitation request directly (handled by elicitation_callback) + 2. Return an error with code -32042 (URL_ELICITATION_REQUIRED) + + This function demonstrates handling case 2 - catching the error + and processing the required URL elicitations. + """ + try: + result = await session.call_tool(tool_name, arguments) + + # Check if the tool returned an error in the result + if result.isError: + print(f"Tool returned error: {result.content}") + return None + + return result + + except McpError as e: + # Check if this is a URL elicitation required error + if e.error.code == URL_ELICITATION_REQUIRED: + print("\n[Tool requires URL elicitation to proceed]") + + # Convert to typed error to access elicitations + url_error = UrlElicitationRequiredError.from_error(e.error) + + # Process each required elicitation + for elicitation in url_error.elicitations: + await handle_url_elicitation(elicitation) + + return None + else: + # Re-raise other MCP errors + print(f"MCP Error: {e.error.message} (code: {e.error.code})") + return None + + +def print_help() -> None: + """Print available commands.""" + print("\nAvailable commands:") + print(" list-tools - List available tools") + print(" call [json-args] - Call a tool with optional JSON arguments") + print(" secure-payment - Test URL elicitation via ctx.elicit_url()") + print(" connect-service - Test URL elicitation via UrlElicitationRequiredError") + print(" help - Show this help") + print(" quit - Exit the program") + + +def print_tool_result(result: types.CallToolResult | None) -> None: + """Print a tool call result.""" + if not result: + return + print("\nTool result:") + for content in result.content: + if isinstance(content, types.TextContent): + print(f" {content.text}") + else: + print(f" [{content.type}]") + + +async def handle_list_tools(session: ClientSession) -> None: + """Handle the list-tools command.""" + tools = await session.list_tools() + if tools.tools: + print("\nAvailable tools:") + for tool in tools.tools: + print(f" - {tool.name}: {tool.description or 'No description'}") + else: + print("No tools available") + + +async def handle_call_command(session: ClientSession, command: str) -> None: + """Handle the call command.""" + parts = command.split(maxsplit=2) + if len(parts) < 2: + print("Usage: call [json-args]") + return + + tool_name = parts[1] + args: dict[str, Any] = {} + if len(parts) > 2: + try: + args = json.loads(parts[2]) + except json.JSONDecodeError as e: + print(f"Invalid JSON arguments: {e}") + return + + print(f"\nCalling tool '{tool_name}' with args: {args}") + result = await call_tool_with_error_handling(session, tool_name, args) + print_tool_result(result) + + +async def process_command(session: ClientSession, command: str) -> bool: + """Process a single command. Returns False if should exit.""" + if command in {"quit", "exit"}: + print("Goodbye!") + return False + + if command == "help": + print_help() + elif command == "list-tools": + await handle_list_tools(session) + elif command.startswith("call "): + await handle_call_command(session, command) + elif command == "secure-payment": + print("\nTesting secure_payment tool (uses ctx.elicit_url())...") + result = await call_tool_with_error_handling(session, "secure_payment", {"amount": 99.99}) + print_tool_result(result) + elif command == "connect-service": + print("\nTesting connect_service tool (raises UrlElicitationRequiredError)...") + result = await call_tool_with_error_handling(session, "connect_service", {"service_name": "github"}) + print_tool_result(result) + else: + print(f"Unknown command: {command}") + print("Type 'help' for available commands.") + + return True + + +async def run_command_loop(session: ClientSession) -> None: + """Run the interactive command loop.""" + while True: + try: + command = input("> ").strip() + except EOFError: + break + except KeyboardInterrupt: + print("\n") + break + + if not command: + continue + + if not await process_command(session, command): + break + + +async def main() -> None: + """Run the interactive URL elicitation client.""" + server_url = "http://localhost:8000/sse" + + print("=" * 60) + print("URL Elicitation Client Example") + print("=" * 60) + print(f"\nConnecting to: {server_url}") + print("(Start server with: cd examples/snippets && uv run server elicitation sse)") + + try: + async with sse_client(server_url) as (read, write): + async with ClientSession( + read, + write, + elicitation_callback=handle_elicitation, + ) as session: + await session.initialize() + print("\nConnected! Type 'help' for available commands.\n") + await run_command_loop(session) + + except ConnectionRefusedError: + print(f"\nError: Could not connect to {server_url}") + print("Make sure the elicitation server is running:") + print(" cd examples/snippets && uv run server elicitation sse") + except Exception as e: + print(f"\nError: {e}") + raise + + +def run() -> None: + """Entry point for the client script.""" + asyncio.run(main()) + + +if __name__ == "__main__": + run() diff --git a/examples/snippets/pyproject.toml b/examples/snippets/pyproject.toml index 76791a55a7..4e68846a09 100644 --- a/examples/snippets/pyproject.toml +++ b/examples/snippets/pyproject.toml @@ -21,3 +21,4 @@ completion-client = "clients.completion_client:main" direct-execution-server = "servers.direct_execution:main" display-utilities-client = "clients.display_utilities:main" oauth-client = "clients.oauth_client:run" +elicitation-client = "clients.url_elicitation_client:run" diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 2c8a3b35ac..a1a65fb32c 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,18 @@ +"""Elicitation examples demonstrating form and URL mode elicitation. + +Form mode elicitation collects structured, non-sensitive data through a schema. +URL mode elicitation directs users to external URLs for sensitive operations +like OAuth flows, credential collection, or payment processing. +""" + +import uuid + from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ElicitRequestURLParams mcp = FastMCP(name="Elicitation Example") @@ -18,7 +29,10 @@ class BookingPreferences(BaseModel): @mcp.tool() async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: - """Book a table with date availability check.""" + """Book a table with date availability check. + + This demonstrates form mode elicitation for collecting non-sensitive user input. + """ # Check if date is available if date == "2024-12-25": # Date unavailable - ask user for alternative @@ -35,3 +49,51 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS # Date available return f"[SUCCESS] Booked for {date} at {time}" + + +@mcp.tool() +async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: + """Process a secure payment requiring URL confirmation. + + This demonstrates URL mode elicitation using ctx.elicit_url() for + operations that require out-of-band user interaction. + """ + elicitation_id = str(uuid.uuid4()) + + result = await ctx.elicit_url( + message=f"Please confirm payment of ${amount:.2f}", + url=f"https://payments.example.com/confirm?amount={amount}&id={elicitation_id}", + elicitation_id=elicitation_id, + ) + + if result.action == "accept": + # In a real app, the payment confirmation would happen out-of-band + # and you'd verify the payment status from your backend + return f"Payment of ${amount:.2f} initiated - check your browser to complete" + elif result.action == "decline": + return "Payment declined by user" + return "Payment cancelled" + + +@mcp.tool() +async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + """Connect to a third-party service requiring OAuth authorization. + + This demonstrates the "throw error" pattern using UrlElicitationRequiredError. + Use this pattern when the tool cannot proceed without user authorization. + """ + elicitation_id = str(uuid.uuid4()) + + # Raise UrlElicitationRequiredError to signal that the client must complete + # a URL elicitation before this request can be processed. + # The MCP framework will convert this to a -32042 error response. + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message=f"Authorization required to connect to {service_name}", + url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", + elicitationId=elicitation_id, + ) + ] + ) diff --git a/examples/snippets/servers/sse_polling_server.py b/examples/snippets/servers/sse_polling_server.py new file mode 100644 index 0000000000..ef1ccd58dd --- /dev/null +++ b/examples/snippets/servers/sse_polling_server.py @@ -0,0 +1,192 @@ +""" +SSE Polling Example Server + +Demonstrates server-initiated SSE stream disconnection for polling behavior. + +Key features: +- retryInterval: Tells clients how long to wait before reconnecting (2 seconds) +- eventStore: Persists events for replay after reconnection +- close_sse_stream(): Gracefully disconnects clients mid-operation + +The server creates a `long-task` tool that: +1. Sends progress notifications at 25%, 50%, 75%, 100% +2. At 50%, closes the SSE stream to trigger client reconnection +3. Continues processing - events are stored and replayed on reconnect + +Run: + uv run examples/snippets/servers/sse_polling_server.py +""" + +import contextlib +import logging +from collections.abc import AsyncIterator +from typing import Any + +import anyio +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + +# Configure logging to show progress +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +class InMemoryEventStore(EventStore): + """Simple in-memory event store for demonstrating SSE polling resumability.""" + + def __init__(self) -> None: + self._events: dict[StreamId, list[tuple[EventId, types.JSONRPCMessage]]] = {} + self._event_index: dict[EventId, tuple[StreamId, types.JSONRPCMessage]] = {} + self._counter = 0 + + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: + event_id = f"evt-{self._counter}" + self._counter += 1 + + if stream_id not in self._events: + self._events[stream_id] = [] + self._events[stream_id].append((event_id, message)) + self._event_index[event_id] = (stream_id, message) + + logger.debug(f"Stored event {event_id} for stream {stream_id}") + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + if last_event_id not in self._event_index: + logger.warning(f"Event {last_event_id} not found") + return None + + stream_id, _ = self._event_index[last_event_id] + events = self._events.get(stream_id, []) + + # Find events after last_event_id + found = False + for event_id, message in events: + if found: + await send_callback(EventMessage(message, event_id)) + logger.info(f"Replayed event {event_id}") + elif event_id == last_event_id: + found = True + + return stream_id + + +def create_app() -> Starlette: + """Create the Starlette application with SSE polling example server.""" + app = Server("sse-polling-example") + + @app.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + if name != "long-task": + raise ValueError(f"Unknown tool: {name}") + + ctx = app.request_context + request_id = ctx.request_id + + logger.info(f"[{request_id}] Starting long-task...") + + # Progress 25% + await ctx.session.send_log_message( + level="info", + data="Progress: 25% - Starting work...", + related_request_id=request_id, + ) + logger.info(f"[{request_id}] Progress: 25%") + await anyio.sleep(1) + + # Progress 50% + await ctx.session.send_log_message( + level="info", + data="Progress: 50% - Halfway there...", + related_request_id=request_id, + ) + logger.info(f"[{request_id}] Progress: 50%") + await anyio.sleep(1) + + # Server-initiated disconnect - client will reconnect + # Use the close_sse_stream callback if available + # This is None if not on streamable HTTP transport or no event store configured + if ctx.close_sse_stream: + logger.info(f"[{request_id}] Closing SSE stream to trigger polling reconnect...") + await ctx.close_sse_stream(retry_interval=2000) # 2 seconds + + # Wait a bit for client to reconnect + await anyio.sleep(0.5) + + # Progress 75% - sent while client was disconnected, stored for replay + await ctx.session.send_log_message( + level="info", + data="Progress: 75% - Almost done (sent while disconnected)...", + related_request_id=request_id, + ) + logger.info(f"[{request_id}] Progress: 75% (client may be disconnected)") + await anyio.sleep(0.5) + + # Progress 100% + await ctx.session.send_log_message( + level="info", + data="Progress: 100% - Complete!", + related_request_id=request_id, + ) + logger.info(f"[{request_id}] Progress: 100%") + + return [types.TextContent(type="text", text="Long task completed successfully!")] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="long-task", + description=( + "A long-running task that demonstrates server-initiated SSE disconnect. " + "The server closes the connection at 50% progress, and the client " + "automatically reconnects to receive the remaining updates." + ), + inputSchema={"type": "object", "properties": {}}, + ) + ] + + # Create event store and session manager + event_store = InMemoryEventStore() + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=event_store, + # Tell clients to reconnect after 2 seconds + retry_interval=2000, + ) + + async def handle_mcp(scope: Scope, receive: Receive, send: Send) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + logger.info("SSE Polling Example Server started on http://localhost:3001/mcp") + yield + logger.info("Server shutting down...") + + return Starlette( + debug=True, + routes=[Mount("/mcp", app=handle_mcp)], + lifespan=lifespan, + ) + + +if __name__ == "__main__": + import uvicorn + + app = create_app() + uvicorn.run(app, host="127.0.0.1", port=3001) diff --git a/pyproject.toml b/pyproject.toml index 078a1dfdcb..03fe23fb4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ venv = ".venv" executionEnvironments = [ { root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false }, { root = "examples/servers", reportUnusedFunction = false }, + { root = "examples/snippets", reportUnusedFunction = false }, ] [tool.ruff] diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 077ff9af64..203a516613 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -3,7 +3,7 @@ from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession from .server.stdio import stdio_server -from .shared.exceptions import McpError +from .shared.exceptions import McpError, UrlElicitationRequiredError from .types import ( CallToolRequest, ClientCapabilities, @@ -125,6 +125,7 @@ "ToolsCapability", "ToolUseContent", "UnsubscribeRequest", + "UrlElicitationRequiredError", "stdio_client", "stdio_server", ] diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 502c901c42..368bdd9df4 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -23,6 +23,7 @@ from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, + create_client_info_from_metadata_url, create_client_registration_request, create_oauth_metadata_request, extract_field_from_www_auth, @@ -33,6 +34,8 @@ handle_protected_resource_response, handle_registration_response, handle_token_response_scopes, + is_valid_client_metadata_url, + should_use_client_metadata_url, ) from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( @@ -96,6 +99,7 @@ class OAuthContext: redirect_handler: Callable[[str], Awaitable[None]] | None callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None timeout: float = 300.0 + client_metadata_url: str | None = None # Discovered metadata protected_resource_metadata: ProtectedResourceMetadata | None = None @@ -226,8 +230,32 @@ def __init__( redirect_handler: Callable[[str], Awaitable[None]] | None = None, callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, timeout: float = 300.0, + client_metadata_url: str | None = None, ): - """Initialize OAuth2 authentication.""" + """Initialize OAuth2 authentication. + + Args: + server_url: The MCP server URL. + client_metadata: OAuth client metadata for registration. + storage: Token storage implementation. + redirect_handler: Handler for authorization redirects. + callback_handler: Handler for authorization callbacks. + timeout: Timeout for the OAuth flow. + client_metadata_url: URL-based client ID. When provided and the server + advertises client_id_metadata_document_supported=true, this URL will be + used as the client_id instead of performing dynamic client registration. + Must be a valid HTTPS URL with a non-root pathname. + + Raises: + ValueError: If client_metadata_url is provided but not a valid HTTPS URL + with a non-root pathname. + """ + # Validate client_metadata_url if provided + if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url): + raise ValueError( + f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}" + ) + self.context = OAuthContext( server_url=server_url, client_metadata=client_metadata, @@ -235,6 +263,7 @@ def __init__( redirect_handler=redirect_handler, callback_handler=callback_handler, timeout=timeout, + client_metadata_url=client_metadata_url, ) self._initialized = False @@ -566,17 +595,30 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.oauth_metadata, ) - # Step 4: Register client if needed - registration_request = create_client_registration_request( - self.context.oauth_metadata, - self.context.client_metadata, - self.context.get_authorization_base_url(self.context.server_url), - ) + # Step 4: Register client or use URL-based client ID (CIMD) if not self.context.client_info: - registration_response = yield registration_request - client_information = await handle_registration_response(registration_response) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) + if should_use_client_metadata_url( + self.context.oauth_metadata, self.context.client_metadata_url + ): + # Use URL-based client ID (CIMD) + logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") + client_information = create_client_info_from_metadata_url( + self.context.client_metadata_url, # type: ignore[arg-type] + redirect_uris=self.context.client_metadata.redirect_uris, + ) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + else: + # Fallback to Dynamic Client Registration + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) + registration_response = yield registration_request + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index bbb3ff52f1..b4426be7f8 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -3,7 +3,7 @@ from urllib.parse import urljoin, urlparse from httpx import Request, Response -from pydantic import ValidationError +from pydantic import AnyUrl, ValidationError from mcp.client.auth import OAuthRegistrationError, OAuthTokenError from mcp.client.streamable_http import MCP_PROTOCOL_VERSION @@ -243,6 +243,75 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma raise OAuthRegistrationError(f"Invalid registration response: {e}") +def is_valid_client_metadata_url(url: str | None) -> bool: + """Validate that a URL is suitable for use as a client_id (CIMD). + + The URL must be HTTPS with a non-root pathname. + + Args: + url: The URL to validate + + Returns: + True if the URL is a valid HTTPS URL with a non-root pathname + """ + if not url: + return False + try: + parsed = urlparse(url) + return parsed.scheme == "https" and parsed.path not in ("", "/") + except Exception: + return False + + +def should_use_client_metadata_url( + oauth_metadata: OAuthMetadata | None, + client_metadata_url: str | None, +) -> bool: + """Determine if URL-based client ID (CIMD) should be used instead of DCR. + + URL-based client IDs should be used when: + 1. The server advertises client_id_metadata_document_supported=true + 2. The client has a valid client_metadata_url configured + + Args: + oauth_metadata: OAuth authorization server metadata + client_metadata_url: URL-based client ID (already validated) + + Returns: + True if CIMD should be used, False if DCR should be used + """ + if not client_metadata_url: + return False + + if not oauth_metadata: + return False + + return oauth_metadata.client_id_metadata_document_supported is True + + +def create_client_info_from_metadata_url( + client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None +) -> OAuthClientInformationFull: + """Create client information using a URL-based client ID (CIMD). + + When using URL-based client IDs, the URL itself becomes the client_id + and no client_secret is used (token_endpoint_auth_method="none"). + + Args: + client_metadata_url: The URL to use as the client_id + redirect_uris: The redirect URIs from the client metadata (passed through for + compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata) + + Returns: + OAuthClientInformationFull with the URL as client_id + """ + return OAuthClientInformationFull( + client_id=client_metadata_url, + token_endpoint_auth_method="none", + redirect_uris=redirect_uris, + ) + + async def handle_token_response_scopes( response: Response, ) -> OAuthToken: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3817ca6b5d..be47d681fb 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -138,7 +138,12 @@ def __init__( async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None elicitation = ( - types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None + types.ElicitationCapability( + form=types.FormElicitationCapability(), + url=types.UrlElicitationCapability(), + ) + if self._elicitation_callback is not _default_elicitation_callback + else None ) roots = ( # TODO: Should this be based on whether we @@ -552,5 +557,10 @@ async def _received_notification(self, notification: types.ServerNotification) - match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) + case types.ElicitCompleteNotification(params=params): + # Handle elicitation completion notification + # Clients MAY use this to retry requests or update UI + # The notification contains the elicitationId of the completed elicitation + pass case _: pass diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 03b65b0a57..fdccaf60fe 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -58,6 +58,27 @@ class ResumptionError(StreamableHTTPError): """Raised when resumption request is invalid.""" +@dataclass +class StreamableHTTPReconnectionOptions: + """Configuration options for reconnection behavior of StreamableHTTPTransport. + + Attributes: + initial_reconnection_delay: Initial backoff time in seconds. Default is 1.0. + max_reconnection_delay: Maximum backoff time in seconds. Default is 30.0. + reconnection_delay_grow_factor: Factor by which delay increases. Default is 1.5. + max_retries: Maximum reconnection attempts. Default is 2. + """ + + initial_reconnection_delay: float = 1.0 + max_reconnection_delay: float = 30.0 + reconnection_delay_grow_factor: float = 1.5 + max_retries: int = 2 + + def __post_init__(self) -> None: + if self.initial_reconnection_delay > self.max_reconnection_delay: + raise ValueError("initial_reconnection_delay cannot exceed max_reconnection_delay") + + @dataclass class RequestContext: """Context for a request operation.""" @@ -81,16 +102,9 @@ def __init__( timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, auth: httpx.Auth | None = None, + reconnection_options: StreamableHTTPReconnectionOptions | None = None, ) -> None: - """Initialize the StreamableHTTP transport. - - Args: - url: The endpoint URL. - headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations. - sse_read_timeout: Timeout for SSE read operations. - auth: Optional HTTPX authentication handler. - """ + """Initialize the StreamableHTTP transport.""" self.url = url self.headers = headers or {} self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout @@ -100,6 +114,8 @@ def __init__( self.auth = auth self.session_id = None self.protocol_version = None + self.reconnection_options = reconnection_options or StreamableHTTPReconnectionOptions() + self._server_retry_seconds: float | None = None # Server-provided retry delay self.request_headers = { ACCEPT: f"{JSON}, {SSE}", CONTENT_TYPE: JSON, @@ -150,6 +166,24 @@ def _maybe_extract_protocol_version_from_message( ) # pragma: no cover logger.warning(f"Raw result: {message.root.result}") + def _get_next_reconnection_delay(self, attempt: int) -> float: + """Calculate the next reconnection delay using exponential backoff. + + Args: + attempt: Current reconnection attempt count + + Returns: + Time to wait in seconds before next reconnection attempt + """ + # Use server-provided retry value if available + if self._server_retry_seconds is not None: + return self._server_retry_seconds + + # Fall back to exponential backoff + opts = self.reconnection_options + delay = opts.initial_reconnection_delay * (opts.reconnection_delay_grow_factor**attempt) + return min(delay, opts.max_reconnection_delay) + async def _handle_sse_event( self, sse: ServerSentEvent, @@ -157,9 +191,29 @@ async def _handle_sse_event( original_request_id: RequestId | None = None, resumption_callback: Callable[[str], Awaitable[None]] | None = None, is_initialization: bool = False, - ) -> bool: - """Handle an SSE event, returning True if the response is complete.""" + ) -> tuple[bool, bool]: + """Handle an SSE event. + + Returns: + Tuple of (is_complete, has_event_id) where: + - is_complete: True if the response stream is complete (got response/error) + - has_event_id: True if this event had an ID (indicating resumability) + """ + event_id = sse.id # httpx_sse defaults to "" for missing ID + has_event_id = bool(event_id) # True if non-empty string + + # Capture server-provided retry value for reconnection timing + if sse.retry is not None: # pragma: no cover + self._server_retry_seconds = sse.retry / 1000.0 # Convert ms to seconds + if sse.event == "message": + # Check for priming event (empty data but may have ID for resumption) + if not sse.data or not sse.data.strip(): + # Priming event - just track the ID for resumption + if has_event_id and resumption_callback: + await resumption_callback(event_id) + return False, has_event_id + try: message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"SSE message: {message}") @@ -176,20 +230,24 @@ async def _handle_sse_event( await read_stream_writer.send(session_message) # Call resumption token callback if we have an ID - if sse.id and resumption_callback: - await resumption_callback(sse.id) + if has_event_id and resumption_callback: + await resumption_callback(event_id) # If this is a response or error return True indicating completion # Otherwise, return False to continue listening - return isinstance(message.root, JSONRPCResponse | JSONRPCError) + return isinstance(message.root, JSONRPCResponse | JSONRPCError), has_event_id except Exception as exc: # pragma: no cover logger.exception("Error parsing SSE message") await read_stream_writer.send(exc) - return False + return False, has_event_id else: # pragma: no cover - logger.warning(f"Unknown SSE event: {sse.event}") - return False + # Empty event or priming event - not a completion, but may have ID + # httpx_sse defaults event to "message", so this handles non-standard events + if has_event_id and resumption_callback: + # Priming event - call resumption callback + await resumption_callback(event_id) + return False, has_event_id async def handle_get_stream( self, @@ -214,7 +272,7 @@ async def handle_get_stream( logger.debug("GET SSE connection established") async for sse in event_source.aiter_sse(): - await self._handle_sse_event(sse, read_stream_writer) + _is_complete, _has_event_id = await self._handle_sse_event(sse, read_stream_writer) except Exception as exc: logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover @@ -243,7 +301,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: logger.debug("Resumption GET SSE connection established") async for sse in event_source.aiter_sse(): # pragma: no branch - is_complete = await self._handle_sse_event( + is_complete, _has_event_id = await self._handle_sse_event( sse, ctx.read_stream_writer, original_request_id, @@ -321,25 +379,98 @@ async def _handle_sse_response( response: httpx.Response, ctx: RequestContext, is_initialization: bool = False, - ) -> None: - """Handle SSE response from the server.""" + attempt: int = 0, + ) -> tuple[bool, str | None]: + """Handle SSE response from the server with automatic reconnection. + + Returns: + Tuple of (has_priming_event, last_event_id) where: + - has_priming_event: True if any event had an ID (priming event received) + - last_event_id: The last event ID received, for resumption + """ + has_priming_event = False + last_event_id: str | None = None + is_complete = False + try: event_source = EventSource(response) async for sse in event_source.aiter_sse(): # pragma: no branch - is_complete = await self._handle_sse_event( + is_complete, has_event_id = await self._handle_sse_event( sse, ctx.read_stream_writer, resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), is_initialization=is_initialization, ) - # If the SSE event indicates completion, like returning respose/error + + # Track priming events + if has_event_id: + has_priming_event = True + last_event_id = sse.id + + # If the SSE event indicates completion, like returning response/error # break the loop if is_complete: await response.aclose() break + except Exception as e: # pragma: no cover + logger.exception("Error reading SSE stream:") + # Don't send exception if we can reconnect + if not (has_priming_event and last_event_id): + await ctx.read_stream_writer.send(e) + + # Auto-reconnect if stream ended without completion and we have priming event + if not is_complete and has_priming_event and last_event_id: # pragma: no cover + await self._attempt_sse_reconnection(ctx, last_event_id, attempt) + + return has_priming_event, last_event_id + + async def _attempt_sse_reconnection( # pragma: no cover + self, + ctx: RequestContext, + last_event_id: str, + attempt: int, + ) -> None: + """Attempt to reconnect to SSE stream using resumption token. + + Called when SSE stream ends without receiving a response/error, + but we have a priming event indicating resumability. + """ + max_retries = self.reconnection_options.max_retries + + if attempt >= max_retries: + error_msg = f"Max reconnection attempts ({max_retries}) exceeded" + logger.error(error_msg) + await ctx.read_stream_writer.send(StreamableHTTPError(error_msg)) + return + + # Calculate delay (uses server retry if available, else exponential backoff) + delay = self._get_next_reconnection_delay(attempt) + logger.info(f"SSE stream closed, reconnecting in {delay:.1f}s (attempt {attempt + 1}/{max_retries})") + + await anyio.sleep(delay) + + # Build resumption context with last_event_id + resumption_metadata = ClientMessageMetadata( + resumption_token=last_event_id, + on_resumption_token_update=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + ) + + resumption_ctx = RequestContext( + client=ctx.client, + headers=ctx.headers, + session_id=ctx.session_id, + session_message=ctx.session_message, + metadata=resumption_metadata, + read_stream_writer=ctx.read_stream_writer, + sse_read_timeout=ctx.sse_read_timeout, + ) + + try: + await self._handle_resumption_request(resumption_ctx) except Exception as e: - logger.exception("Error reading SSE stream:") # pragma: no cover - await ctx.read_stream_writer.send(e) # pragma: no cover + logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}") + # Recursive retry with incremented attempt counter + await self._attempt_sse_reconnection(ctx, last_event_id, attempt + 1) async def _handle_unexpected_content_type( self, @@ -442,6 +573,66 @@ def get_session_id(self) -> str | None: """Get the current session ID.""" return self.session_id + async def resume_stream( + self, + client: httpx.AsyncClient, + read_stream_writer: StreamWriter, + last_event_id: str, + on_resumption_token: Callable[[str], Awaitable[None]] | None = None, + ) -> None: + """Resume SSE stream from a previous event ID. + + This method allows clients to reconnect and resume receiving events + from where they left off using the Last-Event-ID header. + + Args: + client: The HTTP client to use for the request + read_stream_writer: Stream writer for sending received messages + last_event_id: The last event ID received, to resume from + on_resumption_token: Optional callback invoked with new event IDs + """ + if not self.session_id: + logger.warning("Cannot resume stream without a session ID") + return + + headers = self._prepare_request_headers(self.request_headers) + headers[LAST_EVENT_ID] = last_event_id + + try: + async with aconnect_sse( + client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug(f"Resumed SSE stream from event ID: {last_event_id}") # pragma: no cover + + async for sse in event_source.aiter_sse(): # pragma: no cover + _is_complete, has_event_id = await self._handle_sse_event( + sse, + read_stream_writer, + resumption_callback=on_resumption_token, + ) + + # Call resumption callback if we have a new event ID + if has_event_id and sse.id and on_resumption_token: + await on_resumption_token(sse.id) + + except httpx.HTTPStatusError as exc: + # Read response body so consumers can access error details + try: + await exc.response.aread() + except Exception: + pass # Best effort - don't fail if we can't read body + if exc.response.status_code == 405: + logger.debug("Server does not support SSE resumption via GET") # pragma: no cover + else: + logger.warning(f"Failed to resume stream: {exc}") + except Exception as exc: # pragma: no cover + logger.debug(f"Resume stream error: {exc}") + @asynccontextmanager async def streamablehttp_client( @@ -452,6 +643,7 @@ async def streamablehttp_client( terminate_on_close: bool = True, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + reconnection_options: StreamableHTTPReconnectionOptions | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -465,14 +657,8 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. - - Yields: - Tuple containing: - - read_stream: Stream for reading messages from the server - - write_stream: Stream for sending messages to the server - - get_session_id_callback: Function to retrieve the current session ID """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth, reconnection_options) read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index c2d98de384..49195415bf 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -36,6 +36,15 @@ class CancelledElicitation(BaseModel): ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation +class AcceptedUrlElicitation(BaseModel): + """Result when user accepts a URL mode elicitation.""" + + action: Literal["accept"] = "accept" + + +UrlElicitationResult = AcceptedUrlElicitation | DeclinedElicitation | CancelledElicitation + + # Primitive types allowed in elicitation schemas _ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) @@ -99,20 +108,22 @@ async def elicit_with_validation( schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user with schema validation. + """Elicit information from the client/user with schema validation (form mode). This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the user and collect a response according to the provided schema. Or in case a client is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. + + For sensitive data like credentials or OAuth flows, use elicit_url() instead. """ # Validate that schema only contains primitive types and fail loudly if not _validate_elicitation_schema(schema) json_schema = schema.model_json_schema() - result = await session.elicit( + result = await session.elicit_form( message=message, requestedSchema=json_schema, related_request_id=related_request_id, @@ -129,3 +140,51 @@ async def elicit_with_validation( else: # pragma: no cover # This should never happen, but handle it just in case raise ValueError(f"Unexpected elicitation action: {result.action}") + + +async def elicit_url( + session: ServerSession, + message: str, + url: str, + elicitation_id: str, + related_request_id: RequestId | None = None, +) -> UrlElicitationResult: + """Elicit information from the user via out-of-band URL navigation (URL mode). + + This method directs the user to an external URL where sensitive interactions can + occur without passing data through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + the server should send an ElicitCompleteNotification to notify the client. + + Args: + session: The server session + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + result = await session.elicit_url( + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=related_request_id, + ) + + if result.action == "accept": + return AcceptedUrlElicitation() + elif result.action == "decline": + return DeclinedElicitation() + elif result.action == "cancel": + return CancelledElicitation() + else: # pragma: no cover + # This should never happen, but handle it just in case + raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 865b8e7e72..9a438e9ac2 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -42,8 +42,12 @@ from mcp.server.elicitation import ( ElicitationResult, ElicitSchemaModelT, + UrlElicitationResult, elicit_with_validation, ) +from mcp.server.elicitation import ( + elicit_url as _elicit_url, +) from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -60,7 +64,7 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.context import LifespanContextT, RequestContext, RequestT +from mcp.shared.context import CloseSSEStreamCallback, LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument @@ -102,6 +106,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): json_response: bool stateless_http: bool """Define if the server should create a new transport per request.""" + sse_retry_interval: int | None + """SSE retry interval in milliseconds sent in priming event for client reconnection.""" # resource settings warn_on_duplicate_resources: bool @@ -161,6 +167,7 @@ def __init__( # noqa: PLR0913 streamable_http_path: str = "/mcp", json_response: bool = False, stateless_http: bool = False, + sse_retry_interval: int | None = None, warn_on_duplicate_resources: bool = True, warn_on_duplicate_tools: bool = True, warn_on_duplicate_prompts: bool = True, @@ -180,6 +187,7 @@ def __init__( # noqa: PLR0913 streamable_http_path=streamable_http_path, json_response=json_response, stateless_http=stateless_http, + sse_retry_interval=sse_retry_interval, warn_on_duplicate_resources=warn_on_duplicate_resources, warn_on_duplicate_tools=warn_on_duplicate_tools, warn_on_duplicate_prompts=warn_on_duplicate_prompts, @@ -697,6 +705,10 @@ def custom_route( The handler function must be an async function that accepts a Starlette Request and returns a Response. + Routes using this decorator will not require authorization. It is intended + for uses that are either a part of authorization flows or intended to be + public such as health check endpoints. + Args: path: URL path for the route (e.g., "/oauth/callback") methods: List of HTTP methods to support (e.g., ["GET", "POST"]) @@ -935,6 +947,7 @@ def streamable_http_app(self) -> Starlette: json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting security_settings=self.settings.transport_security, + retry_interval=self.settings.sse_retry_interval, ) # Create the ASGI handler @@ -1200,6 +1213,41 @@ async def elicit( related_request_id=self.request_id, ) + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + ) -> UrlElicitationResult: + """Request URL mode elicitation from the client. + + This directs the user to an external URL for out-of-band interactions + that must not pass through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + call `self.session.send_elicit_complete(elicitation_id)` to notify the client. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + return await _elicit_url( + session=self.request_context.session, + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=self.request_id, + ) + async def log( self, level: Literal["debug", "info", "warning", "error"], @@ -1239,6 +1287,23 @@ def session(self): """Access to the underlying session for advanced usage.""" return self.request_context.session + @property + def close_sse_stream(self) -> CloseSSEStreamCallback | None: + """Callback to close SSE stream for polling behavior (SEP-1699). + + This allows tools to trigger server-initiated SSE disconnect during + long-running operations, enabling client reconnection with polling. + + Returns None if: + - Not running on streamable HTTP transport + - No event store configured (events would be lost) + + Usage: + if ctx.close_sse_stream: + await ctx.close_sse_stream(retry_interval=3000) # Reconnect after 3s + """ + return self.request_context.close_sse_stream + # Convenience methods for common log levels async def debug(self, message: str, **extra: Any) -> None: """Send a debug log message.""" diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a0617036f9..9f9178043e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -680,12 +680,14 @@ async def _handle_request( token = None try: - # Extract request context from message metadata + # Extract request context and callback from message metadata request_data = None + close_sse_stream_callback = None if message.message_metadata is not None and isinstance( message.message_metadata, ServerMessageMetadata ): # pragma: no cover request_data = message.message_metadata.request_context + close_sse_stream_callback = message.message_metadata.close_sse_stream # Set our global state that can be retrieved via # app.get_request_context() @@ -696,6 +698,7 @@ async def _handle_request( session, lifespan_context, request=request_data, + close_sse_stream=close_sse_stream_callback, ) ) response = await handler(req) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 677ffef89f..b116fbe384 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -335,19 +335,42 @@ async def elicit( requestedSchema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, ) -> types.ElicitResult: - """Send an elicitation/create request. + """Send a form mode elicitation/create request. Args: message: The message to present to the user requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation Returns: The client's response + + Note: + This method is deprecated in favor of elicit_form(). It remains for + backward compatibility but new code should use elicit_form(). + """ + return await self.elicit_form(message, requestedSchema, related_request_id) + + async def elicit_form( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response with form data """ return await self.send_request( types.ServerRequest( types.ElicitRequest( - params=types.ElicitRequestParams( + params=types.ElicitRequestFormParams( message=message, requestedSchema=requestedSchema, ), @@ -357,6 +380,41 @@ async def elicit( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a URL mode elicitation/create request. + + This directs the user to an external URL for out-of-band interactions + like OAuth flows, credential collection, or payment processing. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response indicating acceptance, decline, or cancellation + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + params=types.ElicitRequestURLParams( + message=message, + url=url, + elicitationId=elicitation_id, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + async def send_ping(self) -> types.EmptyResult: # pragma: no cover """Send a ping request.""" return await self.send_request( @@ -399,6 +457,30 @@ async def send_prompt_list_changed(self) -> None: # pragma: no cover """Send a prompt list changed notification.""" await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) + async def send_elicit_complete( + self, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send an elicitation completion notification. + + This should be sent when a URL mode elicitation has been completed + out-of-band to inform the client that it may retry any requests + that were waiting for this elicitation. + + Args: + elicitation_id: The unique identifier of the completed elicitation + related_request_id: Optional ID of the request that triggered this + """ + await self.send_notification( + types.ServerNotification( + types.ElicitCompleteNotification( + params=types.ElicitCompleteNotificationParams(elicitationId=elicitation_id) + ) + ), + related_request_id, + ) + async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index d6ccfd5a82..45851ebf65 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -28,6 +28,7 @@ TransportSecurityMiddleware, TransportSecuritySettings, ) +from mcp.shared.context import CloseSSEStreamCallback from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -140,6 +141,7 @@ def __init__( is_json_response_enabled: bool = False, event_store: EventStore | None = None, security_settings: TransportSecuritySettings | None = None, + retry_interval: int | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -153,6 +155,10 @@ def __init__( resumability will be enabled, allowing clients to reconnect and resume messages. security_settings: Optional security settings for DNS rebinding protection. + retry_interval: Optional SSE retry interval in milliseconds. + When set, this value is sent to clients in the SSE + `retry` field, telling them how long to wait before + reconnecting after a disconnect. Raises: ValueError: If the session ID contains invalid characters. @@ -164,6 +170,7 @@ def __init__( self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._security = TransportSecurityMiddleware(security_settings) + self._retry_interval = retry_interval self._request_streams: dict[ RequestId, tuple[ @@ -233,9 +240,9 @@ def _get_session_id(self, request: Request) -> str | None: # pragma: no cover """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) - def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover + def _create_event_data(self, event_message: EventMessage) -> dict[str, str | int]: # pragma: no cover """Create event data dictionary from an EventMessage.""" - event_data = { + event_data: dict[str, str | int] = { "event": "message", "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), } @@ -246,6 +253,57 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # return event_data + async def _create_priming_event(self, stream_id: str) -> dict[str, str | int] | None: + """Create a priming event to establish resumption capability. + + Only sends if eventStore is configured (opt-in for resumability). + + Args: + stream_id: The ID of the stream to create the priming event for + + Returns: + Event data dictionary for the priming event, or None if no event store + """ + if self._event_store is None: + return None + + # Store an empty message to get an event ID + # Using an empty dict as a placeholder - it won't be sent as actual data + priming_event_id = await self._event_store.store_event( + stream_id, JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "_priming"}) + ) + + event_data: dict[str, str | int] = { + "id": priming_event_id, + "data": "", # Empty data for priming event + } + + # Add retry interval if configured (sse_starlette expects int, not str) + if self._retry_interval is not None: + event_data["retry"] = self._retry_interval + + return event_data + + def _create_close_sse_stream_callback(self, request_id: RequestId) -> CloseSSEStreamCallback | None: + """Create a bound callback for closing SSE streams. + + Args: + request_id: The request ID to bind to the callback + + Returns: + A callback that closes the SSE stream for this request, + or None if no event store is configured (events would be lost). + """ + # Only provide callback if event store is configured + # Without an event store, closing the stream would lose events + if self._event_store is None: + return None + + async def callback(retry_interval: int | None = None) -> bool: + return await self.close_sse_stream(request_id, retry_interval) + + return callback + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # pragma: no cover """Clean up memory streams for a given request ID.""" if request_id in self._request_streams: @@ -457,12 +515,17 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await self._clean_up_memory_streams(request_id) else: # pragma: no cover # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str | int]](0) async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: + # Send priming event first if event store is configured + priming_event = await self._create_priming_event(request_id) + if priming_event is not None: + await sse_stream_writer.send(priming_event) + # Process messages from the request-specific stream async for event_message in request_stream_reader: # Build the event data @@ -502,7 +565,12 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - metadata = ServerMessageMetadata(request_context=request) + # Create callback for closing SSE stream (only if event store configured) + close_callback = self._create_close_sse_stream_callback(request_id) + metadata = ServerMessageMetadata( + request_context=request, + close_sse_stream=close_callback, + ) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) except Exception: @@ -573,7 +641,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr return # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str | int]](0) async def standalone_sse_writer(): try: @@ -583,6 +651,11 @@ async def standalone_sse_writer(): standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: + # Send priming event first if event store is configured + priming_event = await self._create_priming_event(GET_STREAM_KEY) + if priming_event is not None: + await sse_stream_writer.send(priming_event) + # Process messages from the standalone stream async for event_message in standalone_stream_reader: # For the standalone stream, we handle: @@ -669,6 +742,36 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") + async def close_sse_stream(self, request_id: RequestId, retry_interval: int | None = None) -> bool: + """Close an SSE stream for a specific request, triggering client reconnection. + + Use this to implement polling behavior during long-running operations - + client will reconnect after the retry interval specified in the priming event. + + Args: + request_id: The request ID (or stream key) of the stream to close + retry_interval: Optional retry interval in ms to send before closing. + If provided, overrides the transport's default retry interval. + + Returns: + True if the stream was found and closed, False otherwise. + """ + request_id_str = str(request_id) + if request_id_str not in self._request_streams: + return False + + try: + sender, receiver = self._request_streams[request_id_str] + await sender.aclose() + await receiver.aclose() + return True + except Exception: # pragma: no cover + # Stream might already be closed + logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed") + return False + finally: + self._request_streams.pop(request_id_str, None) + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover """Handle unsupported HTTP methods.""" headers = { @@ -763,7 +866,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Create SSE stream for replay - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str | int]](0) async def replay_sender(): try: diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 04c7de2d7b..47feb69efd 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -60,12 +60,14 @@ def __init__( json_response: bool = False, stateless: bool = False, security_settings: TransportSecuritySettings | None = None, + retry_interval: int | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings + self.retry_interval = retry_interval # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -226,6 +228,7 @@ async def _handle_stateful_request( is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) security_settings=self.security_settings, + retry_interval=self.retry_interval, ) assert http_transport.mcp_session_id is not None @@ -277,3 +280,25 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE status_code=HTTPStatus.BAD_REQUEST, ) await response(scope, receive, send) + + async def close_sse_stream( # pragma: no cover + self, session_id: str, request_id: str | int, retry_interval: int | None = None + ) -> bool: + """Close an SSE stream for a specific request, triggering client reconnection. + + Use this to implement polling behavior during long-running operations. + The client will reconnect after the retry interval specified in the priming event. + + Args: + session_id: The MCP session ID (from mcp-session-id header) + request_id: The request ID of the stream to close + retry_interval: Optional retry interval in ms to send before closing. + If provided, overrides the transport's default retry interval. + + Returns: + True if the stream was found and closed, False otherwise + """ + if session_id not in self._server_instances: + return False + transport = self._server_instances[session_id] + return await transport.close_sse_stream(request_id, retry_interval) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5f..4638564eb2 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol from typing_extensions import TypeVar @@ -11,6 +11,20 @@ RequestT = TypeVar("RequestT", default=Any) +class CloseSSEStreamCallback(Protocol): # pragma: no cover + """Callback to close SSE stream for polling behavior (SEP-1699). + + Args: + retry_interval: Optional retry interval in ms to send before closing. + If None, uses the transport's default retry interval. + + Returns: + True if the stream was found and closed, False otherwise. + """ + + async def __call__(self, retry_interval: int | None = None) -> bool: ... + + @dataclass class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): request_id: RequestId @@ -18,3 +32,6 @@ class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): session: SessionT lifespan_context: LifespanContextT request: RequestT | None = None + # Callback to close SSE stream for polling behavior (SEP-1699) + # None if not on streamable HTTP transport or no event store configured + close_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 97a1c09a9f..4943114912 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -1,4 +1,8 @@ -from mcp.types import ErrorData +from __future__ import annotations + +from typing import Any, cast + +from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData class McpError(Exception): @@ -12,3 +16,56 @@ def __init__(self, error: ErrorData): """Initialize McpError.""" super().__init__(error.message) self.error = error + + +class UrlElicitationRequiredError(McpError): + """ + Specialized error for when a tool requires URL mode elicitation(s) before proceeding. + + Servers can raise this error from tool handlers to indicate that the client + must complete one or more URL elicitations before the request can be processed. + + Example: + raise UrlElicitationRequiredError([ + ElicitRequestURLParams( + mode="url", + message="Authorization required for your files", + url="https://example.com/oauth/authorize", + elicitationId="auth-001" + ) + ]) + """ + + def __init__( + self, + elicitations: list[ElicitRequestURLParams], + message: str | None = None, + ): + """Initialize UrlElicitationRequiredError.""" + if message is None: + message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required" + + self._elicitations = elicitations + + error = ErrorData( + code=URL_ELICITATION_REQUIRED, + message=message, + data={"elicitations": [e.model_dump(by_alias=True, exclude_none=True) for e in elicitations]}, + ) + super().__init__(error) + + @property + def elicitations(self) -> list[ElicitRequestURLParams]: + """The list of URL elicitations required before the request can proceed.""" + return self._elicitations + + @classmethod + def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError: + """Reconstruct from an ErrorData received over the wire.""" + if error.code != URL_ELICITATION_REQUIRED: + raise ValueError(f"Expected error code {URL_ELICITATION_REQUIRED}, got {error.code}") + + data = cast(dict[str, Any], error.data or {}) + raw_elicitations = cast(list[dict[str, Any]], data.get("elicitations", [])) + elicitations = [ElicitRequestURLParams.model_validate(e) for e in raw_elicitations] + return cls(elicitations, error.message) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 4b6df23eb6..866228daa6 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -7,9 +7,13 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import TYPE_CHECKING from mcp.types import JSONRPCMessage, RequestId +if TYPE_CHECKING: + from mcp.shared.context import CloseSSEStreamCallback + ResumptionToken = str ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] @@ -30,6 +34,9 @@ class ServerMessageMetadata: related_request_id: RequestId | None = None # Request-specific context (e.g., headers, auth info) request_context: object | None = None + # Callback to close SSE stream for polling behavior (SEP-1699) + # None if not on streamable HTTP transport or no event store configured + close_sse_stream: "CloseSSEStreamCallback | None" = None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None diff --git a/src/mcp/shared/tool_name_validation.py b/src/mcp/shared/tool_name_validation.py index 188d5fb146..f35efa5a61 100644 --- a/src/mcp/shared/tool_name_validation.py +++ b/src/mcp/shared/tool_name_validation.py @@ -6,7 +6,7 @@ digits (0-9), underscore (_), dash (-), and dot (.). Tool names SHOULD NOT contain spaces, commas, or other special characters. -See: https://modelcontextprotocol.io/specification/draft/server/tools#tool-names +See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names """ from __future__ import annotations @@ -21,7 +21,7 @@ TOOL_NAME_REGEX = re.compile(r"^[A-Za-z0-9._-]{1,128}$") # SEP reference URL for warning messages -SEP_986_URL = "https://github.com/modelcontextprotocol/modelcontextprotocol/issues/986" +SEP_986_URL = "https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names" @dataclass diff --git a/src/mcp/types.py b/src/mcp/types.py index 8955de694e..dd9775f8c8 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -146,6 +146,10 @@ class JSONRPCResponse(BaseModel): model_config = ConfigDict(extra="allow") +# MCP-specific error codes in the range [-32000, -32099] +URL_ELICITATION_REQUIRED = -32042 +"""Error code indicating that a URL mode elicitation is required before the request can be processed.""" + # SDK error codes CONNECTION_CLOSED = -32000 # REQUEST_TIMEOUT = -32001 # the typescript sdk uses this @@ -272,8 +276,29 @@ class SamplingToolsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class FormElicitationCapability(BaseModel): + """Capability for form mode elicitation.""" + + model_config = ConfigDict(extra="allow") + + +class UrlElicitationCapability(BaseModel): + """Capability for URL mode elicitation.""" + + model_config = ConfigDict(extra="allow") + + class ElicitationCapability(BaseModel): - """Capability for elicitation operations.""" + """Capability for elicitation operations. + + Clients must support at least one mode (form or url). + """ + + form: FormElicitationCapability | None = None + """Present if the client supports form mode elicitation.""" + + url: UrlElicitationCapability | None = None + """Present if the client supports URL mode elicitation.""" model_config = ConfigDict(extra="allow") @@ -1411,6 +1436,32 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n params: CancelledNotificationParams +class ElicitCompleteNotificationParams(NotificationParams): + """Parameters for elicitation completion notifications.""" + + elicitationId: str + """The unique identifier of the elicitation that was completed.""" + + model_config = ConfigDict(extra="allow") + + +class ElicitCompleteNotification( + Notification[ElicitCompleteNotificationParams, Literal["notifications/elicitation/complete"]] +): + """ + A notification from the server to the client, informing it that a URL mode + elicitation has been completed. + + Clients MAY use the notification to automatically retry requests that received a + URLElicitationRequiredError, update the user interface, or otherwise continue + an interaction. However, because delivery of the notification is not guaranteed, + clients must not wait indefinitely for a notification from the server. + """ + + method: Literal["notifications/elicitation/complete"] = "notifications/elicitation/complete" + params: ElicitCompleteNotificationParams + + class ClientRequest( RootModel[ PingRequest @@ -1442,14 +1493,58 @@ class ClientNotification( """Schema for elicitation requests.""" -class ElicitRequestParams(RequestParams): - """Parameters for elicitation requests.""" +class ElicitRequestFormParams(RequestParams): + """Parameters for form mode elicitation requests. + + Form mode collects non-sensitive information from the user via an in-band form + rendered by the client. + """ + + mode: Literal["form"] = "form" + """The elicitation mode (always "form" for this type).""" message: str + """The message to present to the user describing what information is being requested.""" + requestedSchema: ElicitRequestedSchema + """ + A restricted subset of JSON Schema defining the structure of expected response. + Only top-level properties are allowed, without nesting. + """ + model_config = ConfigDict(extra="allow") +class ElicitRequestURLParams(RequestParams): + """Parameters for URL mode elicitation requests. + + URL mode directs users to external URLs for sensitive out-of-band interactions + like OAuth flows, credential collection, or payment processing. + """ + + mode: Literal["url"] = "url" + """The elicitation mode (always "url" for this type).""" + + message: str + """The message to present to the user explaining why the interaction is needed.""" + + url: str + """The URL that the user should navigate to.""" + + elicitationId: str + """ + The ID of the elicitation, which must be unique within the context of the server. + The client MUST treat this ID as an opaque value. + """ + + model_config = ConfigDict(extra="allow") + + +# Union type for elicitation request parameters +ElicitRequestParams: TypeAlias = ElicitRequestURLParams | ElicitRequestFormParams +"""Parameters for elicitation requests - either form or URL mode.""" + + class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): """A request from the server to elicit information from the client.""" @@ -1463,18 +1558,33 @@ class ElicitResult(Result): action: Literal["accept", "decline", "cancel"] """ The user action in response to the elicitation. - - "accept": User submitted the form/confirmed the action + - "accept": User submitted the form/confirmed the action (or consented to URL navigation) - "decline": User explicitly declined the action - "cancel": User dismissed without making an explicit choice """ content: dict[str, str | int | float | bool | list[str] | None] | None = None """ - The submitted form data, only present when action is "accept". - Contains values matching the requested schema. + The submitted form data, only present when action is "accept" in form mode. + Contains values matching the requested schema. Values can be strings, integers, + booleans, or arrays of strings. + For URL mode, this field is omitted. """ +class ElicitationRequiredErrorData(BaseModel): + """Error data for URLElicitationRequiredError. + + Servers return this when a request cannot be processed until one or more + URL mode elicitations are completed. + """ + + elicitations: list[ElicitRequestURLParams] + """List of URL mode elicitations that must be completed.""" + + model_config = ConfigDict(extra="allow") + + class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): pass @@ -1492,6 +1602,7 @@ class ServerNotification( | ResourceListChangedNotification | ToolListChangedNotification | PromptListChangedNotification + | ElicitCompleteNotification ] ): pass diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index d032bdcd6e..609be9873a 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -17,12 +17,16 @@ from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, + create_client_info_from_metadata_url, + create_client_registration_request, create_oauth_metadata_request, extract_field_from_www_auth, extract_resource_metadata_from_www_auth, extract_scope_from_www_auth, get_client_metadata_scopes, handle_registration_response, + is_valid_client_metadata_url, + should_use_client_metadata_url, ) from mcp.shared.auth import ( OAuthClientInformationFull, @@ -945,6 +949,49 @@ def text(self): assert "Registration failed: 400" in str(exc_info.value) +class TestCreateClientRegistrationRequest: + """Test client registration request creation.""" + + def test_uses_registration_endpoint_from_metadata(self): + """Test that registration URL comes from metadata when available.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), + ) + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(oauth_metadata, client_metadata, "https://auth.example.com") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + def test_falls_back_to_default_register_endpoint_when_no_metadata(self): + """Test that registration uses fallback URL when auth_server_metadata is None.""" + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(None, client_metadata, "https://auth.example.com") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + def test_falls_back_when_metadata_has_no_registration_endpoint(self): + """Test fallback when metadata exists but lacks registration_endpoint.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + # No registration_endpoint + ) + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(oauth_metadata, client_metadata, "https://auth.example.com") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + class TestAuthFlow: """Test the auth flow in httpx.""" @@ -1783,3 +1830,296 @@ def test_extract_field_from_www_auth_invalid_cases( result = extract_field_from_www_auth(init_response, field_name) assert result is None, f"Should return None for {description}" + + +class TestCIMD: + """Test Client ID Metadata Document (CIMD) support.""" + + @pytest.mark.parametrize( + "url,expected", + [ + # Valid CIMD URLs + ("https://example.com/client", True), + ("https://example.com/client-metadata.json", True), + ("https://example.com/path/to/client", True), + ("https://example.com:8443/client", True), + # Invalid URLs - HTTP (not HTTPS) + ("http://example.com/client", False), + # Invalid URLs - root path + ("https://example.com", False), + ("https://example.com/", False), + # Invalid URLs - None or empty + (None, False), + ("", False), + # Invalid URLs - malformed (triggers urlparse exception) + ("http://[::1/foo/", False), + ], + ) + def test_is_valid_client_metadata_url(self, url: str | None, expected: bool): + """Test CIMD URL validation.""" + assert is_valid_client_metadata_url(url) == expected + + def test_should_use_client_metadata_url_when_server_supports(self): + """Test that CIMD is used when server supports it and URL is provided.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=True, + ) + assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True + + def test_should_not_use_client_metadata_url_when_server_does_not_support(self): + """Test that CIMD is not used when server doesn't support it.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=False, + ) + assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False + + def test_should_not_use_client_metadata_url_when_not_provided(self): + """Test that CIMD is not used when no URL is provided.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=True, + ) + assert should_use_client_metadata_url(oauth_metadata, None) is False + + def test_should_not_use_client_metadata_url_when_no_metadata(self): + """Test that CIMD is not used when OAuth metadata is None.""" + assert should_use_client_metadata_url(None, "https://example.com/client") is False + + def test_create_client_info_from_metadata_url(self): + """Test creating client info from CIMD URL.""" + client_info = create_client_info_from_metadata_url( + "https://example.com/client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + assert client_info.client_id == "https://example.com/client" + assert client_info.token_endpoint_auth_method == "none" + assert client_info.redirect_uris == [AnyUrl("http://localhost:3030/callback")] + assert client_info.client_secret is None + + def test_oauth_provider_with_valid_client_metadata_url( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test OAuthClientProvider initialization with valid client_metadata_url.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + assert provider.context.client_metadata_url == "https://example.com/client" + + def test_oauth_provider_with_invalid_client_metadata_url_raises_error( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test OAuthClientProvider raises error for invalid client_metadata_url.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + with pytest.raises(ValueError) as exc_info: + OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="http://example.com/client", # HTTP instead of HTTPS + ) + assert "HTTPS URL with a non-root pathname" in str(exc_info.value) + + @pytest.mark.anyio + async def test_auth_flow_uses_cimd_when_server_supports( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that auth flow uses CIMD URL as client_id when server supports it.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 response + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + + # OAuth metadata discovery + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"client_id_metadata_document_supported": true}' + ), + request=oauth_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Should skip DCR and go directly to token exchange + token_request = await auth_flow.asend(oauth_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + + # Verify client_id is the CIMD URL + content = token_request.content.decode() + assert "client_id=https%3A%2F%2Fexample.com%2Fclient" in content + + # Verify client info was set correctly + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "https://example.com/client" + assert provider.context.client_info.token_endpoint_auth_method == "none" + + # Complete the flow + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer test_token" + + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that auth flow falls back to DCR when server doesn't support CIMD.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + await auth_flow.__anext__() + + # Send 401 response + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + + # OAuth metadata discovery - server does NOT support CIMD + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=oauth_request, + ) + + # Should proceed to DCR instead of skipping it + registration_request = await auth_flow.asend(oauth_response) + assert registration_request.method == "POST" + assert str(registration_request.url) == "https://auth.example.com/register" + + # Complete the flow to avoid generator cleanup issues + registration_response = httpx.Response( + 201, + content=b'{"client_id": "dcr_client_id", "redirect_uris": ["http://localhost:3030/callback"]}', + request=registration_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + token_request = await auth_flow.asend(registration_response) + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 359fea6197..597b291785 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,6 +7,7 @@ import pytest from pydantic import BaseModel, Field +from mcp import types from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession @@ -288,6 +289,7 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: # First verify that defaults are present in the JSON schema sent to clients async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): # Verify the schema includes defaults + assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requestedSchema props = schema["properties"] diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index fdbb04694c..c108ee7cb3 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1358,6 +1358,26 @@ def prompt_fn(name: str) -> str: # pragma: no cover await client.get_prompt("prompt_fn") +class TestContextCloseSSEStream: + """Tests for the Context.close_sse_stream property.""" + + @pytest.mark.anyio + async def test_close_sse_stream_none_without_streamable_http(self): + """Test that close_sse_stream is None when not using streamable HTTP transport.""" + mcp = FastMCP() + result_holder: list[bool] = [] + + @mcp.tool() + async def check_callback(ctx: Context[ServerSession, None]) -> str: + # Without streamable HTTP transport, close_sse_stream should be None + result_holder.append(ctx.close_sse_stream is None) + return "done" + + async with client_session(mcp._mcp_server) as client: + await client.call_tool("check_callback", {}) + assert result_holder[0] is True + + def test_streamable_http_no_redirect() -> None: """Test that streamable HTTP routes are correctly configured.""" mcp = FastMCP() diff --git a/tests/server/fastmcp/test_url_elicitation.py b/tests/server/fastmcp/test_url_elicitation.py new file mode 100644 index 0000000000..a4d3b2e643 --- /dev/null +++ b/tests/server/fastmcp/test_url_elicitation.py @@ -0,0 +1,394 @@ +"""Test URL mode elicitation feature (SEP 1036).""" + +import anyio +import pytest + +from mcp import types +from mcp.client.session import ClientSession +from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitRequestParams, ElicitResult, TextContent + + +@pytest.mark.anyio +async def test_url_elicitation_accept(): + """Test URL mode elicitation with user acceptance.""" + mcp = FastMCP(name="URLElicitationServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def request_api_key(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Please provide your API key to continue.", + url="https://example.com/api_key_setup", + elicitation_id="test-elicitation-001", + ) + # Test only checks accept path + return f"User {result.action}" + + # Create elicitation callback that accepts URL mode + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + assert params.url == "https://example.com/api_key_setup" + assert params.elicitationId == "test-elicitation-001" + assert params.message == "Please provide your API key to continue." + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("request_api_key", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User accept" + + +@pytest.mark.anyio +async def test_url_elicitation_decline(): + """Test URL mode elicitation with user declining.""" + mcp = FastMCP(name="URLElicitationDeclineServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def oauth_flow(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Authorize access to your files.", + url="https://example.com/oauth/authorize", + elicitation_id="oauth-001", + ) + # Test only checks decline path + return f"User {result.action} authorization" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + return ElicitResult(action="decline") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("oauth_flow", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User decline authorization" + + +@pytest.mark.anyio +async def test_url_elicitation_cancel(): + """Test URL mode elicitation with user cancelling.""" + mcp = FastMCP(name="URLElicitationCancelServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def payment_flow(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Complete payment to proceed.", + url="https://example.com/payment", + elicitation_id="payment-001", + ) + # Test only checks cancel path + return f"User {result.action} payment" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + return ElicitResult(action="cancel") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("payment_flow", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User cancel payment" + + +@pytest.mark.anyio +async def test_url_elicitation_helper_function(): + """Test the elicit_url helper function.""" + from mcp.server.elicitation import elicit_url + + mcp = FastMCP(name="URLElicitationHelperServer") + + @mcp.tool(description="Tool using elicit_url helper") + async def setup_credentials(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Set up your credentials", + url="https://example.com/setup", + elicitation_id="setup-001", + ) + # Test only checks accept path - return the type name + return type(result).__name__ + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("setup_credentials", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "AcceptedUrlElicitation" + + +@pytest.mark.anyio +async def test_url_no_content_in_response(): + """Test that URL mode elicitation responses don't include content field.""" + mcp = FastMCP(name="URLContentCheckServer") + + @mcp.tool(description="Check URL response format") + async def check_url_response(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Test message", + url="https://example.com/test", + elicitation_id="test-001", + ) + + # URL mode responses should not have content + assert result.content is None + return f"Action: {result.action}, Content: {result.content}" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify that this is URL mode + assert params.mode == "url" + assert isinstance(params, types.ElicitRequestURLParams) + # URL params have url and elicitationId, not requestedSchema + assert params.url == "https://example.com/test" + assert params.elicitationId == "test-001" + # Return without content - this is correct for URL mode + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("check_url_response", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Content: None" in result.content[0].text + + +@pytest.mark.anyio +async def test_form_mode_still_works(): + """Ensure form mode elicitation still works after SEP 1036.""" + from pydantic import BaseModel, Field + + mcp = FastMCP(name="FormModeBackwardCompatServer") + + class NameSchema(BaseModel): + name: str = Field(description="Your name") + + @mcp.tool(description="Test form mode") + async def ask_name(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="What is your name?", schema=NameSchema) + # Test only checks accept path with data + assert result.action == "accept" + assert result.data is not None + return f"Hello, {result.data.name}!" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify form mode parameters + assert params.mode == "form" + assert isinstance(params, types.ElicitRequestFormParams) + # Form params have requestedSchema, not url/elicitationId + assert params.requestedSchema is not None + return ElicitResult(action="accept", content={"name": "Alice"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("ask_name", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Hello, Alice!" + + +@pytest.mark.anyio +async def test_elicit_complete_notification(): + """Test that elicitation completion notifications can be sent and received.""" + mcp = FastMCP(name="ElicitCompleteServer") + + # Track if the notification was sent + notification_sent = False + + @mcp.tool(description="Tool that sends completion notification") + async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: + nonlocal notification_sent + + # Simulate an async operation (e.g., user completing auth in browser) + elicitation_id = "complete-test-001" + + # Send completion notification + await ctx.session.send_elicit_complete(elicitation_id) + notification_sent = True + + return "Elicitation completed" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="accept") # pragma: no cover + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("trigger_elicitation", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Elicitation completed" + + # Give time for notification to be processed + await anyio.sleep(0.1) + + # Verify the notification was sent + assert notification_sent + + +@pytest.mark.anyio +async def test_url_elicitation_required_error_code(): + """Test that the URL_ELICITATION_REQUIRED error code is correct.""" + # Verify the error code matches the specification (SEP 1036) + assert types.URL_ELICITATION_REQUIRED == -32042, ( + "URL_ELICITATION_REQUIRED error code must be -32042 per SEP 1036 specification" + ) + + +@pytest.mark.anyio +async def test_elicit_url_typed_results(): + """Test that elicit_url returns properly typed result objects.""" + from mcp.server.elicitation import elicit_url + + mcp = FastMCP(name="TypedResultsServer") + + @mcp.tool(description="Test declined result") + async def test_decline(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Test decline", + url="https://example.com/decline", + elicitation_id="decline-001", + ) + + if isinstance(result, DeclinedElicitation): + return "Declined" + return "Not declined" # pragma: no cover + + @mcp.tool(description="Test cancelled result") + async def test_cancel(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Test cancel", + url="https://example.com/cancel", + elicitation_id="cancel-001", + ) + + if isinstance(result, CancelledElicitation): + return "Cancelled" + return "Not cancelled" # pragma: no cover + + # Test declined result + async def decline_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="decline") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=decline_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("test_decline", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Declined" + + # Test cancelled result + async def cancel_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="cancel") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=cancel_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("test_cancel", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Cancelled" + + +@pytest.mark.anyio +async def test_deprecated_elicit_method(): + """Test the deprecated elicit() method for backward compatibility.""" + from pydantic import BaseModel, Field + + mcp = FastMCP(name="DeprecatedElicitServer") + + class EmailSchema(BaseModel): + email: str = Field(description="Email address") + + @mcp.tool(description="Test deprecated elicit method") + async def use_deprecated_elicit(ctx: Context[ServerSession, None]) -> str: + # Use the deprecated elicit() method which should call elicit_form() + result = await ctx.session.elicit( + message="Enter your email", + requestedSchema=EmailSchema.model_json_schema(), + ) + + if result.action == "accept" and result.content: + return f"Email: {result.content.get('email', 'none')}" + return "No email provided" # pragma: no cover + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify this is form mode + assert params.mode == "form" + assert params.requestedSchema is not None + return ElicitResult(action="accept", content={"email": "test@example.com"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("use_deprecated_elicit", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Email: test@example.com" + + +@pytest.mark.anyio +async def test_ctx_elicit_url_convenience_method(): + """Test the ctx.elicit_url() convenience method (vs ctx.session.elicit_url()).""" + mcp = FastMCP(name="CtxElicitUrlServer") + + @mcp.tool(description="A tool that uses ctx.elicit_url() directly") + async def direct_elicit_url(ctx: Context[ServerSession, None]) -> str: + # Use ctx.elicit_url() directly instead of ctx.session.elicit_url() + result = await ctx.elicit_url( + message="Test the convenience method", + url="https://example.com/test", + elicitation_id="ctx-test-001", + ) + return f"Result: {result.action}" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + assert params.elicitationId == "ctx-test-001" + return ElicitResult(action="accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + result = await client_session.call_tool("direct_elicit_url", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Result: accept" diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py new file mode 100644 index 0000000000..8845dfe781 --- /dev/null +++ b/tests/shared/test_exceptions.py @@ -0,0 +1,159 @@ +"""Tests for MCP exception classes.""" + +import pytest + +from mcp.shared.exceptions import McpError, UrlElicitationRequiredError +from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData + + +class TestUrlElicitationRequiredError: + """Tests for UrlElicitationRequiredError exception class.""" + + def test_create_with_single_elicitation(self) -> None: + """Test creating error with a single elicitation.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert error.error.code == URL_ELICITATION_REQUIRED + assert error.error.message == "URL elicitation required" + assert len(error.elicitations) == 1 + assert error.elicitations[0].elicitationId == "test-123" + + def test_create_with_multiple_elicitations(self) -> None: + """Test creating error with multiple elicitations uses plural message.""" + elicitations = [ + ElicitRequestURLParams( + mode="url", + message="Auth 1", + url="https://example.com/auth1", + elicitationId="test-1", + ), + ElicitRequestURLParams( + mode="url", + message="Auth 2", + url="https://example.com/auth2", + elicitationId="test-2", + ), + ] + error = UrlElicitationRequiredError(elicitations) + + assert error.error.message == "URL elicitations required" # Plural + assert len(error.elicitations) == 2 + + def test_custom_message(self) -> None: + """Test creating error with a custom message.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation], message="Custom message") + + assert error.error.message == "Custom message" + + def test_from_error_data(self) -> None: + """Test reconstructing error from ErrorData.""" + error_data = ErrorData( + code=URL_ELICITATION_REQUIRED, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Auth required", + "url": "https://example.com/auth", + "elicitationId": "test-123", + } + ] + }, + ) + + error = UrlElicitationRequiredError.from_error(error_data) + + assert len(error.elicitations) == 1 + assert error.elicitations[0].elicitationId == "test-123" + assert error.elicitations[0].url == "https://example.com/auth" + + def test_from_error_data_wrong_code(self) -> None: + """Test that from_error raises ValueError for wrong error code.""" + error_data = ErrorData( + code=-32600, # Wrong code + message="Some other error", + data={}, + ) + + with pytest.raises(ValueError, match="Expected error code"): + UrlElicitationRequiredError.from_error(error_data) + + def test_serialization_roundtrip(self) -> None: + """Test that error can be serialized and reconstructed.""" + original = UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + ] + ) + + # Simulate serialization over wire + error_data = original.error + + # Reconstruct + reconstructed = UrlElicitationRequiredError.from_error(error_data) + + assert reconstructed.elicitations[0].elicitationId == original.elicitations[0].elicitationId + assert reconstructed.elicitations[0].url == original.elicitations[0].url + assert reconstructed.elicitations[0].message == original.elicitations[0].message + + def test_error_data_contains_elicitations(self) -> None: + """Test that error data contains properly serialized elicitations.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Please authenticate", + url="https://example.com/oauth", + elicitationId="oauth-flow-1", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert error.error.data is not None + assert "elicitations" in error.error.data + elicit_data = error.error.data["elicitations"][0] + assert elicit_data["mode"] == "url" + assert elicit_data["message"] == "Please authenticate" + assert elicit_data["url"] == "https://example.com/oauth" + assert elicit_data["elicitationId"] == "oauth-flow-1" + + def test_inherits_from_mcp_error(self) -> None: + """Test that UrlElicitationRequiredError inherits from McpError.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert isinstance(error, McpError) + assert isinstance(error, Exception) + + def test_exception_message(self) -> None: + """Test that exception message is set correctly.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + # The exception's string representation should match the message + assert str(error) == "URL elicitation required" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3b70c19dc7..56495f18f6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -22,7 +22,7 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -39,7 +39,7 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.message import ClientMessageMetadata +from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool from tests.test_helpers import wait_for_server @@ -115,9 +115,10 @@ async def replay_events_after( # pragma: no cover # Test server implementation that follows MCP protocol class ServerTest(Server): # pragma: no cover - def __init__(self): + def __init__(self, session_manager_ref: list[StreamableHTTPSessionManager] | None = None): super().__init__(SERVER_NAME) self._lock = None # Will be initialized in async context + self._session_manager_ref = session_manager_ref or [] @self.read_resource() async def handle_read_resource(uri: AnyUrl) -> str | bytes: @@ -163,6 +164,11 @@ async def handle_list_tools() -> list[Tool]: description="A tool that releases the lock", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="tool_with_server_disconnect", + description="A tool that triggers server-initiated SSE disconnect", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() @@ -254,6 +260,37 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] self._lock.set() return [TextContent(type="text", text="Lock released")] + elif name == "tool_with_server_disconnect": + # Send first notification + await ctx.session.send_log_message( + level="info", + data="First notification before disconnect", + logger="disconnect_tool", + related_request_id=ctx.request_id, + ) + + # Trigger server-initiated SSE disconnect + if self._session_manager_ref: + session_manager = self._session_manager_ref[0] + request = ctx.request + if isinstance(request, Request): + session_id = request.headers.get("mcp-session-id") + if session_id: + await session_manager.close_sse_stream(session_id, ctx.request_id) + + # Wait a bit for client to reconnect + await anyio.sleep(0.2) + + # Send second notification after disconnect + await ctx.session.send_log_message( + level="info", + data="Second notification after disconnect", + logger="disconnect_tool", + related_request_id=ctx.request_id, + ) + + return [TextContent(type="text", text="Completed with disconnect")] + return [TextContent(type="text", text=f"Called {name}")] @@ -266,8 +303,11 @@ def create_app( is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. """ - # Create server instance - server = ServerTest() + # Create a reference holder for the session manager + session_manager_ref: list[StreamableHTTPSessionManager] = [] + + # Create server instance with session manager reference + server = ServerTest(session_manager_ref=session_manager_ref) # Create the session manager security_settings = TransportSecuritySettings( @@ -280,6 +320,9 @@ def create_app( security_settings=security_settings, ) + # Store session manager reference for server to access + session_manager_ref.append(session_manager) + # Create an ASGI application that uses the session manager app = Starlette( debug=True, @@ -882,7 +925,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session: """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 7 assert tools.tools[0].name == "test_tool" # Call the tool @@ -919,7 +962,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 7 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -948,7 +991,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 7 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -1019,7 +1062,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 7 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1085,7 +1128,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 7 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1606,3 +1649,333 @@ async def bad_client(): assert isinstance(result, InitializeResult) tools = await session.list_tools() assert tools.tools + + +@pytest.mark.anyio +async def test_reconnection_delay_with_server_retry(): + """Test _get_next_reconnection_delay uses server-provided retry value.""" + from mcp.client.streamable_http import ( + StreamableHTTPReconnectionOptions, + StreamableHTTPTransport, + ) + + transport = StreamableHTTPTransport( + "http://localhost:8000", + reconnection_options=StreamableHTTPReconnectionOptions( + initial_reconnection_delay=1.0, + max_reconnection_delay=30.0, + reconnection_delay_grow_factor=2.0, + max_retries=5, + ), + ) + + # Without server retry, should use exponential backoff + delay_0 = transport._get_next_reconnection_delay(0) + assert delay_0 == 1.0 # initial_delay * 2^0 = 1.0 + + delay_1 = transport._get_next_reconnection_delay(1) + assert delay_1 == 2.0 # initial_delay * 2^1 = 2.0 + + delay_2 = transport._get_next_reconnection_delay(2) + assert delay_2 == 4.0 # initial_delay * 2^2 = 4.0 + + # Should cap at max_reconnection_delay + delay_large = transport._get_next_reconnection_delay(10) + assert delay_large == 30.0 # capped at max + + # Set server-provided retry value + transport._server_retry_seconds = 5.0 + + # Should now use server-provided value regardless of attempt + assert transport._get_next_reconnection_delay(0) == 5.0 + assert transport._get_next_reconnection_delay(5) == 5.0 + assert transport._get_next_reconnection_delay(100) == 5.0 + + +@pytest.mark.anyio +async def test_create_priming_event_with_event_store(): + """Test _create_priming_event generates correct event when event store configured.""" + event_store = SimpleEventStore() + + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + event_store=event_store, + retry_interval=5000, # 5 seconds in ms + ) + + # Create priming event + priming_event = await transport._create_priming_event("stream-123") + + assert priming_event is not None + assert "id" in priming_event + assert priming_event["id"] == "1" # First event ID from SimpleEventStore + assert priming_event["data"] == "" # Empty data for priming + assert priming_event["retry"] == 5000 + + +@pytest.mark.anyio +async def test_create_priming_event_without_retry_interval(): + """Test _create_priming_event without retry interval configured.""" + event_store = SimpleEventStore() + + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + event_store=event_store, + # No retry_interval + ) + + priming_event = await transport._create_priming_event("stream-456") + + assert priming_event is not None + assert "id" in priming_event + assert priming_event["data"] == "" + assert "retry" not in priming_event # No retry field + + +@pytest.mark.anyio +async def test_create_priming_event_without_event_store(): + """Test _create_priming_event returns None without event store.""" + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + # No event_store + ) + + priming_event = await transport._create_priming_event("stream-789") + + assert priming_event is None + + +@pytest.mark.anyio +async def test_close_sse_stream(): + """Test close_sse_stream closes the stream and cleans up.""" + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + ) + + # Manually add a stream to _request_streams + send_stream, recv_stream = anyio.create_memory_object_stream[EventMessage](0) + transport._request_streams["request-123"] = (send_stream, recv_stream) + + assert "request-123" in transport._request_streams + + # Close the stream + await transport.close_sse_stream("request-123") + + # Stream should be removed + assert "request-123" not in transport._request_streams + + +@pytest.mark.anyio +async def test_close_sse_stream_nonexistent(): + """Test close_sse_stream handles nonexistent stream gracefully.""" + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + ) + + # Should not raise even if stream doesn't exist + await transport.close_sse_stream("nonexistent-stream") + + +@pytest.mark.anyio +async def test_close_sse_stream_already_closed(): + """Test close_sse_stream handles already-closed streams gracefully.""" + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + ) + + # Manually add a stream and close it before calling close_sse_stream + send_stream, recv_stream = anyio.create_memory_object_stream[EventMessage](0) + transport._request_streams["request-456"] = (send_stream, recv_stream) + + # Close streams manually first + await send_stream.aclose() + await recv_stream.aclose() + + # Should handle gracefully without error + await transport.close_sse_stream("request-456") + + # Stream should still be removed from dict + assert "request-456" not in transport._request_streams + + +@pytest.mark.anyio +async def test_resume_stream_without_session_id(): + """Test resume_stream returns early without session ID.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport("http://localhost:8000") + assert transport.session_id is None + + # Create a dummy stream writer with type annotation + read_stream_writer, read_stream_reader = anyio.create_memory_object_stream[SessionMessage | Exception](0) + + async with httpx.AsyncClient() as client: + # Should return early without making request + await transport.resume_stream( + client, + read_stream_writer, + "event-123", + ) + + # Clean up streams to avoid resource warnings + await read_stream_writer.aclose() + await read_stream_reader.aclose() + + # No errors should occur + + +@pytest.mark.anyio +async def test_resume_stream_with_405_response(basic_server: None, basic_server_url: str): + """Test resume_stream handles 405 Method Not Allowed gracefully.""" + from mcp.client.streamable_http import StreamableHTTPTransport + + transport = StreamableHTTPTransport(f"{basic_server_url}/mcp") + + # First establish a session via initialization + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + transport.session_id = get_session_id() + + # Now try to resume with the session - server might return 405 + read_stream_writer, read_stream_reader = anyio.create_memory_object_stream[SessionMessage | Exception]( + 0 + ) # pragma: no cover - xdist coverage bug on 3.11 + + async with httpx.AsyncClient() as client: + # This should handle the 405 gracefully + await transport.resume_stream( + client, + read_stream_writer, + "nonexistent-event-id", + ) + + # Clean up streams to avoid resource warnings + await read_stream_writer.aclose() + await read_stream_reader.aclose() + + +@pytest.mark.anyio +async def test_reconnection_options_dataclass(): + """Test StreamableHTTPReconnectionOptions defaults.""" + from mcp.client.streamable_http import StreamableHTTPReconnectionOptions + + options = StreamableHTTPReconnectionOptions() + + assert options.initial_reconnection_delay == 1.0 + assert options.max_reconnection_delay == 30.0 + assert options.reconnection_delay_grow_factor == 1.5 + assert options.max_retries == 2 + + +@pytest.mark.anyio +async def test_streamablehttp_client_with_reconnection_options(basic_server: None, basic_server_url: str): + """Test streamablehttp_client accepts reconnection_options parameter.""" + from mcp.client.streamable_http import StreamableHTTPReconnectionOptions + + options = StreamableHTTPReconnectionOptions( + initial_reconnection_delay=0.5, + max_reconnection_delay=10.0, + reconnection_delay_grow_factor=1.2, + max_retries=3, + ) + + async with streamablehttp_client( + f"{basic_server_url}/mcp", + reconnection_options=options, + ) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + +@pytest.mark.anyio +async def test_streamablehttp_client_auto_reconnection(event_server: tuple[SimpleEventStore, str]): + """Test automatic client reconnection when server closes SSE stream mid-operation.""" + _, server_url = event_server + + # Track notifications received via logging callback + notifications_received: list[str] = [] + + async def logging_callback(params: types.LoggingMessageNotificationParams) -> None: + """Called when a log message notification is received from the server.""" + if params.data: # pragma: no branch + notifications_received.append(str(params.data)) + + # Configure client with reconnection options (fast delays for testing) + reconnection_options = StreamableHTTPReconnectionOptions( + initial_reconnection_delay=0.1, + max_reconnection_delay=1.0, + reconnection_delay_grow_factor=1.2, + max_retries=5, + ) + + async with streamablehttp_client( + f"{server_url}/mcp", + reconnection_options=reconnection_options, + ) as (read_stream, write_stream, get_session_id): + async with ClientSession( + read_stream, + write_stream, + logging_callback=logging_callback, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + session_id = get_session_id() + assert session_id is not None + + # Call the tool that triggers server-initiated disconnect + tool_result = await session.call_tool("tool_with_server_disconnect", {}) + + # Verify the tool completed successfully + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert tool_result.content[0].text == "Completed with disconnect" + + # Verify we received all notifications (before and after disconnect) + assert len(notifications_received) >= 2, ( + f"Expected at least 2 notifications, got {len(notifications_received)}: {notifications_received}" + ) + assert any("before disconnect" in n for n in notifications_received), ( + f"Missing 'before disconnect' notification in: {notifications_received}" + ) + assert any("after disconnect" in n for n in notifications_received), ( + f"Missing 'after disconnect' notification in: {notifications_received}" + ) + + +def test_create_close_sse_stream_callback_without_event_store(): + """Test that _create_close_sse_stream_callback returns None without event store.""" + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + event_store=None, # No event store + ) + callback = transport._create_close_sse_stream_callback("test-request-id") + assert callback is None + + +@pytest.mark.anyio +async def test_create_close_sse_stream_callback_with_event_store(): + """Test that _create_close_sse_stream_callback returns a working callback with event store.""" + event_store = SimpleEventStore() + transport = StreamableHTTPServerTransport( + mcp_session_id="test-session", + event_store=event_store, + ) + + callback = transport._create_close_sse_stream_callback("test-request-id") + assert callback is not None + + # The callback should call close_sse_stream which returns False for non-existent stream + result = await callback(retry_interval=1000) + assert result is False # No stream to close