11import os
22from contextlib import AsyncExitStack
3+ from pathlib import Path
34from types import TracebackType
45from typing import Any , Literal , Optional , TypedDict , cast
56
1213from langchain_mcp_adapters .prompts import load_mcp_prompt
1314from langchain_mcp_adapters .tools import load_mcp_tools
1415
16+ EncodingErrorHandler = Literal ["strict" , "ignore" , "replace" ]
17+
1518DEFAULT_ENCODING = "utf-8"
16- DEFAULT_ENCODING_ERROR_HANDLER = "strict"
19+ DEFAULT_ENCODING_ERROR_HANDLER : EncodingErrorHandler = "strict"
1720
1821DEFAULT_HTTP_TIMEOUT = 5
1922DEFAULT_SSE_READ_TIMEOUT = 60 * 5
@@ -31,17 +34,23 @@ class StdioConnection(TypedDict):
3134 env : dict [str , str ] | None
3235 """The environment to use when spawning the process."""
3336
37+ cwd : str | Path | None
38+ """The working directory to use when spawning the process."""
39+
3440 encoding : str
3541 """The text encoding used when sending/receiving messages to the server."""
3642
37- encoding_error_handler : Literal [ "strict" , "ignore" , "replace" ]
43+ encoding_error_handler : EncodingErrorHandler
3844 """
3945 The text encoding error handler.
4046
4147 See https://docs.python.org/3/library/codecs.html#codec-base-classes for
4248 explanations of possible values
4349 """
4450
51+ session_kwargs : dict [str , Any ] | None
52+ """Additional keyword arguments to pass to the ClientSession"""
53+
4554
4655class SSEConnection (TypedDict ):
4756 transport : Literal ["sse" ]
@@ -58,6 +67,9 @@ class SSEConnection(TypedDict):
5867 sse_read_timeout : float
5968 """SSE read timeout"""
6069
70+ session_kwargs : dict [str , Any ] | None
71+ """Additional keyword arguments to pass to the ClientSession"""
72+
6173
6274class MultiServerMCPClient :
6375 """Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them."""
@@ -146,6 +158,7 @@ async def connect_to_server(
146158 headers = kwargs .get ("headers" ),
147159 timeout = kwargs .get ("timeout" , DEFAULT_HTTP_TIMEOUT ),
148160 sse_read_timeout = kwargs .get ("sse_read_timeout" , DEFAULT_SSE_READ_TIMEOUT ),
161+ session_kwargs = kwargs .get ("session_kwargs" ),
149162 )
150163 elif transport == "stdio" :
151164 if "command" not in kwargs :
@@ -161,6 +174,7 @@ async def connect_to_server(
161174 encoding_error_handler = kwargs .get (
162175 "encoding_error_handler" , DEFAULT_ENCODING_ERROR_HANDLER
163176 ),
177+ session_kwargs = kwargs .get ("session_kwargs" ),
164178 )
165179 else :
166180 raise ValueError (f"Unsupported transport: { transport } . Must be 'stdio' or 'sse'" )
@@ -176,6 +190,7 @@ async def connect_to_server_via_stdio(
176190 encoding_error_handler : Literal [
177191 "strict" , "ignore" , "replace"
178192 ] = DEFAULT_ENCODING_ERROR_HANDLER ,
193+ session_kwargs : dict [str , Any ] | None = None ,
179194 ) -> None :
180195 """Connect to a specific MCP server using stdio
181196
@@ -186,6 +201,7 @@ async def connect_to_server_via_stdio(
186201 env: Environment variables for the command
187202 encoding: Character encoding
188203 encoding_error_handler: How to handle encoding errors
204+ session_kwargs: Additional keyword arguments to pass to the ClientSession
189205 """
190206 # NOTE: execution commands (e.g., `uvx` / `npx`) require PATH envvar to be set.
191207 # To address this, we automatically inject existing PATH envvar into the `env` value,
@@ -205,9 +221,10 @@ async def connect_to_server_via_stdio(
205221 # Create and store the connection
206222 stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
207223 read , write = stdio_transport
224+ session_kwargs = session_kwargs or {}
208225 session = cast (
209226 ClientSession ,
210- await self .exit_stack .enter_async_context (ClientSession (read , write )),
227+ await self .exit_stack .enter_async_context (ClientSession (read , write , ** session_kwargs )),
211228 )
212229
213230 await self ._initialize_session_and_load_tools (server_name , session )
@@ -220,6 +237,7 @@ async def connect_to_server_via_sse(
220237 headers : dict [str , Any ] | None = None ,
221238 timeout : float = DEFAULT_HTTP_TIMEOUT ,
222239 sse_read_timeout : float = DEFAULT_SSE_READ_TIMEOUT ,
240+ session_kwargs : dict [str , Any ] | None = None ,
223241 ) -> None :
224242 """Connect to a specific MCP server using SSE
225243
@@ -229,15 +247,17 @@ async def connect_to_server_via_sse(
229247 headers: HTTP headers to send to the SSE endpoint
230248 timeout: HTTP timeout
231249 sse_read_timeout: SSE read timeout
250+ session_kwargs: Additional keyword arguments to pass to the ClientSession
232251 """
233252 # Create and store the connection
234253 sse_transport = await self .exit_stack .enter_async_context (
235254 sse_client (url , headers , timeout , sse_read_timeout )
236255 )
237256 read , write = sse_transport
257+ session_kwargs = session_kwargs or {}
238258 session = cast (
239259 ClientSession ,
240- await self .exit_stack .enter_async_context (ClientSession (read , write )),
260+ await self .exit_stack .enter_async_context (ClientSession (read , write , ** session_kwargs )),
241261 )
242262
243263 await self ._initialize_session_and_load_tools (server_name , session )
@@ -260,18 +280,8 @@ async def __aenter__(self) -> "MultiServerMCPClient":
260280 try :
261281 connections = self .connections or {}
262282 for server_name , connection in connections .items ():
263- connection_dict = connection .copy ()
264- transport = connection_dict .pop ("transport" )
265- if transport == "stdio" :
266- # connection_dict is a StdioConnection (with "transport" popped)
267- await self .connect_to_server_via_stdio (server_name , ** connection_dict ) # type: ignore
268- elif transport == "sse" :
269- # connection_dict is a SSEConnection (with "transport" popped)
270- await self .connect_to_server_via_sse (server_name , ** connection_dict ) # type: ignore
271- else :
272- raise ValueError (
273- f"Unsupported transport: { transport } . Must be 'stdio' or 'sse'"
274- )
283+ await self .connect_to_server (server_name , ** connection )
284+
275285 return self
276286 except Exception :
277287 await self .exit_stack .aclose ()
0 commit comments