From 3441e8270ed1529179cbd4956d3006da852cbfc9 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 13:31:03 +0000 Subject: [PATCH 01/10] Add TDD stubs and failing tests for SSE polling (close_sse_stream) This commit adds the API stubs and failing tests for the server-side disconnect feature that enables SSE polling. When implemented, this will allow servers to disconnect SSE streams without terminating them, triggering client reconnection for polling patterns. API stubs added: - CloseSSEStreamCallback type in message.py - close_sse_stream field in ServerMessageMetadata and RequestContext - close_sse_stream() stub in StreamableHTTPServerTransport - close_sse_stream() stub in FastMCP Context - retry_interval parameter in session manager and transport Tests added (all expected to fail until implementation): - test_streamablehttp_client_receives_priming_event - test_server_close_sse_stream_via_context - test_streamablehttp_client_auto_reconnects - test_streamablehttp_client_respects_retry_interval - test_streamablehttp_sse_polling_full_cycle - test_streamablehttp_events_replayed_after_disconnect Github-Issue:#1699 --- src/mcp/server/fastmcp/server.py | 19 ++ src/mcp/server/streamable_http.py | 26 ++ src/mcp/server/streamable_http_manager.py | 6 + src/mcp/shared/context.py | 2 + src/mcp/shared/message.py | 5 + tests/shared/test_streamable_http.py | 333 +++++++++++++++++++++- 6 files changed, 380 insertions(+), 11 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2e596c9f9a..02566b772a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1282,6 +1282,25 @@ def session(self): """Access to the underlying session for advanced usage.""" return self.request_context.session + async def close_sse_stream(self) -> None: + """Close the SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the current request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + The callback is only available when event_store is configured. + + Raises: + NotImplementedError: Feature not yet implemented. + """ + raise NotImplementedError("close_sse_stream not yet implemented") + # Convenience methods for common log levels async def debug(self, message: str, **extra: Any) -> None: """Send a debug log message.""" diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index d6ccfd5a82..2f28ee756f 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -140,6 +140,7 @@ def __init__( is_json_response_enabled: bool = False, event_store: EventStore | None = None, security_settings: TransportSecuritySettings | None = None, + retry_interval: int | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -153,6 +154,10 @@ def __init__( resumability will be enabled, allowing clients to reconnect and resume messages. security_settings: Optional security settings for DNS rebinding protection. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE + retry field. When set, the server will send a retry field in + SSE priming events to control client reconnection timing for + polling behavior. Only used when event_store is provided. Raises: ValueError: If the session ID contains invalid characters. @@ -164,6 +169,7 @@ def __init__( self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._security = TransportSecurityMiddleware(security_settings) + self._retry_interval = retry_interval self._request_streams: dict[ RequestId, tuple[ @@ -178,6 +184,26 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated + def close_sse_stream(self, request_id: RequestId) -> None: + """Close SSE connection for a specific request without terminating the stream. + + This method closes the HTTP connection for the specified request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + client will reconnect after the retry interval specified in the priming event. + + Args: + request_id: The request ID whose SSE stream should be closed. + + Note: + This is a no-op if there is no active stream for the request ID. + Requires event_store to be configured for events to be stored during + the disconnect. + """ + raise NotImplementedError("close_sse_stream not yet implemented") + def _create_error_response( self, error_message: str, diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 04c7de2d7b..50d2aefa29 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -51,6 +51,9 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. + security_settings: Optional transport security settings. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE + retry field. Used for SSE polling behavior. """ def __init__( @@ -60,12 +63,14 @@ def __init__( json_response: bool = False, stateless: bool = False, security_settings: TransportSecuritySettings | None = None, + retry_interval: int | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings + self.retry_interval = retry_interval # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -226,6 +231,7 @@ async def _handle_stateful_request( is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) security_settings=self.security_settings, + retry_interval=self.retry_interval, ) assert http_transport.mcp_session_id is not None diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5f..aaeadd70b6 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -3,6 +3,7 @@ from typing_extensions import TypeVar +from mcp.shared.message import CloseSSEStreamCallback from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams @@ -18,3 +19,4 @@ class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): session: SessionT lifespan_context: LifespanContextT request: RequestT | None = None + close_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 4b6df23eb6..7104e52a5d 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -14,6 +14,9 @@ ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] +# Callback type for closing SSE streams without terminating +CloseSSEStreamCallback = Callable[[], Awaitable[None]] + @dataclass class ClientMessageMetadata: @@ -30,6 +33,8 @@ class ServerMessageMetadata: related_request_id: RequestId | None = None # Request-specific context (e.g., headers, auth info) request_context: object | None = None + # Callback to close SSE stream without terminating + close_sse_stream: CloseSSEStreamCallback | None = None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 8e8884270e..4eda459f35 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -7,6 +7,7 @@ import json import multiprocessing import socket +import time from collections.abc import Generator from typing import Any @@ -164,6 +165,16 @@ async def handle_list_tools() -> list[Tool]: description="A tool that releases the lock", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() @@ -255,17 +266,68 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] self._lock.set() return [TextContent(type="text", text="Lock released")] + elif name == "tool_with_stream_close": + # Send notification before closing + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + # Close SSE stream (triggers client reconnect) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Continue processing (events stored in event_store) + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="Done")] + + elif name == "tool_with_multiple_notifications_and_close": + # Send notification1 + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + # Close SSE stream + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Send notification2, notification3 (stored in event_store) + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="All notifications sent")] + return [TextContent(type="text", text=f"Called {name}")] def create_app( - is_json_response_enabled: bool = False, event_store: EventStore | None = None + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, ) -> Starlette: # pragma: no cover """Create a Starlette application for testing using the session manager. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. + retry_interval: Retry interval in milliseconds for SSE polling. """ # Create server instance server = ServerTest() @@ -279,6 +341,7 @@ def create_app( event_store=event_store, json_response=is_json_response_enabled, security_settings=security_settings, + retry_interval=retry_interval, ) # Create an ASGI application that uses the session manager @@ -294,7 +357,10 @@ def create_app( def run_server( - port: int, is_json_response_enabled: bool = False, event_store: EventStore | None = None + port: int, + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, ) -> None: # pragma: no cover """Run the test server. @@ -302,9 +368,10 @@ def run_server( port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. + retry_interval: Retry interval in milliseconds for SSE polling. """ - app = create_app(is_json_response_enabled, event_store) + app = create_app(is_json_response_enabled, event_store, retry_interval) # Configure server config = uvicorn.Config( app=app, @@ -379,10 +446,10 @@ def event_server_port() -> int: def event_server( event_server_port: int, event_store: SimpleEventStore ) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store enabled.""" + """Start a server with event store and retry_interval enabled.""" proc = multiprocessing.Process( target=run_server, - kwargs={"port": event_server_port, "event_store": event_store}, + kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, daemon=True, ) proc.start() @@ -883,7 +950,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session: """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 8 assert tools.tools[0].name == "test_tool" # Call the tool @@ -920,7 +987,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 8 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -949,7 +1016,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 8 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -962,7 +1029,6 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j async def test_streamablehttp_client_get_stream(basic_server: None, basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" import mcp.types as types - from mcp.shared.session import RequestResponder notifications_received: list[types.ServerNotification] = [] @@ -1020,7 +1086,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 8 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1086,7 +1152,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 8 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1633,3 +1699,248 @@ async def test_handle_sse_event_skips_empty_data(): finally: await write_stream.aclose() await read_stream.aclose() + + +@pytest.mark.anyio +async def test_streamablehttp_client_receives_priming_event( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Client should receive priming event (resumption token update) on POST SSE stream.""" + _, server_url = event_server + + captured_resumption_tokens: list[str] = [] + + async def on_resumption_token_update(token: str) -> None: + captured_resumption_tokens.append(token) + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams(name="test_tool", arguments={}), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None + + +@pytest.mark.anyio +async def test_server_close_sse_stream_via_context( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Server tool can call ctx.close_sse_stream() to close connection.""" + _, server_url = event_server + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call tool that closes stream mid-operation + # This should NOT raise NotImplementedError when fully implemented + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + +@pytest.mark.anyio +async def test_streamablehttp_client_auto_reconnects( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" + _, server_url = event_server + captured_notifications: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + return + if isinstance(message, types.ServerNotification): + if isinstance(message.root, types.LoggingMessageNotification): + captured_notifications.append(str(message.root.params.data)) + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + +@pytest.mark.anyio +async def test_streamablehttp_client_respects_retry_interval( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Client MUST respect retry field, waiting specified ms before reconnecting.""" + _, server_url = event_server + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time + + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + # The elapsed time should include at least the retry interval + # if reconnection occurred. This test may be flaky depending on + # implementation details, but demonstrates the expected behavior. + # Note: This assertion may need adjustment based on actual implementation + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + + +@pytest.mark.anyio +async def test_streamablehttp_sse_polling_full_cycle( + event_server: tuple[SimpleEventStore, str], +) -> None: + """End-to-end test: server closes stream, client reconnects, receives all events.""" + _, server_url = event_server + all_notifications: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + return + if isinstance(message, types.ServerNotification): + if isinstance(message.root, types.LoggingMessageNotification): + all_notifications.append(str(message.root.params.data)) + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + +@pytest.mark.anyio +async def test_streamablehttp_events_replayed_after_disconnect( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Events sent while client is disconnected should be replayed on reconnect.""" + _, server_url = event_server + notification_data: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + return + if isinstance(message, types.ServerNotification): + if isinstance(message.root, types.LoggingMessageNotification): + notification_data.append(str(message.root.params.data)) + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" + + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" From 1dfe97e30b104c79c501b3a9bec6946c69600948 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 14:12:11 +0000 Subject: [PATCH 02/10] Implement SSE priming events for resumability (SEP-1699) Server now sends a priming event (SSE event with ID but empty data) at the start of POST SSE streams when an EventStore is configured. This enables clients to reconnect with Last-Event-ID even if the server closes the connection before sending any actual data. Changes: - EventStore.store_event now accepts JSONRPCMessage | None (None for priming) - Server sends priming event before processing messages in sse_writer - Client calls resumption callback for empty-data events that have an ID --- src/mcp/client/streamable_http.py | 5 ++++- src/mcp/server/streamable_http.py | 15 +++++++++++++-- tests/shared/test_streamable_http.py | 4 ++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 1b32c022ee..47362a4338 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -160,8 +160,11 @@ async def _handle_sse_event( ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": - # Skip empty data (keep-alive pings) + # Handle priming events (empty data with ID) for resumability if not sse.data: + # Call resumption callback for priming events that have an ID + if sse.id and resumption_callback: + await resumption_callback(sse.id) return False try: message = JSONRPCMessage.model_validate_json(sse.data) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 2f28ee756f..299812f9a3 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -87,13 +87,13 @@ class EventStore(ABC): """ @abstractmethod - async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: """ Stores an event for later retrieval. Args: stream_id: ID of the stream the event belongs to - message: The JSON-RPC message to store + message: The JSON-RPC message to store, or None for priming events Returns: The generated event ID for the stored event @@ -489,6 +489,17 @@ async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: + # Send priming event if event_store is configured + # This sends an event with ID but empty data, enabling + # the client to reconnect with Last-Event-ID if needed + if self._event_store: + priming_event_id = await self._event_store.store_event( + request_id, + None, # Priming event has no payload + ) + priming_event = {"id": priming_event_id, "data": ""} + await sse_stream_writer.send(priming_event) + # Process messages from the request-specific stream async for event_message in request_stream_reader: # Build the event data diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 4eda459f35..8efc541b20 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -77,10 +77,10 @@ class SimpleEventStore(EventStore): """Simple in-memory event store for testing.""" def __init__(self): - self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] + self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: # pragma: no cover + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: # pragma: no cover """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) From 21db2a671e17e09257489a95eb7bfe60c7f39698 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 14:24:24 +0000 Subject: [PATCH 03/10] Implement close_sse_stream and client auto-reconnect (SEP-1699) Server now supports closing SSE streams mid-operation via close_sse_stream(), which triggers client reconnection. Client automatically reconnects when the stream closes after receiving a priming event. Changes: - Server transport: Implement close_sse_stream() to close SSE writer - Server transport: Create callback and pass via ServerMessageMetadata - Lowlevel server: Thread close_sse_stream callback to RequestContext - FastMCP Context: Wire close_sse_stream() to call the callback - Client: Track priming events and auto-reconnect with Last-Event-ID --- src/mcp/client/streamable_http.py | 65 +++++++++++++++++++++++++++++-- src/mcp/server/fastmcp/server.py | 6 +-- src/mcp/server/lowlevel/server.py | 5 ++- src/mcp/server/streamable_http.py | 29 ++++++++++++-- 4 files changed, 94 insertions(+), 11 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 47362a4338..38922d315e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -329,9 +329,15 @@ async def _handle_sse_response( is_initialization: bool = False, ) -> None: """Handle SSE response from the server.""" + last_event_id: str | None = None + try: event_source = EventSource(response) async for sse in event_source.aiter_sse(): # pragma: no branch + # Track last event ID for potential reconnection + if sse.id: + last_event_id = sse.id + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -342,10 +348,63 @@ async def _handle_sse_response( # break the loop if is_complete: await response.aclose() - break + return # Normal completion, no reconnect needed + except Exception as e: + logger.debug(f"SSE stream ended: {e}") + + # Stream ended without response - reconnect if we have priming event + if last_event_id is not None: + await self._handle_reconnection(ctx, last_event_id) + + async def _handle_reconnection( + self, + ctx: RequestContext, + last_event_id: str, + ) -> None: + """Reconnect with Last-Event-ID to resume stream after server disconnect.""" + headers = self._prepare_request_headers(ctx.headers) + headers[LAST_EVENT_ID] = last_event_id + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): + original_request_id = ctx.session_message.message.root.id + + try: + async with aconnect_sse( + ctx.client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("Reconnection GET SSE connection established") + + # Track for potential further reconnection + reconnect_last_event_id: str | None = last_event_id + + async for sse in event_source.aiter_sse(): + if sse.id: + reconnect_last_event_id = sse.id + + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + await event_source.response.aclose() + return + + # Stream ended again without response - reconnect again + if reconnect_last_event_id is not None: + await self._handle_reconnection(ctx, reconnect_last_event_id) except Exception as e: - logger.exception("Error reading SSE stream:") # pragma: no cover - await ctx.read_stream_writer.send(e) # pragma: no cover + logger.debug(f"Reconnection failed: {e}") + # Try to reconnect again if we still have an event ID + await self._handle_reconnection(ctx, last_event_id) async def _handle_unexpected_content_type( self, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 02566b772a..921d8f540a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1295,11 +1295,9 @@ async def close_sse_stream(self) -> None: Note: This is a no-op if not using StreamableHTTP transport with event_store. The callback is only available when event_store is configured. - - Raises: - NotImplementedError: Feature not yet implemented. """ - raise NotImplementedError("close_sse_stream not yet implemented") + if self._request_context and self._request_context.close_sse_stream: + await self._request_context.close_sse_stream() # Convenience methods for common log levels async def debug(self, message: str, **extra: Any) -> None: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a0617036f9..8a64b463cf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -680,12 +680,14 @@ async def _handle_request( token = None try: - # Extract request context from message metadata + # Extract request context and close_sse_stream from message metadata request_data = None + close_sse_stream_cb = None if message.message_metadata is not None and isinstance( message.message_metadata, ServerMessageMetadata ): # pragma: no cover request_data = message.message_metadata.request_context + close_sse_stream_cb = message.message_metadata.close_sse_stream # Set our global state that can be retrieved via # app.get_request_context() @@ -696,6 +698,7 @@ async def _handle_request( session, lifespan_context, request=request_data, + close_sse_stream=close_sse_stream_cb, ) ) response = await handler(req) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 299812f9a3..09e7432938 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -177,6 +177,7 @@ def __init__( MemoryObjectReceiveStream[EventMessage], ], ] = {} + self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False @property @@ -202,7 +203,26 @@ def close_sse_stream(self, request_id: RequestId) -> None: Requires event_store to be configured for events to be stored during the disconnect. """ - raise NotImplementedError("close_sse_stream not yet implemented") + writer = self._sse_stream_writers.pop(request_id, None) + if writer: + writer.close() + + def _create_session_message( + self, + message: JSONRPCMessage, + request: Request, + request_id: RequestId, + ) -> SessionMessage: + """Create a session message with metadata including close_sse_stream callback.""" + + async def close_stream_callback() -> None: + self.close_sse_stream(request_id) + + metadata = ServerMessageMetadata( + request_context=request, + close_sse_stream=close_stream_callback, + ) + return SessionMessage(message, metadata=metadata) def _create_error_response( self, @@ -485,6 +505,9 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + # Store writer reference so close_sse_stream() can close it + self._sse_stream_writers[request_id] = sse_stream_writer + async def sse_writer(): # Get the request ID from the incoming request message try: @@ -516,6 +539,7 @@ async def sse_writer(): logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") + self._sse_stream_writers.pop(request_id, None) await self._clean_up_memory_streams(request_id) # Create and start EventSourceResponse @@ -539,8 +563,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - metadata = ServerMessageMetadata(request_context=request) - session_message = SessionMessage(message, metadata=metadata) + session_message = self._create_session_message(message, request, request_id) await writer.send(session_message) except Exception: logger.exception("SSE response error") From 13e43e39b18bb0e30a82df9ed8daa9a4e21886ef Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 14:42:54 +0000 Subject: [PATCH 04/10] Implement retry interval support for SSE polling (SEP-1699) Server now sends the retry field in SSE priming events when retry_interval is configured. Client respects this field and waits the specified interval before reconnecting. Changes: - Server: Add retry field to priming event when retry_interval is set - Server: Extract _send_priming_event() helper method - Client: Track retry interval from SSE events - Client: Wait for retry interval before reconnecting --- src/mcp/client/streamable_http.py | 21 +++++++++++++++++---- src/mcp/server/streamable_http.py | 30 ++++++++++++++++++++---------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 38922d315e..b252d2a256 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -330,6 +330,7 @@ async def _handle_sse_response( ) -> None: """Handle SSE response from the server.""" last_event_id: str | None = None + retry_interval_ms: int | None = None try: event_source = EventSource(response) @@ -338,6 +339,10 @@ async def _handle_sse_response( if sse.id: last_event_id = sse.id + # Track retry interval from server + if sse.retry is not None: + retry_interval_ms = sse.retry + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -352,16 +357,21 @@ async def _handle_sse_response( except Exception as e: logger.debug(f"SSE stream ended: {e}") - # Stream ended without response - reconnect if we have priming event + # Stream ended without response - reconnect if we received an event with ID if last_event_id is not None: - await self._handle_reconnection(ctx, last_event_id) + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) async def _handle_reconnection( self, ctx: RequestContext, last_event_id: str, + retry_interval_ms: int | None = None, ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" + # Wait for retry interval if specified by server + if retry_interval_ms is not None: + await anyio.sleep(retry_interval_ms / 1000.0) + headers = self._prepare_request_headers(ctx.headers) headers[LAST_EVENT_ID] = last_event_id @@ -383,10 +393,13 @@ async def _handle_reconnection( # Track for potential further reconnection reconnect_last_event_id: str | None = last_event_id + reconnect_retry_ms = retry_interval_ms async for sse in event_source.aiter_sse(): if sse.id: reconnect_last_event_id = sse.id + if sse.retry is not None: + reconnect_retry_ms = sse.retry is_complete = await self._handle_sse_event( sse, @@ -400,11 +413,11 @@ async def _handle_reconnection( # Stream ended again without response - reconnect again if reconnect_last_event_id is not None: - await self._handle_reconnection(ctx, reconnect_last_event_id) + await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms) except Exception as e: logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID - await self._handle_reconnection(ctx, last_event_id) + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) async def _handle_unexpected_content_type( self, diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 09e7432938..b962f1131c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -15,6 +15,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from http import HTTPStatus +from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -224,6 +225,23 @@ async def close_stream_callback() -> None: ) return SessionMessage(message, metadata=metadata) + async def _send_priming_event( + self, + request_id: RequestId, + sse_stream_writer: MemoryObjectSendStream[dict[str, Any]], + ) -> None: + """Send priming event for SSE resumability if event_store is configured.""" + if not self._event_store: + return + priming_event_id = await self._event_store.store_event( + str(request_id), # Convert RequestId to StreamId (str) + None, # Priming event has no payload + ) + priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""} + if self._retry_interval is not None: + priming_event["retry"] = self._retry_interval + await sse_stream_writer.send(priming_event) + def _create_error_response( self, error_message: str, @@ -512,16 +530,8 @@ async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: - # Send priming event if event_store is configured - # This sends an event with ID but empty data, enabling - # the client to reconnect with Last-Event-ID if needed - if self._event_store: - priming_event_id = await self._event_store.store_event( - request_id, - None, # Priming event has no payload - ) - priming_event = {"id": priming_event_id, "data": ""} - await sse_stream_writer.send(priming_event) + # Send priming event for SSE resumability + await self._send_priming_event(request_id, sse_stream_writer) # Process messages from the request-specific stream async for event_message in request_stream_reader: From e893de911b20af34b4bd0d12c9d64a087cf9cb4a Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 14:54:04 +0000 Subject: [PATCH 05/10] Add reconnection backoff and max retry protection (SEP-1699) Prevents potential DDOS when server doesn't provide retry interval. Changes: - Always wait before reconnecting (server retry value or 1s default) - Track failed attempts only - successful reconnections reset counter - Bail after 2 consecutive failures --- src/mcp/client/streamable_http.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index b252d2a256..ce8ce12c94 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -42,6 +42,10 @@ MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" LAST_EVENT_ID = "last-event-id" + +# Reconnection defaults +DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry +MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up CONTENT_TYPE = "content-type" ACCEPT = "accept" @@ -366,11 +370,17 @@ async def _handle_reconnection( ctx: RequestContext, last_event_id: str, retry_interval_ms: int | None = None, + attempt: int = 0, ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" - # Wait for retry interval if specified by server - if retry_interval_ms is not None: - await anyio.sleep(retry_interval_ms / 1000.0) + # Bail if max retries exceeded + if attempt >= MAX_RECONNECTION_ATTEMPTS: + logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") + return + + # Always wait - use server value or default + delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + await anyio.sleep(delay_ms / 1000.0) headers = self._prepare_request_headers(ctx.headers) headers[LAST_EVENT_ID] = last_event_id @@ -411,13 +421,15 @@ async def _handle_reconnection( await event_source.response.aclose() return - # Stream ended again without response - reconnect again + # Stream ended again without response - reconnect again (reset attempt counter) if reconnect_last_event_id is not None: - await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms) + await self._handle_reconnection( + ctx, reconnect_last_event_id, reconnect_retry_ms, 0 + ) except Exception as e: logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID - await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) async def _handle_unexpected_content_type( self, From 316135392061b6555debbb7b2667ad780103d495 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 15:03:05 +0000 Subject: [PATCH 06/10] Add SSE polling demo server and client examples (SEP-1699) Demonstrates the SSE polling pattern with close_sse_stream(): - Server: process_batch tool that checkpoints periodically - Client: auto-reconnects transparently with Last-Event-ID - Shows priming events, retry interval, and event replay --- examples/clients/sse-polling-client/README.md | 30 +++ .../mcp_sse_polling_client/__init__.py | 1 + .../mcp_sse_polling_client/main.py | 105 +++++++++++ .../clients/sse-polling-client/pyproject.toml | 36 ++++ examples/servers/sse-polling-demo/README.md | 36 ++++ .../mcp_sse_polling_demo/__init__.py | 1 + .../mcp_sse_polling_demo/__main__.py | 6 + .../mcp_sse_polling_demo/event_store.py | 100 ++++++++++ .../mcp_sse_polling_demo/server.py | 177 ++++++++++++++++++ .../servers/sse-polling-demo/pyproject.toml | 36 ++++ 10 files changed, 528 insertions(+) create mode 100644 examples/clients/sse-polling-client/README.md create mode 100644 examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py create mode 100644 examples/clients/sse-polling-client/mcp_sse_polling_client/main.py create mode 100644 examples/clients/sse-polling-client/pyproject.toml create mode 100644 examples/servers/sse-polling-demo/README.md create mode 100644 examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py create mode 100644 examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py create mode 100644 examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py create mode 100644 examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py create mode 100644 examples/servers/sse-polling-demo/pyproject.toml diff --git a/examples/clients/sse-polling-client/README.md b/examples/clients/sse-polling-client/README.md new file mode 100644 index 0000000000..68fdf1c35d --- /dev/null +++ b/examples/clients/sse-polling-client/README.md @@ -0,0 +1,30 @@ +# MCP SSE Polling Demo Client + +Demonstrates client-side auto-reconnect for the SSE polling pattern (SEP-1699). + +## Features + +- Connects to SSE polling demo server +- Automatically reconnects when server closes SSE stream +- Resumes from Last-Event-ID to avoid missing messages +- Respects server-provided retry interval + +## Usage + +```bash +# First start the server: +uv run mcp-sse-polling-demo --port 3000 + +# Then run this client: +uv run mcp-sse-polling-client --url http://localhost:3000/mcp + +# Custom options: +uv run mcp-sse-polling-client --url http://localhost:3000/mcp --items 20 --checkpoint-every 5 +``` + +## Options + +- `--url`: Server URL (default: http://localhost:3000/mcp) +- `--items`: Number of items to process (default: 10) +- `--checkpoint-every`: Checkpoint interval (default: 3) +- `--log-level`: Logging level (default: DEBUG) diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py new file mode 100644 index 0000000000..ee69b32c96 --- /dev/null +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py @@ -0,0 +1 @@ +"""SSE Polling Demo Client - demonstrates auto-reconnect for long-running tasks.""" diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py new file mode 100644 index 0000000000..1defd8eaa4 --- /dev/null +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py @@ -0,0 +1,105 @@ +""" +SSE Polling Demo Client + +Demonstrates the client-side auto-reconnect for SSE polling pattern. + +This client connects to the SSE Polling Demo server and calls process_batch, +which triggers periodic server-side stream closes. The client automatically +reconnects using Last-Event-ID and resumes receiving messages. + +Run with: + # First start the server: + uv run mcp-sse-polling-demo --port 3000 + + # Then run this client: + uv run mcp-sse-polling-client --url http://localhost:3000/mcp +""" + +import asyncio +import logging + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + +logger = logging.getLogger(__name__) + + +async def run_demo(url: str, items: int, checkpoint_every: int) -> None: + """Run the SSE polling demo.""" + print(f"\n{'=' * 60}") + print("SSE Polling Demo Client") + print(f"{'=' * 60}") + print(f"Server URL: {url}") + print(f"Processing {items} items with checkpoints every {checkpoint_every}") + print(f"{'=' * 60}\n") + + async with streamablehttp_client(url) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the connection + print("Initializing connection...") + await session.initialize() + print("Connected!\n") + + # List available tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}\n") + + # Call the process_batch tool + print(f"Calling process_batch(items={items}, checkpoint_every={checkpoint_every})...\n") + print("-" * 40) + + result = await session.call_tool( + "process_batch", + { + "items": items, + "checkpoint_every": checkpoint_every, + }, + ) + + print("-" * 40) + if result.content: + content = result.content[0] + text = getattr(content, "text", str(content)) + print(f"\nResult: {text}") + else: + print("\nResult: No content") + print(f"{'=' * 60}\n") + + +@click.command() +@click.option( + "--url", + default="http://localhost:3000/mcp", + help="Server URL", +) +@click.option( + "--items", + default=10, + help="Number of items to process", +) +@click.option( + "--checkpoint-every", + default=3, + help="Checkpoint interval", +) +@click.option( + "--log-level", + default="INFO", + help="Logging level", +) +def main(url: str, items: int, checkpoint_every: int, log_level: str) -> None: + """Run the SSE Polling Demo client.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + # Suppress noisy HTTP client logging + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + asyncio.run(run_demo(url, items, checkpoint_every)) + + +if __name__ == "__main__": + main() diff --git a/examples/clients/sse-polling-client/pyproject.toml b/examples/clients/sse-polling-client/pyproject.toml new file mode 100644 index 0000000000..ae896708d4 --- /dev/null +++ b/examples/clients/sse-polling-client/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-sse-polling-client" +version = "0.1.0" +description = "Demo client for SSE polling with auto-reconnect" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "sse", "polling", "client"] +license = { text = "MIT" } +dependencies = ["click>=8.2.0", "mcp"] + +[project.scripts] +mcp-sse-polling-client = "mcp_sse_polling_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_sse_polling_client"] + +[tool.pyright] +include = ["mcp_sse_polling_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/sse-polling-demo/README.md b/examples/servers/sse-polling-demo/README.md new file mode 100644 index 0000000000..e9d4446e1f --- /dev/null +++ b/examples/servers/sse-polling-demo/README.md @@ -0,0 +1,36 @@ +# MCP SSE Polling Demo Server + +Demonstrates the SSE polling pattern with server-initiated stream close for long-running tasks (SEP-1699). + +## Features + +- Priming events (automatic with EventStore) +- Server-initiated stream close via `close_sse_stream()` callback +- Client auto-reconnect with Last-Event-ID +- Progress notifications during long-running tasks +- Configurable retry interval + +## Usage + +```bash +# Start server on default port +uv run mcp-sse-polling-demo --port 3000 + +# Custom retry interval (milliseconds) +uv run mcp-sse-polling-demo --port 3000 --retry-interval 100 +``` + +## Tool: process_batch + +Processes items with periodic checkpoints that trigger SSE stream closes: + +- `items`: Number of items to process (1-100, default: 10) +- `checkpoint_every`: Close stream after this many items (1-20, default: 3) + +## Client + +Use the companion `mcp-sse-polling-client` to test: + +```bash +uv run mcp-sse-polling-client --url http://localhost:3000/mcp +``` diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py new file mode 100644 index 0000000000..46af2fdeed --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py @@ -0,0 +1 @@ +"""SSE Polling Demo Server - demonstrates close_sse_stream for long-running tasks.""" diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py new file mode 100644 index 0000000000..23cfc85e11 --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py @@ -0,0 +1,6 @@ +"""Entry point for the SSE Polling Demo server.""" + +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py new file mode 100644 index 0000000000..75f98cdd49 --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py @@ -0,0 +1,100 @@ +""" +In-memory event store for demonstrating resumability functionality. + +This is a simple implementation intended for examples and testing, +not for production use where a persistent storage solution would be more appropriate. +""" + +import logging +from collections import deque +from dataclasses import dataclass +from uuid import uuid4 + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +@dataclass +class EventEntry: + """Represents an event entry in the event store.""" + + event_id: EventId + stream_id: StreamId + message: JSONRPCMessage | None # None for priming events + + +class InMemoryEventStore(EventStore): + """ + Simple in-memory implementation of the EventStore interface for resumability. + This is primarily intended for examples and testing, not for production use + where a persistent storage solution would be more appropriate. + + This implementation keeps only the last N events per stream for memory efficiency. + """ + + def __init__(self, max_events_per_stream: int = 100): + """Initialize the event store. + + Args: + max_events_per_stream: Maximum number of events to keep per stream + """ + self.max_events_per_stream = max_events_per_stream + # for maintaining last N events per stream + self.streams: dict[StreamId, deque[EventEntry]] = {} + # event_id -> EventEntry for quick lookup + self.event_index: dict[EventId, EventEntry] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + """Stores an event with a generated event ID. + + Args: + stream_id: ID of the stream the event belongs to + message: The message to store, or None for priming events + """ + event_id = str(uuid4()) + event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) + + # Get or create deque for this stream + if stream_id not in self.streams: + self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) + + # If deque is full, the oldest event will be automatically removed + # We need to remove it from the event_index as well + if len(self.streams[stream_id]) == self.max_events_per_stream: + oldest_event = self.streams[stream_id][0] + self.event_index.pop(oldest_event.event_id, None) + + # Add new event + self.streams[stream_id].append(event_entry) + self.event_index[event_id] = event_entry + + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replays events that occurred after the specified event ID.""" + if last_event_id not in self.event_index: + logger.warning(f"Event ID {last_event_id} not found in store") + return None + + # Get the stream and find events after the last one + last_event = self.event_index[last_event_id] + stream_id = last_event.stream_id + stream_events = self.streams.get(last_event.stream_id, deque()) + + # Events in deque are already in chronological order + found_last = False + for event in stream_events: + if found_last: + # Skip priming events (None messages) during replay + if event.message is not None: + await send_callback(EventMessage(event.message, event.event_id)) + elif event.event_id == last_event_id: + found_last = True + + return stream_id diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py new file mode 100644 index 0000000000..e4bdcaa396 --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -0,0 +1,177 @@ +""" +SSE Polling Demo Server + +Demonstrates the SSE polling pattern with close_sse_stream() for long-running tasks. + +Features demonstrated: +- Priming events (automatic with EventStore) +- Server-initiated stream close via close_sse_stream callback +- Client auto-reconnect with Last-Event-ID +- Progress notifications during long-running tasks + +Run with: + uv run mcp-sse-polling-demo --port 3000 +""" + +import contextlib +import logging +from collections.abc import AsyncIterator +from typing import Any + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +from .event_store import InMemoryEventStore + +logger = logging.getLogger(__name__) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR)", +) +@click.option( + "--retry-interval", + default=100, + help="SSE retry interval in milliseconds (sent to client)", +) +def main(port: int, log_level: str, retry_interval: int) -> int: + """Run the SSE Polling Demo server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Create the lowlevel server + app = Server("sse-polling-demo") + + @app.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + """Handle tool calls.""" + ctx = app.request_context + + if name == "process_batch": + items = arguments.get("items", 10) + checkpoint_every = arguments.get("checkpoint_every", 3) + + if items < 1 or items > 100: + return [types.TextContent(type="text", text="Error: items must be between 1 and 100")] + if checkpoint_every < 1 or checkpoint_every > 20: + return [types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + + await ctx.session.send_log_message( + level="info", + data=f"Starting batch processing of {items} items...", + logger="process_batch", + related_request_id=ctx.request_id, + ) + + for i in range(1, items + 1): + # Simulate work + await anyio.sleep(0.5) + + # Report progress + await ctx.session.send_log_message( + level="info", + data=f"[{i}/{items}] Processing item {i}", + logger="process_batch", + related_request_id=ctx.request_id, + ) + + # Checkpoint: close stream to trigger client reconnect + if i % checkpoint_every == 0 and i < items: + await ctx.session.send_log_message( + level="info", + data=f"Checkpoint at item {i} - closing SSE stream for polling", + logger="process_batch", + related_request_id=ctx.request_id, + ) + if ctx.close_sse_stream: + logger.info(f"Closing SSE stream at checkpoint {i}") + await ctx.close_sse_stream() + # Wait for client to reconnect (must be > retry_interval of 100ms) + await anyio.sleep(0.2) + + return [ + types.TextContent( + type="text", + text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", + ) + ] + + return [types.TextContent(type="text", text=f"Unknown tool: {name}")] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + """List available tools.""" + return [ + types.Tool( + name="process_batch", + description=( + "Process a batch of items with periodic checkpoints. " + "Demonstrates SSE polling where server closes stream periodically." + ), + inputSchema={ + "type": "object", + "properties": { + "items": { + "type": "integer", + "description": "Number of items to process (1-100)", + "default": 10, + }, + "checkpoint_every": { + "type": "integer", + "description": "Close stream after this many items (1-20)", + "default": 3, + }, + }, + }, + ) + ] + + # Create event store for resumability + event_store = InMemoryEventStore() + + # Create session manager with event store and retry interval + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=event_store, + retry_interval=retry_interval, + ) + + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + logger.info(f"SSE Polling Demo server started on port {port}") + logger.info("Try: POST /mcp with tools/call for 'process_batch'") + yield + logger.info("Server shutting down...") + + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/servers/sse-polling-demo/pyproject.toml b/examples/servers/sse-polling-demo/pyproject.toml new file mode 100644 index 0000000000..f7ad89217c --- /dev/null +++ b/examples/servers/sse-polling-demo/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-sse-polling-demo" +version = "0.1.0" +description = "Demo server showing SSE polling with close_sse_stream for long-running tasks" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "sse", "polling", "streamable", "http"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.2.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-sse-polling-demo = "mcp_sse_polling_demo.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_sse_polling_demo"] + +[tool.pyright] +include = ["mcp_sse_polling_demo"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] From fdcd8f56ac08f0f9d49a565f363d9d82e36d6383 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 16:39:42 +0000 Subject: [PATCH 07/10] Fix multiple close_sse_stream support and add reconnection logging (SEP-1699) - Register SSE writer in _replay_events() so subsequent close_sse_stream() calls work - Send priming event on each reconnection - Handle ClosedResourceError gracefully in both POST and GET SSE writers - Add disconnect/reconnect logging at INFO level for visibility - Add test for multiple reconnections during long-running tool calls - Remove pragma from store_event (now covered by tests) --- src/mcp/client/streamable_http.py | 4 +- src/mcp/server/streamable_http.py | 19 ++++++ tests/shared/test_streamable_http.py | 95 ++++++++++++++++++++++++++-- 3 files changed, 111 insertions(+), 7 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index ce8ce12c94..4427cc5323 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -363,6 +363,7 @@ async def _handle_sse_response( # Stream ended without response - reconnect if we received an event with ID if last_event_id is not None: + logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) async def _handle_reconnection( @@ -399,7 +400,7 @@ async def _handle_reconnection( timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), ) as event_source: event_source.response.raise_for_status() - logger.debug("Reconnection GET SSE connection established") + logger.info("Reconnected to SSE stream") # Track for potential further reconnection reconnect_last_event_id: str | None = last_event_id @@ -423,6 +424,7 @@ async def _handle_reconnection( # Stream ended again without response - reconnect again (reset attempt counter) if reconnect_last_event_id is not None: + logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection( ctx, reconnect_last_event_id, reconnect_retry_ms, 0 ) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index b962f1131c..834e613fac 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -208,6 +208,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: if writer: writer.close() + # Also close and remove request streams + if request_id in self._request_streams: + send_stream, receive_stream = self._request_streams.pop(request_id) + send_stream.close() + receive_stream.close() + def _create_session_message( self, message: JSONRPCMessage, @@ -545,6 +551,9 @@ async def sse_writer(): JSONRPCResponse | JSONRPCError, ): break + except anyio.ClosedResourceError: + # Expected when close_sse_stream() is called + logger.debug("SSE stream closed by close_sse_stream()") except Exception: logger.exception("Error in SSE writer") finally: @@ -848,6 +857,13 @@ async def send_event(event_message: EventMessage) -> None: # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: + # Register SSE writer so close_sse_stream() can close it + self._sse_stream_writers[stream_id] = sse_stream_writer + + # Send priming event for this new connection + await self._send_priming_event(stream_id, sse_stream_writer) + + # Create new request streams for this connection self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) msg_reader = self._request_streams[stream_id][1] @@ -857,6 +873,9 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) + except anyio.ClosedResourceError: + # Expected when close_sse_stream() is called + logger.debug("Replay SSE stream closed by close_sse_stream()") except Exception: logger.exception("Error in replay sender") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 8efc541b20..6842711103 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -80,7 +80,7 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: # pragma: no cover + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) @@ -175,6 +175,17 @@ async def handle_list_tools() -> list[Tool]: description="Tool that sends notification1, closes stream, sends notification2, notification3", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="tool_with_multiple_stream_closes", + description="Tool that closes SSE stream multiple times during execution", + inputSchema={ + "type": "object", + "properties": { + "checkpoints": {"type": "integer", "default": 3}, + "sleep_time": {"type": "number", "default": 0.2}, + }, + }, + ), ] @self.call_tool() @@ -314,6 +325,25 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] ) return [TextContent(type="text", text="All notifications sent")] + elif name == "tool_with_multiple_stream_closes": + num_checkpoints = args.get("checkpoints", 3) + sleep_time = args.get("sleep_time", 0.2) + + for i in range(num_checkpoints): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) + + if ctx.close_sse_stream: + await ctx.close_sse_stream() + + await anyio.sleep(sleep_time) + + return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + return [TextContent(type="text", text=f"Called {name}")] @@ -950,7 +980,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session: """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 8 + assert len(tools.tools) == 9 assert tools.tools[0].name == "test_tool" # Call the tool @@ -987,7 +1017,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 8 + assert len(tools.tools) == 9 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -1016,7 +1046,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 8 + assert len(tools.tools) == 9 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -1086,7 +1116,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 8 + assert len(tools.tools) == 9 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1152,7 +1182,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 8 + assert len(tools.tools) == 9 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1944,3 +1974,56 @@ async def message_handler( assert result.content[0].type == "text" assert isinstance(result.content[0], TextContent) assert result.content[0].text == "All notifications sent" + + +@pytest.mark.anyio +async def test_streamablehttp_multiple_reconnections( + event_server: tuple[SimpleEventStore, str], +): + """Verify multiple close_sse_stream() calls each trigger a client reconnect. + + Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure + client has time to reconnect before the next checkpoint. + + With 3 checkpoints, we expect 8 resumption tokens: + - 1 priming (initial POST connection) + - 3 notifications (checkpoint_0, checkpoint_1, checkpoint_2) + - 3 priming (one per reconnect after each close) + - 1 response + """ + _, server_url = event_server + resumption_tokens: list[str] = [] + + async def on_resumption_token(token: str) -> None: + resumption_tokens.append(token) + + async with streamablehttp_client(f"{server_url}/mcp") as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="tool_with_multiple_stream_closes", + # retry_interval=500ms, so sleep 600ms to ensure reconnect completes + arguments={"checkpoints": 3, "sleep_time": 0.6}, + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text + + # 4 priming + 3 notifications + 1 response = 8 tokens + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) From 7d3674ab663257a7c4e43b305b924882e9fa2a0b Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 17:24:52 +0000 Subject: [PATCH 08/10] Add retry_interval to FastMCP and test_reconnection to everything-server (SEP-1699) - Add retry_interval parameter to FastMCP for SSE polling control - Add InMemoryEventStore and test_reconnection tool to everything-server - Enables SSE polling conformance test to pass (server-sse-polling scenario) --- examples/clients/sse-polling-client/README.md | 2 +- .../mcp_everything_server/server.py | 57 +++++++++++++++++++ .../mcp_simple_streamablehttp/event_store.py | 8 ++- src/mcp/client/streamable_http.py | 9 +-- src/mcp/server/fastmcp/server.py | 3 + tests/shared/test_streamable_http.py | 4 +- 6 files changed, 72 insertions(+), 11 deletions(-) diff --git a/examples/clients/sse-polling-client/README.md b/examples/clients/sse-polling-client/README.md index 68fdf1c35d..78449aa832 100644 --- a/examples/clients/sse-polling-client/README.md +++ b/examples/clients/sse-polling-client/README.md @@ -24,7 +24,7 @@ uv run mcp-sse-polling-client --url http://localhost:3000/mcp --items 20 --check ## Options -- `--url`: Server URL (default: http://localhost:3000/mcp) +- `--url`: Server URL (default: ) - `--items`: Number of items to process (default: 10) - `--checkpoint-every`: Checkpoint interval (default: 3) - `--log-level`: Logging level (default: DEBUG) diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index eb632b4d63..e37bfa1317 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -14,6 +14,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.prompts.base import UserMessage from mcp.server.session import ServerSession +from mcp.server.streamable_http import EventCallback, EventMessage, EventStore from mcp.types import ( AudioContent, Completion, @@ -21,6 +22,7 @@ CompletionContext, EmbeddedResource, ImageContent, + JSONRPCMessage, PromptReference, ResourceTemplateReference, SamplingMessage, @@ -31,6 +33,43 @@ logger = logging.getLogger(__name__) +# Type aliases for event store +StreamId = str +EventId = str + + +class InMemoryEventStore(EventStore): + """Simple in-memory event store for SSE resumability testing.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, EventId, JSONRPCMessage | None]] = [] + self._event_id_counter = 0 + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + """Store an event and return its ID.""" + self._event_id_counter += 1 + event_id = str(self._event_id_counter) + self._events.append((stream_id, event_id, message)) + return event_id + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + """Replay events after the specified ID.""" + target_stream_id = None + for stream_id, event_id, _ in self._events: + if event_id == last_event_id: + target_stream_id = stream_id + break + if target_stream_id is None: + return None + last_event_id_int = int(last_event_id) + for stream_id, event_id, message in self._events: + if stream_id == target_stream_id and int(event_id) > last_event_id_int: + # Skip priming events (None message) + if message is not None: + await send_callback(EventMessage(message, event_id)) + return target_stream_id + + # Test data TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" TEST_AUDIO_BASE64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQIAAAA=" @@ -39,8 +78,13 @@ resource_subscriptions: set[str] = set() watched_resource_content = "Watched resource content" +# Create event store for SSE resumability (SEP-1699) +event_store = InMemoryEventStore() + mcp = FastMCP( name="mcp-conformance-test-server", + event_store=event_store, + retry_interval=100, # 100ms retry interval for SSE polling ) @@ -263,6 +307,19 @@ def test_error_handling() -> str: raise RuntimeError("This tool intentionally returns an error for testing") +@mcp.tool() +async def test_reconnection(ctx: Context[ServerSession, None]) -> str: + """Tests SSE polling by closing stream mid-call (SEP-1699)""" + await ctx.info("Before disconnect") + + await ctx.close_sse_stream() + + await asyncio.sleep(0.2) # Wait for client to reconnect + + await ctx.info("After reconnect") + return "Reconnection test completed" + + # Resources @mcp.resource("test://static-text") def static_text_resource() -> str: diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py index ee52cdbe77..0c3081ed64 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -24,7 +24,7 @@ class EventEntry: event_id: EventId stream_id: StreamId - message: JSONRPCMessage + message: JSONRPCMessage | None class InMemoryEventStore(EventStore): @@ -48,7 +48,7 @@ def __init__(self, max_events_per_stream: int = 100): # event_id -> EventEntry for quick lookup self.event_index: dict[EventId, EventEntry] = {} - async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: """Stores an event with a generated event ID.""" event_id = str(uuid4()) event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) @@ -88,7 +88,9 @@ async def replay_events_after( found_last = False for event in stream_events: if found_last: - await send_callback(EventMessage(event.message, event.event_id)) + # Skip priming events (None message) + if event.message is not None: + await send_callback(EventMessage(event.message, event.event_id)) elif event.event_id == last_event_id: found_last = True diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 4427cc5323..091a81b593 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -403,7 +403,7 @@ async def _handle_reconnection( logger.info("Reconnected to SSE stream") # Track for potential further reconnection - reconnect_last_event_id: str | None = last_event_id + reconnect_last_event_id: str = last_event_id reconnect_retry_ms = retry_interval_ms async for sse in event_source.aiter_sse(): @@ -423,11 +423,8 @@ async def _handle_reconnection( return # Stream ended again without response - reconnect again (reset attempt counter) - if reconnect_last_event_id is not None: - logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection( - ctx, reconnect_last_event_id, reconnect_retry_ms, 0 - ) + logger.info("SSE stream disconnected, reconnecting...") + await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) except Exception as e: logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 921d8f540a..c3902df0e7 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -153,6 +153,7 @@ def __init__( # noqa: PLR0913 auth_server_provider: (OAuthAuthorizationServerProvider[Any, Any, Any] | None) = None, token_verifier: TokenVerifier | None = None, event_store: EventStore | None = None, + retry_interval: int | None = None, *, tools: list[Tool] | None = None, debug: bool = False, @@ -221,6 +222,7 @@ def __init__( # noqa: PLR0913 if auth_server_provider and not token_verifier: # pragma: no cover self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._event_store = event_store + self._retry_interval = retry_interval self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies self._session_manager: StreamableHTTPSessionManager | None = None @@ -940,6 +942,7 @@ def streamable_http_app(self) -> Starlette: self._session_manager = StreamableHTTPSessionManager( app=self._mcp_server, event_store=self._event_store, + retry_interval=self._retry_interval, json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting security_settings=self.settings.transport_security, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 6842711103..9065f5ce67 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -110,7 +110,9 @@ async def replay_events_after( # pragma: no cover # Replay only events from the same stream with ID > last_event_id for stream_id, event_id, message in self._events: if stream_id == target_stream_id and int(event_id) > last_event_id_int: - await send_callback(EventMessage(message, event_id)) + # Skip priming events (None message) + if message is not None: + await send_callback(EventMessage(message, event_id)) return target_stream_id From f54b18a5f91b5cbab501668bd145a47245fe842e Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 27 Nov 2025 17:38:05 +0000 Subject: [PATCH 09/10] update uv.lock & add pragmas --- src/mcp/client/streamable_http.py | 12 ++--- src/mcp/server/fastmcp/server.py | 2 +- src/mcp/server/streamable_http.py | 6 +-- tests/shared/test_streamable_http.py | 30 ++++++------ uv.lock | 68 ++++++++++++++++++++++++++++ 5 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 091a81b593..4b4b4ee0d5 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -358,11 +358,11 @@ async def _handle_sse_response( if is_complete: await response.aclose() return # Normal completion, no reconnect needed - except Exception as e: + except Exception as e: # pragma: no cover logger.debug(f"SSE stream ended: {e}") # Stream ended without response - reconnect if we received an event with ID - if last_event_id is not None: + if last_event_id is not None: # pragma: no branch logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) @@ -375,7 +375,7 @@ async def _handle_reconnection( ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" # Bail if max retries exceeded - if attempt >= MAX_RECONNECTION_ATTEMPTS: + if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") return @@ -388,7 +388,7 @@ async def _handle_reconnection( # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): + if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.root.id try: @@ -407,7 +407,7 @@ async def _handle_reconnection( reconnect_retry_ms = retry_interval_ms async for sse in event_source.aiter_sse(): - if sse.id: + if sse.id: # pragma: no branch reconnect_last_event_id = sse.id if sse.retry is not None: reconnect_retry_ms = sse.retry @@ -425,7 +425,7 @@ async def _handle_reconnection( # Stream ended again without response - reconnect again (reset attempt counter) logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) - except Exception as e: + except Exception as e: # pragma: no cover logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c3902df0e7..3b2ad35424 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1299,7 +1299,7 @@ async def close_sse_stream(self) -> None: This is a no-op if not using StreamableHTTP transport with event_store. The callback is only available when event_store is configured. """ - if self._request_context and self._request_context.close_sse_stream: + if self._request_context and self._request_context.close_sse_stream: # pragma: no cover await self._request_context.close_sse_stream() # Convenience methods for common log levels diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 834e613fac..16fbb4156d 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -186,7 +186,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: + def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -214,7 +214,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: send_stream.close() receive_stream.close() - def _create_session_message( + def _create_session_message( # pragma: no cover self, message: JSONRPCMessage, request: Request, @@ -231,7 +231,7 @@ async def close_stream_callback() -> None: ) return SessionMessage(message, metadata=metadata) - async def _send_priming_event( + async def _send_priming_event( # pragma: no cover self, request_id: RequestId, sse_stream_writer: MemoryObjectSendStream[dict[str, Any]], diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 9065f5ce67..03208f6380 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -80,7 +80,9 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: + async def store_event( # pragma: no cover + self, stream_id: StreamId, message: types.JSONRPCMessage | None + ) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) @@ -1817,10 +1819,10 @@ async def test_streamablehttp_client_auto_reconnects( async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): - return - if isinstance(message, types.ServerNotification): - if isinstance(message.root, types.LoggingMessageNotification): + if isinstance(message, Exception): # pragma: no branch + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch captured_notifications.append(str(message.root.params.data)) async with streamablehttp_client(f"{server_url}/mcp") as ( @@ -1893,10 +1895,10 @@ async def test_streamablehttp_sse_polling_full_cycle( async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): - return - if isinstance(message, types.ServerNotification): - if isinstance(message.root, types.LoggingMessageNotification): + if isinstance(message, Exception): # pragma: no branch + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch all_notifications.append(str(message.root.params.data)) async with streamablehttp_client(f"{server_url}/mcp") as ( @@ -1941,10 +1943,10 @@ async def test_streamablehttp_events_replayed_after_disconnect( async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): - return - if isinstance(message, types.ServerNotification): - if isinstance(message.root, types.LoggingMessageNotification): + if isinstance(message, Exception): # pragma: no branch + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch notification_data.append(str(message.root.params.data)) async with streamablehttp_client(f"{server_url}/mcp") as ( @@ -2025,7 +2027,7 @@ async def on_resumption_token(token: str) -> None: assert "Completed 3 checkpoints" in result.content[0].text # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( + assert len(resumption_tokens) == 8, ( # pragma: no cover f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " f"got {len(resumption_tokens)}: {resumption_tokens}" ) diff --git a/uv.lock b/uv.lock index d1363aef41..c8e0fce1da 100644 --- a/uv.lock +++ b/uv.lock @@ -17,6 +17,8 @@ members = [ "mcp-simple-streamablehttp-stateless", "mcp-simple-tool", "mcp-snippets", + "mcp-sse-polling-client", + "mcp-sse-polling-demo", "mcp-structured-output-lowlevel", ] @@ -1240,6 +1242,72 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "mcp", editable = "." }] +[[package]] +name = "mcp-sse-polling-client" +version = "0.1.0" +source = { editable = "examples/clients/sse-polling-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.2.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-sse-polling-demo" +version = "0.1.0" +source = { editable = "examples/servers/sse-polling-demo" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.2.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-structured-output-lowlevel" version = "0.1.0" From 465c541c1743024bb574145dc6984319fe986867 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Mon, 1 Dec 2025 15:03:57 +0000 Subject: [PATCH 10/10] Add GET stream auto-reconnection and close_standalone_sse_stream API Client changes: - handle_get_stream() now auto-reconnects after server closes connection - Respects retry_interval from server for reconnection timing - Sends Last-Event-ID header on reconnection if available Server changes: - Added close_standalone_sse_stream() method to StreamableHTTPServerTransport - Exposed via RequestContext and FastMCP Context for tools to trigger GET stream closure and client reconnection Test changes: - Added test_standalone_get_stream_reconnection to verify reconnection works - Updated tool count assertions (9 -> 10 for new test tool) - Removed SEP-1699 references from source code comments --- src/mcp/client/streamable_http.py | 63 +++++++++++++------ src/mcp/server/fastmcp/server.py | 15 +++++ src/mcp/server/lowlevel/server.py | 3 + src/mcp/server/streamable_http.py | 23 +++++++ src/mcp/shared/context.py | 1 + src/mcp/shared/message.py | 4 +- tests/shared/test_streamable_http.py | 91 ++++++++++++++++++++++++++-- 7 files changed, 176 insertions(+), 24 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 4b4b4ee0d5..fa0524e6ef 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -206,28 +206,55 @@ async def handle_get_stream( client: httpx.AsyncClient, read_stream_writer: StreamWriter, ) -> None: - """Handle GET stream for server-initiated messages.""" - try: - if not self.session_id: - return + """Handle GET stream for server-initiated messages with auto-reconnect.""" + last_event_id: str | None = None + retry_interval_ms: int | None = None + attempt: int = 0 - headers = self._prepare_request_headers(self.request_headers) + while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch + try: + if not self.session_id: + return - async with aconnect_sse( - client, - "GET", - self.url, - headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("GET SSE connection established") + headers = self._prepare_request_headers(self.request_headers) + if last_event_id: + headers[LAST_EVENT_ID] = last_event_id # pragma: no cover - async for sse in event_source.aiter_sse(): - await self._handle_sse_event(sse, read_stream_writer) + async with aconnect_sse( + client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + # Track last event ID for reconnection + if sse.id: + last_event_id = sse.id # pragma: no cover + # Track retry interval from server + if sse.retry is not None: + retry_interval_ms = sse.retry # pragma: no cover + + await self._handle_sse_event(sse, read_stream_writer) + + # Stream ended normally (server closed) - reset attempt counter + attempt = 0 - except Exception as exc: - logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover + except Exception as exc: # pragma: no cover + logger.debug(f"GET stream error: {exc}") + attempt += 1 + + if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") + return + + # Wait before reconnecting + delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...") + await anyio.sleep(delay_ms / 1000.0) async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3b2ad35424..6fb0ec45ce 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1302,6 +1302,21 @@ async def close_sse_stream(self) -> None: if self._request_context and self._request_context.close_sse_stream: # pragma: no cover await self._request_context.close_sse_stream() + async def close_standalone_sse_stream(self) -> None: + """Close the standalone GET SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap. + """ + if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover + await self._request_context.close_standalone_sse_stream() + # Convenience methods for common log levels async def debug(self, message: str, **extra: Any) -> None: """Send a debug log message.""" diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 8a64b463cf..c3ff2fa165 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -683,11 +683,13 @@ async def _handle_request( # Extract request context and close_sse_stream from message metadata request_data = None close_sse_stream_cb = None + close_standalone_sse_stream_cb = None if message.message_metadata is not None and isinstance( message.message_metadata, ServerMessageMetadata ): # pragma: no cover request_data = message.message_metadata.request_context close_sse_stream_cb = message.message_metadata.close_sse_stream + close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream # Set our global state that can be retrieved via # app.get_request_context() @@ -699,6 +701,7 @@ async def _handle_request( lifespan_context, request=request_data, close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, ) ) response = await handler(req) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 16fbb4156d..22167377ba 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -214,6 +214,25 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover send_stream.close() receive_stream.close() + def close_standalone_sse_stream(self) -> None: # pragma: no cover + """Close the standalone GET SSE stream, triggering client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Use this to implement polling behavior for the notification stream - + client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if there is no active standalone SSE stream. + Requires event_store to be configured for events to be stored during + the disconnect. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap (see test_standalone_get_stream_reconnection). + """ + self.close_sse_stream(GET_STREAM_KEY) + def _create_session_message( # pragma: no cover self, message: JSONRPCMessage, @@ -225,9 +244,13 @@ def _create_session_message( # pragma: no cover async def close_stream_callback() -> None: self.close_sse_stream(request_id) + async def close_standalone_stream_callback() -> None: + self.close_standalone_sse_stream() + metadata = ServerMessageMetadata( request_context=request, close_sse_stream=close_stream_callback, + close_standalone_sse_stream=close_standalone_stream_callback, ) return SessionMessage(message, metadata=metadata) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index aaeadd70b6..c3b1fd82d0 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -20,3 +20,4 @@ class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): lifespan_context: LifespanContextT request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None + close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 7104e52a5d..81503eaaa7 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -33,8 +33,10 @@ class ServerMessageMetadata: related_request_id: RequestId | None = None # Request-specific context (e.g., headers, auth info) request_context: object | None = None - # Callback to close SSE stream without terminating + # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None + # Callback to close the standalone GET SSE stream (for unsolicited notifications) + close_standalone_sse_stream: CloseSSEStreamCallback | None = None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 03208f6380..b4bbfd61d6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -190,6 +190,11 @@ async def handle_list_tools() -> list[Tool]: }, }, ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() @@ -348,6 +353,26 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + elif name == "tool_with_standalone_stream_close": + # Test for GET stream reconnection + # 1. Send unsolicited notification via GET stream (no related_request_id) + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_1")) + + # Small delay to ensure notification is flushed before closing + await anyio.sleep(0.1) + + # 2. Close the standalone GET stream + if ctx.close_standalone_sse_stream: + await ctx.close_standalone_sse_stream() + + # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) + await anyio.sleep(1.5) + + # 4. Send another notification on the new GET stream connection + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_2")) + + return [TextContent(type="text", text="Standalone stream close test done")] + return [TextContent(type="text", text=f"Called {name}")] @@ -984,7 +1009,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session: """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 10 assert tools.tools[0].name == "test_tool" # Call the tool @@ -1021,7 +1046,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 10 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -1050,7 +1075,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 10 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -1120,7 +1145,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 10 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -1186,7 +1211,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 9 + assert len(tools.tools) == 10 headers: dict[str, str] = {} # pragma: no cover if captured_session_id: # pragma: no cover @@ -2031,3 +2056,59 @@ async def on_resumption_token(token: str) -> None: f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " f"got {len(resumption_tokens)}: {resumption_tokens}" ) + + +@pytest.mark.anyio +async def test_standalone_get_stream_reconnection(basic_server: None, basic_server_url: str) -> None: + """ + Test that standalone GET stream automatically reconnects after server closes it. + + Verifies: + 1. Client receives notification 1 via GET stream + 2. Server closes GET stream + 3. Client reconnects with Last-Event-ID + 4. Client receives notification 2 on new connection + """ + server_url = basic_server_url + received_notifications: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.ResourceUpdatedNotification): # pragma: no branch + received_notifications.append(str(message.root.params.uri)) + + async with streamablehttp_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1/" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2/" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + )