|
14 | 14 | DEFAULT_ENCODING = "utf-8" |
15 | 15 | DEFAULT_ENCODING_ERROR_HANDLER = "strict" |
16 | 16 |
|
| 17 | +DEFAULT_HTTP_TIMEOUT = 5 |
| 18 | +DEFAULT_SSE_READ_TIMEOUT = 60 * 5 |
| 19 | + |
17 | 20 |
|
18 | 21 | class StdioConnection(TypedDict): |
19 | 22 | transport: Literal["stdio"] |
@@ -45,6 +48,15 @@ class SSEConnection(TypedDict): |
45 | 48 | url: str |
46 | 49 | """The URL of the SSE endpoint to connect to.""" |
47 | 50 |
|
| 51 | + headers: dict[str, Any] | None = None |
| 52 | + """HTTP headers to send to the SSE endpoint""" |
| 53 | + |
| 54 | + timeout: float |
| 55 | + """HTTP timeout""" |
| 56 | + |
| 57 | + sse_read_timeout: float |
| 58 | + """SSE read timeout""" |
| 59 | + |
48 | 60 |
|
49 | 61 | class MultiServerMCPClient: |
50 | 62 | """Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them.""" |
@@ -125,7 +137,13 @@ async def connect_to_server( |
125 | 137 | if transport == "sse": |
126 | 138 | if "url" not in kwargs: |
127 | 139 | raise ValueError("'url' parameter is required for SSE connection") |
128 | | - await self.connect_to_server_via_sse(server_name, url=kwargs["url"]) |
| 140 | + await self.connect_to_server_via_sse( |
| 141 | + server_name, |
| 142 | + url=kwargs["url"], |
| 143 | + headers=kwargs.get("headers"), |
| 144 | + timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT), |
| 145 | + sse_read_timeout=kwargs.get("sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT), |
| 146 | + ) |
129 | 147 | elif transport == "stdio": |
130 | 148 | if "command" not in kwargs: |
131 | 149 | raise ValueError("'command' parameter is required for stdio connection") |
@@ -189,15 +207,23 @@ async def connect_to_server_via_sse( |
189 | 207 | server_name: str, |
190 | 208 | *, |
191 | 209 | url: str, |
| 210 | + headers: dict[str, Any] | None = None, |
| 211 | + timeout: float = DEFAULT_HTTP_TIMEOUT, |
| 212 | + sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT, |
192 | 213 | ) -> None: |
193 | 214 | """Connect to a specific MCP server using SSE |
194 | 215 |
|
195 | 216 | Args: |
196 | 217 | server_name: Name to identify this server connection |
197 | 218 | url: URL of the SSE server |
| 219 | + headers: HTTP headers to send to the SSE endpoint |
| 220 | + timeout: HTTP timeout |
| 221 | + sse_read_timeout: SSE read timeout |
198 | 222 | """ |
199 | 223 | # Create and store the connection |
200 | | - sse_transport = await self.exit_stack.enter_async_context(sse_client(url)) |
| 224 | + sse_transport = await self.exit_stack.enter_async_context( |
| 225 | + sse_client(url, headers, timeout, sse_read_timeout) |
| 226 | + ) |
201 | 227 | read, write = sse_transport |
202 | 228 | session = cast( |
203 | 229 | ClientSession, |
|
0 commit comments