Skip to content

Commit b769816

Browse files
vbardassmails
authored andcommitted
allow passing session kwargs (langchain-ai#66)
1 parent 3cf8fe7 commit b769816

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

langchain_mcp_adapters/client.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from contextlib import AsyncExitStack
3+
from pathlib import Path
34
from types import TracebackType
45
from typing import Any, Literal, Optional, TypedDict, cast
56

@@ -12,8 +13,10 @@
1213
from langchain_mcp_adapters.prompts import load_mcp_prompt
1314
from langchain_mcp_adapters.tools import load_mcp_tools
1415

16+
EncodingErrorHandler = Literal["strict", "ignore", "replace"]
17+
1518
DEFAULT_ENCODING = "utf-8"
16-
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
19+
DEFAULT_ENCODING_ERROR_HANDLER: EncodingErrorHandler = "strict"
1720

1821
DEFAULT_HTTP_TIMEOUT = 5
1922
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
@@ -31,17 +34,23 @@ class StdioConnection(TypedDict):
3134
env: dict[str, str] | None
3235
"""The environment to use when spawning the process."""
3336

37+
cwd: str | Path | None
38+
"""The working directory to use when spawning the process."""
39+
3440
encoding: str
3541
"""The text encoding used when sending/receiving messages to the server."""
3642

37-
encoding_error_handler: Literal["strict", "ignore", "replace"]
43+
encoding_error_handler: EncodingErrorHandler
3844
"""
3945
The text encoding error handler.
4046
4147
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
4248
explanations of possible values
4349
"""
4450

51+
session_kwargs: dict[str, Any] | None
52+
"""Additional keyword arguments to pass to the ClientSession"""
53+
4554

4655
class SSEConnection(TypedDict):
4756
transport: Literal["sse"]
@@ -58,6 +67,9 @@ class SSEConnection(TypedDict):
5867
sse_read_timeout: float
5968
"""SSE read timeout"""
6069

70+
session_kwargs: dict[str, Any] | None
71+
"""Additional keyword arguments to pass to the ClientSession"""
72+
6173

6274
class MultiServerMCPClient:
6375
"""Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them."""
@@ -146,6 +158,7 @@ async def connect_to_server(
146158
headers=kwargs.get("headers"),
147159
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
148160
sse_read_timeout=kwargs.get("sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT),
161+
session_kwargs=kwargs.get("session_kwargs"),
149162
)
150163
elif transport == "stdio":
151164
if "command" not in kwargs:
@@ -161,6 +174,7 @@ async def connect_to_server(
161174
encoding_error_handler=kwargs.get(
162175
"encoding_error_handler", DEFAULT_ENCODING_ERROR_HANDLER
163176
),
177+
session_kwargs=kwargs.get("session_kwargs"),
164178
)
165179
else:
166180
raise ValueError(f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'")
@@ -176,6 +190,7 @@ async def connect_to_server_via_stdio(
176190
encoding_error_handler: Literal[
177191
"strict", "ignore", "replace"
178192
] = DEFAULT_ENCODING_ERROR_HANDLER,
193+
session_kwargs: dict[str, Any] | None = None,
179194
) -> None:
180195
"""Connect to a specific MCP server using stdio
181196
@@ -186,6 +201,7 @@ async def connect_to_server_via_stdio(
186201
env: Environment variables for the command
187202
encoding: Character encoding
188203
encoding_error_handler: How to handle encoding errors
204+
session_kwargs: Additional keyword arguments to pass to the ClientSession
189205
"""
190206
# NOTE: execution commands (e.g., `uvx` / `npx`) require PATH envvar to be set.
191207
# To address this, we automatically inject existing PATH envvar into the `env` value,
@@ -205,9 +221,10 @@ async def connect_to_server_via_stdio(
205221
# Create and store the connection
206222
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
207223
read, write = stdio_transport
224+
session_kwargs = session_kwargs or {}
208225
session = cast(
209226
ClientSession,
210-
await self.exit_stack.enter_async_context(ClientSession(read, write)),
227+
await self.exit_stack.enter_async_context(ClientSession(read, write, **session_kwargs)),
211228
)
212229

213230
await self._initialize_session_and_load_tools(server_name, session)
@@ -220,6 +237,7 @@ async def connect_to_server_via_sse(
220237
headers: dict[str, Any] | None = None,
221238
timeout: float = DEFAULT_HTTP_TIMEOUT,
222239
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
240+
session_kwargs: dict[str, Any] | None = None,
223241
) -> None:
224242
"""Connect to a specific MCP server using SSE
225243
@@ -229,15 +247,17 @@ async def connect_to_server_via_sse(
229247
headers: HTTP headers to send to the SSE endpoint
230248
timeout: HTTP timeout
231249
sse_read_timeout: SSE read timeout
250+
session_kwargs: Additional keyword arguments to pass to the ClientSession
232251
"""
233252
# Create and store the connection
234253
sse_transport = await self.exit_stack.enter_async_context(
235254
sse_client(url, headers, timeout, sse_read_timeout)
236255
)
237256
read, write = sse_transport
257+
session_kwargs = session_kwargs or {}
238258
session = cast(
239259
ClientSession,
240-
await self.exit_stack.enter_async_context(ClientSession(read, write)),
260+
await self.exit_stack.enter_async_context(ClientSession(read, write, **session_kwargs)),
241261
)
242262

243263
await self._initialize_session_and_load_tools(server_name, session)
@@ -260,18 +280,8 @@ async def __aenter__(self) -> "MultiServerMCPClient":
260280
try:
261281
connections = self.connections or {}
262282
for server_name, connection in connections.items():
263-
connection_dict = connection.copy()
264-
transport = connection_dict.pop("transport")
265-
if transport == "stdio":
266-
# connection_dict is a StdioConnection (with "transport" popped)
267-
await self.connect_to_server_via_stdio(server_name, **connection_dict) # type: ignore
268-
elif transport == "sse":
269-
# connection_dict is a SSEConnection (with "transport" popped)
270-
await self.connect_to_server_via_sse(server_name, **connection_dict) # type: ignore
271-
else:
272-
raise ValueError(
273-
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
274-
)
283+
await self.connect_to_server(server_name, **connection)
284+
275285
return self
276286
except Exception:
277287
await self.exit_stack.aclose()

0 commit comments

Comments
 (0)