11from contextlib import AsyncExitStack
22from types import TracebackType
3- from typing import Literal , cast
3+ from typing import Literal , TypedDict , cast
44
55from langchain_core .tools import BaseTool
66from mcp import ClientSession , StdioServerParameters
7+ from mcp .client .sse import sse_client
78from mcp .client .stdio import stdio_client
89
910from langchain_mcp_adapters .tools import load_mcp_tools
1011
12+ DEFAULT_ENCODING = "utf-8"
13+ DEFAULT_ENCODING_ERROR_HANDLER = "strict"
14+
15+
16+ class StdioConnection (TypedDict ):
17+ transport : Literal ["stdio" ]
18+
19+ command : str
20+ """The executable to run to start the server."""
21+
22+ args : list [str ]
23+ """Command line arguments to pass to the executable."""
24+
25+ env : dict [str , str ] | None
26+ """The environment to use when spawning the process."""
27+
28+ encoding : str
29+ """The text encoding used when sending/receiving messages to the server."""
30+
31+ encoding_error_handler : Literal ["strict" , "ignore" , "replace" ]
32+ """
33+ The text encoding error handler.
34+
35+ See https://docs.python.org/3/library/codecs.html#codec-base-classes for
36+ explanations of possible values
37+ """
38+
39+
40+ class SSEConnection (TypedDict ):
41+ transport : Literal ["sse" ]
42+
43+ url : str
44+ """The URL of the SSE endpoint to connect to."""
45+
1146
1247class MultiServerMCPClient :
1348 """Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them."""
1449
15- def __init__ (self ) -> None :
50+ def __init__ (self , connections : dict [str , StdioConnection | SSEConnection ] = None ) -> None :
51+ """Initialize a MultiServerMCPClient with MCP servers connections.
52+
53+ Args:
54+ connections: A dictionary mapping server names to connection configurations.
55+ Each configuration can be either a StdioConnection or SSEConnection.
56+ If None, no initial connections are established.
57+
58+ Example:
59+
60+ ```python
61+ async with MultiServerMCPClient(
62+ {
63+ "math": {
64+ "command": "python",
65+ # Make sure to update to the full absolute path to your math_server.py file
66+ "args": ["/path/to/math_server.py"],
67+ "transport": "stdio",
68+ },
69+ "weather": {
70+ # make sure you start your weather server on port 8000
71+ "url": "http://localhost:8000/sse",
72+ "transport": "sse",
73+ }
74+ }
75+ ) as client:
76+ all_tools = client.get_tools()
77+ ...
78+ ```
79+ """
80+ self .connections = connections
1681 self .exit_stack = AsyncExitStack ()
1782 self .sessions : dict [str , ClientSession ] = {}
1883 self .server_name_to_tools : dict [str , list [BaseTool ]] = {}
1984
85+ async def _initialize_session_and_load_tools (
86+ self , server_name : str , session : ClientSession
87+ ) -> None :
88+ """Initialize a session and load tools from it.
89+
90+ Args:
91+ server_name: Name to identify this server connection
92+ session: The ClientSession to initialize
93+ """
94+ # Initialize the session
95+ await session .initialize ()
96+ self .sessions [server_name ] = session
97+
98+ # Load tools from this server
99+ server_tools = await load_mcp_tools (session )
100+ self .server_name_to_tools [server_name ] = server_tools
101+
20102 async def connect_to_server (
103+ self ,
104+ server_name : str ,
105+ * ,
106+ transport : Literal ["stdio" , "sse" ] = "stdio" ,
107+ ** kwargs ,
108+ ) -> None :
109+ """Connect to an MCP server using either stdio or SSE.
110+
111+ This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
112+ based on the provided transport parameter.
113+
114+ Args:
115+ server_name: Name to identify this server connection
116+ transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
117+ **kwargs: Additional arguments to pass to the specific connection method
118+
119+ Raises:
120+ ValueError: If transport is not recognized
121+ ValueError: If required parameters for the specified transport are missing
122+ """
123+ if transport == "sse" :
124+ if "url" not in kwargs :
125+ raise ValueError ("'url' parameter is required for SSE connection" )
126+ await self .connect_to_server_via_sse (server_name , url = kwargs ["url" ])
127+ elif transport == "stdio" :
128+ if "command" not in kwargs :
129+ raise ValueError ("'command' parameter is required for stdio connection" )
130+ if "args" not in kwargs :
131+ raise ValueError ("'args' parameter is required for stdio connection" )
132+ await self .connect_to_server_via_stdio (
133+ server_name ,
134+ command = kwargs ["command" ],
135+ args = kwargs ["args" ],
136+ env = kwargs .get ("env" ),
137+ encoding = kwargs .get ("encoding" , DEFAULT_ENCODING ),
138+ encoding_error_handler = kwargs .get (
139+ "encoding_error_handler" , DEFAULT_ENCODING_ERROR_HANDLER
140+ ),
141+ )
142+ else :
143+ raise ValueError (f"Unsupported transport: { transport } . Must be 'stdio' or 'sse'" )
144+
145+ async def connect_to_server_via_stdio (
21146 self ,
22147 server_name : str ,
23148 * ,
24149 command : str ,
25150 args : list [str ],
26151 env : dict [str , str ] | None = None ,
27- encoding : str = "utf-8" ,
28- encoding_error_handler : Literal ["strict" , "ignore" , "replace" ] = "strict" ,
152+ encoding : str = DEFAULT_ENCODING ,
153+ encoding_error_handler : Literal [
154+ "strict" , "ignore" , "replace"
155+ ] = DEFAULT_ENCODING_ERROR_HANDLER ,
29156 ) -> None :
30- """Connect to a specific MCP server"""
157+ """Connect to a specific MCP server using stdio
158+
159+ Args:
160+ server_name: Name to identify this server connection
161+ command: Command to execute
162+ args: Arguments for the command
163+ env: Environment variables for the command
164+ encoding: Character encoding
165+ encoding_error_handler: How to handle encoding errors
166+ """
31167 server_params = StdioServerParameters (
32168 command = command ,
33169 args = args ,
@@ -44,13 +180,29 @@ async def connect_to_server(
44180 await self .exit_stack .enter_async_context (ClientSession (read , write )),
45181 )
46182
47- # Initialize the session
48- await session .initialize ()
49- self .sessions [server_name ] = session
183+ await self ._initialize_session_and_load_tools (server_name , session )
50184
51- # Load tools from this server
52- server_tools = await load_mcp_tools (session )
53- self .server_name_to_tools [server_name ] = server_tools
185+ async def connect_to_server_via_sse (
186+ self ,
187+ server_name : str ,
188+ * ,
189+ url : str ,
190+ ) -> None :
191+ """Connect to a specific MCP server using SSE
192+
193+ Args:
194+ server_name: Name to identify this server connection
195+ url: URL of the SSE server
196+ """
197+ # Create and store the connection
198+ sse_transport = await self .exit_stack .enter_async_context (sse_client (url ))
199+ read , write = sse_transport
200+ session = cast (
201+ ClientSession ,
202+ await self .exit_stack .enter_async_context (ClientSession (read , write )),
203+ )
204+
205+ await self ._initialize_session_and_load_tools (server_name , session )
54206
55207 def get_tools (self ) -> list [BaseTool ]:
56208 """Get a list of all tools from all connected servers."""
@@ -60,7 +212,23 @@ def get_tools(self) -> list[BaseTool]:
60212 return all_tools
61213
62214 async def __aenter__ (self ) -> "MultiServerMCPClient" :
63- return self
215+ try :
216+ connections = self .connections or {}
217+ for server_name , connection in connections .items ():
218+ connection_dict = connection .copy ()
219+ transport = connection_dict .pop ("transport" )
220+ if transport == "stdio" :
221+ await self .connect_to_server_via_stdio (server_name , ** connection_dict )
222+ elif transport == "sse" :
223+ await self .connect_to_server_via_sse (server_name , ** connection_dict )
224+ else :
225+ raise ValueError (
226+ f"Unsupported transport: { transport } . Must be 'stdio' or 'sse'"
227+ )
228+ return self
229+ except Exception :
230+ await self .exit_stack .aclose ()
231+ raise
64232
65233 async def __aexit__ (
66234 self ,
0 commit comments