diff --git a/pyproject.toml b/pyproject.toml index 9058cb6..354c062 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,10 @@ classifiers = [ ] version = "0.2.2" requires-python = ">=3.11" -dependencies = ["mcp"] +dependencies = [ + "mcp", + "uvicorn>=0.34.0", +] [build-system] requires = ["setuptools"] diff --git a/src/mcp_proxy/sse_server.py b/src/mcp_proxy/sse_server.py new file mode 100644 index 0000000..7be2350 --- /dev/null +++ b/src/mcp_proxy/sse_server.py @@ -0,0 +1,77 @@ +"""Create a local SSE server that proxies requests to a stdio MCP server.""" + +from dataclasses import dataclass +from typing import Literal + +import uvicorn +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.routing import Mount, Route + +from .proxy_server import create_proxy_server + + +@dataclass +class SseServerSettings: + """Settings for the server.""" + + bind_host: str = "127.0.0.1" + port: int = 8000 + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + + +def create_starlette_app(mcp_server: Server, debug: bool | None = None) -> Starlette: + """Create a Starlette application that can server the provied mcp server with SSE.""" + sse = SseServerTransport("/messages/") + + async def handle_sse(request: Request) -> None: + async with sse.connect_sse( + request.scope, + request.receive, + request._send, # noqa: SLF001 + ) as (read_stream, write_stream): + await mcp_server.run( + read_stream, + write_stream, + mcp_server.create_initialization_options(), + ) + + return Starlette( + debug=debug, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + +async def run_sse_server( + stdio_params: StdioServerParameters, + sse_settings: SseServerSettings, +) -> None: + """Run the stdio client and expose an SSE server. + + Args: + stdio_params: The parameters for the stdio client that spawns a stdio server. + sse_settings: The settings for the SSE server that accepts incoming requests. + + """ + async with stdio_client(stdio_params) as streams, ClientSession(*streams) as session: + mcp_server = await create_proxy_server(session) + + # Bind SSE request handling to MCP server + starlette_app = await create_starlette_app(mcp_server, sse_settings.log_level == "DEBUG") + + # Configure HTTP server + config = uvicorn.Config( + starlette_app, + host=sse_settings.bind_host, + port=sse_settings.port, + log_level=sse_settings.log_level.lower(), + ) + http_server = uvicorn.Server(config) + await http_server.serve() diff --git a/tests/test_sse_server.py b/tests/test_sse_server.py new file mode 100644 index 0000000..291ef8f --- /dev/null +++ b/tests/test_sse_server.py @@ -0,0 +1,60 @@ +"""Tests for the sse server.""" + +import asyncio +import contextlib + +import uvicorn +from mcp import types +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.server import Server + +from mcp_proxy.sse_server import create_starlette_app + + +class BackgroundServer(uvicorn.Server): + """A test server that runs in a background thread.""" + + def install_signal_handlers(self) -> None: + """Do not install signal handlers.""" + + @contextlib.asynccontextmanager + async def run_in_background(self) -> None: + """Run the server in a background thread.""" + task = asyncio.create_task(self.serve()) + try: + while not self.started: # noqa: ASYNC110 + await asyncio.sleep(1e-3) + yield + finally: + task.cancel() + self.shutdown() + + @property + def url(self) -> str: + """Return the url of the started server.""" + hostport = next( + iter([socket.getsockname() for server in self.servers for socket in server.sockets]), + ) + return f"http://{hostport[0]}:{hostport[1]}" + + +async def test_create_starlette_app() -> None: + """Test basic glue code for the SSE transport and a fake MCP server.""" + server = Server("prompt-server") + + @server.list_prompts() + async def list_prompts() -> list[types.Prompt]: + return [types.Prompt(name="prompt1")] + + app = create_starlette_app(server) + + config = uvicorn.Config(app, port=0, log_level="info") + server = BackgroundServer(config) + async with server.run_in_background(): + mcp_url = f"{server.url}/sse" + async with sse_client(url=mcp_url) as streams, ClientSession(*streams) as session: + await session.initialize() + response = await session.list_prompts() + assert len(response.prompts) == 1 + assert response.prompts[0].name == "prompt1" diff --git a/uv.lock b/uv.lock index 91afb04..a6ea7cf 100644 --- a/uv.lock +++ b/uv.lock @@ -33,6 +33,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, ] +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -177,6 +189,7 @@ version = "0.2.2" source = { editable = "." } dependencies = [ { name = "mcp" }, + { name = "uvicorn" }, ] [package.dev-dependencies] @@ -187,7 +200,10 @@ dev = [ ] [package.metadata] -requires-dist = [{ name = "mcp" }] +requires-dist = [ + { name = "mcp" }, + { name = "uvicorn", specifier = ">=0.34.0" }, +] [package.metadata.requires-dev] dev = [ @@ -350,3 +366,16 @@ sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec3 wheels = [ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, ] + +[[package]] +name = "uvicorn" +version = "0.34.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 }, +]