Skip to content

Commit 8c6fe2d

Browse files
authored
add support for SSE client (#9)
1 parent cb93152 commit 8c6fe2d

File tree

2 files changed

+201
-27
lines changed

2 files changed

+201
-27
lines changed

README.md

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def multiply(a: int, b: int) -> int:
4646
return a * b
4747

4848
if __name__ == "__main__":
49-
mcp.run()
49+
mcp.run(transport="stdio")
5050
```
5151

5252
### Client
@@ -103,7 +103,11 @@ async def get_weather(location: str) -> int:
103103
return "It's always sunny in New York"
104104

105105
if __name__ == "__main__":
106-
mcp.run()
106+
mcp.run(transport="sse")
107+
```
108+
109+
```bash
110+
python weather_server.py
107111
```
108112

109113
### Client
@@ -115,19 +119,21 @@ from langgraph.prebuilt import create_react_agent
115119
from langchain_openai import ChatOpenAI
116120
model = ChatOpenAI(model="gpt-4o")
117121

118-
async with MultiServerMCPClient() as client:
119-
await client.connect_to_server(
120-
"math",
121-
command="python",
122-
# Make sure to update to the full absolute path to your math_server.py file
123-
args=["/path/to/math_server.py"],
124-
)
125-
await client.connect_to_server(
126-
"weather",
127-
command="python",
128-
# Make sure to update to the full absolute path to your weather_server.py file
129-
args=["/path/to/weather_server.py"],
130-
)
122+
async with MultiServerMCPClient(
123+
{
124+
"math": {
125+
"command": "python",
126+
# Make sure to update to the full absolute path to your math_server.py file
127+
"args": ["/path/to/math_server.py"],
128+
"transport": "stdio",
129+
},
130+
"weather": {
131+
# make sure you start your weather server on port 8000
132+
"url": "http://localhost:8000/sse",
133+
"transport": "sse",
134+
}
135+
}
136+
) as client:
131137
agent = create_react_agent(model, client.get_tools())
132138
math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
133139
weather_response = await agent.ainvoke({"messages": "what is the weather in nyc?"})

langchain_mcp_adapters/client.py

Lines changed: 180 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,169 @@
11
from contextlib import AsyncExitStack
22
from types import TracebackType
3-
from typing import Literal, cast
3+
from typing import Literal, TypedDict, cast
44

55
from langchain_core.tools import BaseTool
66
from mcp import ClientSession, StdioServerParameters
7+
from mcp.client.sse import sse_client
78
from mcp.client.stdio import stdio_client
89

910
from langchain_mcp_adapters.tools import load_mcp_tools
1011

12+
DEFAULT_ENCODING = "utf-8"
13+
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
14+
15+
16+
class StdioConnection(TypedDict):
17+
transport: Literal["stdio"]
18+
19+
command: str
20+
"""The executable to run to start the server."""
21+
22+
args: list[str]
23+
"""Command line arguments to pass to the executable."""
24+
25+
env: dict[str, str] | None
26+
"""The environment to use when spawning the process."""
27+
28+
encoding: str
29+
"""The text encoding used when sending/receiving messages to the server."""
30+
31+
encoding_error_handler: Literal["strict", "ignore", "replace"]
32+
"""
33+
The text encoding error handler.
34+
35+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
36+
explanations of possible values
37+
"""
38+
39+
40+
class SSEConnection(TypedDict):
41+
transport: Literal["sse"]
42+
43+
url: str
44+
"""The URL of the SSE endpoint to connect to."""
45+
1146

1247
class MultiServerMCPClient:
1348
"""Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them."""
1449

15-
def __init__(self) -> None:
50+
def __init__(self, connections: dict[str, StdioConnection | SSEConnection] = None) -> None:
51+
"""Initialize a MultiServerMCPClient with MCP servers connections.
52+
53+
Args:
54+
connections: A dictionary mapping server names to connection configurations.
55+
Each configuration can be either a StdioConnection or SSEConnection.
56+
If None, no initial connections are established.
57+
58+
Example:
59+
60+
```python
61+
async with MultiServerMCPClient(
62+
{
63+
"math": {
64+
"command": "python",
65+
# Make sure to update to the full absolute path to your math_server.py file
66+
"args": ["/path/to/math_server.py"],
67+
"transport": "stdio",
68+
},
69+
"weather": {
70+
# make sure you start your weather server on port 8000
71+
"url": "http://localhost:8000/sse",
72+
"transport": "sse",
73+
}
74+
}
75+
) as client:
76+
all_tools = client.get_tools()
77+
...
78+
```
79+
"""
80+
self.connections = connections
1681
self.exit_stack = AsyncExitStack()
1782
self.sessions: dict[str, ClientSession] = {}
1883
self.server_name_to_tools: dict[str, list[BaseTool]] = {}
1984

85+
async def _initialize_session_and_load_tools(
86+
self, server_name: str, session: ClientSession
87+
) -> None:
88+
"""Initialize a session and load tools from it.
89+
90+
Args:
91+
server_name: Name to identify this server connection
92+
session: The ClientSession to initialize
93+
"""
94+
# Initialize the session
95+
await session.initialize()
96+
self.sessions[server_name] = session
97+
98+
# Load tools from this server
99+
server_tools = await load_mcp_tools(session)
100+
self.server_name_to_tools[server_name] = server_tools
101+
20102
async def connect_to_server(
103+
self,
104+
server_name: str,
105+
*,
106+
transport: Literal["stdio", "sse"] = "stdio",
107+
**kwargs,
108+
) -> None:
109+
"""Connect to an MCP server using either stdio or SSE.
110+
111+
This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
112+
based on the provided transport parameter.
113+
114+
Args:
115+
server_name: Name to identify this server connection
116+
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
117+
**kwargs: Additional arguments to pass to the specific connection method
118+
119+
Raises:
120+
ValueError: If transport is not recognized
121+
ValueError: If required parameters for the specified transport are missing
122+
"""
123+
if transport == "sse":
124+
if "url" not in kwargs:
125+
raise ValueError("'url' parameter is required for SSE connection")
126+
await self.connect_to_server_via_sse(server_name, url=kwargs["url"])
127+
elif transport == "stdio":
128+
if "command" not in kwargs:
129+
raise ValueError("'command' parameter is required for stdio connection")
130+
if "args" not in kwargs:
131+
raise ValueError("'args' parameter is required for stdio connection")
132+
await self.connect_to_server_via_stdio(
133+
server_name,
134+
command=kwargs["command"],
135+
args=kwargs["args"],
136+
env=kwargs.get("env"),
137+
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
138+
encoding_error_handler=kwargs.get(
139+
"encoding_error_handler", DEFAULT_ENCODING_ERROR_HANDLER
140+
),
141+
)
142+
else:
143+
raise ValueError(f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'")
144+
145+
async def connect_to_server_via_stdio(
21146
self,
22147
server_name: str,
23148
*,
24149
command: str,
25150
args: list[str],
26151
env: dict[str, str] | None = None,
27-
encoding: str = "utf-8",
28-
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict",
152+
encoding: str = DEFAULT_ENCODING,
153+
encoding_error_handler: Literal[
154+
"strict", "ignore", "replace"
155+
] = DEFAULT_ENCODING_ERROR_HANDLER,
29156
) -> None:
30-
"""Connect to a specific MCP server"""
157+
"""Connect to a specific MCP server using stdio
158+
159+
Args:
160+
server_name: Name to identify this server connection
161+
command: Command to execute
162+
args: Arguments for the command
163+
env: Environment variables for the command
164+
encoding: Character encoding
165+
encoding_error_handler: How to handle encoding errors
166+
"""
31167
server_params = StdioServerParameters(
32168
command=command,
33169
args=args,
@@ -44,13 +180,29 @@ async def connect_to_server(
44180
await self.exit_stack.enter_async_context(ClientSession(read, write)),
45181
)
46182

47-
# Initialize the session
48-
await session.initialize()
49-
self.sessions[server_name] = session
183+
await self._initialize_session_and_load_tools(server_name, session)
50184

51-
# Load tools from this server
52-
server_tools = await load_mcp_tools(session)
53-
self.server_name_to_tools[server_name] = server_tools
185+
async def connect_to_server_via_sse(
186+
self,
187+
server_name: str,
188+
*,
189+
url: str,
190+
) -> None:
191+
"""Connect to a specific MCP server using SSE
192+
193+
Args:
194+
server_name: Name to identify this server connection
195+
url: URL of the SSE server
196+
"""
197+
# Create and store the connection
198+
sse_transport = await self.exit_stack.enter_async_context(sse_client(url))
199+
read, write = sse_transport
200+
session = cast(
201+
ClientSession,
202+
await self.exit_stack.enter_async_context(ClientSession(read, write)),
203+
)
204+
205+
await self._initialize_session_and_load_tools(server_name, session)
54206

55207
def get_tools(self) -> list[BaseTool]:
56208
"""Get a list of all tools from all connected servers."""
@@ -60,7 +212,23 @@ def get_tools(self) -> list[BaseTool]:
60212
return all_tools
61213

62214
async def __aenter__(self) -> "MultiServerMCPClient":
63-
return self
215+
try:
216+
connections = self.connections or {}
217+
for server_name, connection in connections.items():
218+
connection_dict = connection.copy()
219+
transport = connection_dict.pop("transport")
220+
if transport == "stdio":
221+
await self.connect_to_server_via_stdio(server_name, **connection_dict)
222+
elif transport == "sse":
223+
await self.connect_to_server_via_sse(server_name, **connection_dict)
224+
else:
225+
raise ValueError(
226+
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
227+
)
228+
return self
229+
except Exception:
230+
await self.exit_stack.aclose()
231+
raise
64232

65233
async def __aexit__(
66234
self,

0 commit comments

Comments
 (0)