diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 41145e49f6..3e64fae415 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -80,14 +80,39 @@ class SseServerTransport: def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: """ Creates a new SSE server transport, which will direct the client to POST - messages to the relative or absolute URL given. + messages to the relative path given. Args: - endpoint: The relative or absolute URL for POST messages. + endpoint: A relative path where messages should be posted + (e.g., "/messages/"). security_settings: Optional security settings for DNS rebinding protection. + + Note: + We use relative paths instead of full URLs for several reasons: + 1. Security: Prevents cross-origin requests by ensuring clients only connect + to the same origin they established the SSE connection with + 2. Flexibility: The server can be mounted at any path without needing to + know its full URL + 3. Portability: The same endpoint configuration works across different + environments (development, staging, production) + + Raises: + ValueError: If the endpoint is a full URL instead of a relative path """ super().__init__() + + # Validate that endpoint is a relative path and not a full URL + if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint: + raise ValueError( + f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), \ + expecting a relative path(e.g., '/messages/')." + ) + + # Ensure endpoint starts with a forward slash + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + self._endpoint = endpoint self._read_stream_writers = {} self._security = TransportSecurityMiddleware(security_settings) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 8e1912e9bd..41821e6806 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -473,3 +473,31 @@ def test_sse_message_id_coercion(): json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}' msg = types.JSONRPCMessage.model_validate_json(json_message) assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))) + + +@pytest.mark.parametrize( + "endpoint, expected_result", + [ + # Valid endpoints - should normalize and work + ("/messages/", "/messages/"), + ("messages/", "/messages/"), + ("/", "/"), + # Invalid endpoints - should raise ValueError + ("http://example.com/messages/", ValueError), + ("//example.com/messages/", ValueError), + ("ftp://example.com/messages/", ValueError), + ("/messages/?param=value", ValueError), + ("/messages/#fragment", ValueError), + ], +) +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): + """Test that SseServerTransport properly validates and normalizes endpoints.""" + if isinstance(expected_result, type) and issubclass(expected_result, Exception): + # Test invalid endpoints that should raise an exception + with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"): + SseServerTransport(endpoint) + else: + # Test valid endpoints that should normalize correctly + sse = SseServerTransport(endpoint) + assert sse._endpoint == expected_result + assert sse._endpoint.startswith("/")