Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def _do_run(
# Try to parse the action input to dict
if action_input and isinstance(action_input, str):
tool_args = parse_or_raise_error(action_input)
elif isinstance(action_input, dict):
elif isinstance(action_input, dict) or isinstance(action_input, list):
tool_args = action_input
action_input_str = json.dumps(action_input, ensure_ascii=False)
except json.JSONDecodeError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ async def run_tool(
if parsed_args and isinstance(parsed_args, tuple):
args = parsed_args[1]

if args is not None and isinstance(args, list) and len(args) == 0:
# Input args is empty list, just use default args
args = {}

try:
tool_result = await tool_pack.async_execute(resource_name=name, **args)
status = Status.COMPLETE.value
Expand Down
60 changes: 57 additions & 3 deletions packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import logging
import os
import ssl
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast

from mcp import ClientSession
from mcp.client.sse import sse_client

from dbgpt.util.json_utils import parse_or_raise_error

from ...util.mcp_utils import sse_client
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
from ..pack import Resource, ResourcePack
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
Expand Down Expand Up @@ -66,6 +67,8 @@ def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
# The position arguments is empty
args = ()
kwargs = parse_or_raise_error(input_str)
if kwargs is not None and isinstance(kwargs, list) and len(kwargs) == 0:
kwargs = {}
return args, kwargs


Expand Down Expand Up @@ -303,13 +306,47 @@ class MCPToolPack(ToolPack):
}
}
)

If you want to set the ssl verify, you can use the ssl_verify parameter:
.. code-block:: python

# Default ssl_verify is True
tools = MCPToolPack(
"https://your_ssl_domain/sse",
)

# Set the default ssl_verify to False to disable ssl verify
tools2 = MCPToolPack(
"https://your_ssl_domain/sse", default_ssl_verify=False
)

# With Custom CA file
tools3 = MCPToolPack(
"https://your_ssl_domain/sse", default_ssl_cafile="/path/to/your/ca.crt"
)

# Set the ssl_verify for each server
import ssl

tools4 = MCPToolPack(
"https://your_ssl_domain/sse",
ssl_verify={
"https://your_ssl_domain/sse": ssl.create_default_context(
cafile="/path/to/your/ca.crt"
),
},
)

"""

def __init__(
self,
mcp_servers: Union[str, List[str]],
headers: Optional[Dict[str, Dict[str, Any]]] = None,
default_headers: Optional[Dict[str, Any]] = None,
ssl_verify: Optional[Dict[str, Union[ssl.SSLContext, str, bool]]] = None,
default_ssl_verify: Union[ssl.SSLContext, str, bool] = True,
default_ssl_cafile: Optional[str] = None,
**kwargs,
):
"""Create an Auto-GPT plugin tool pack."""
Expand All @@ -320,6 +357,12 @@ def __init__(
self._default_headers = default_headers or {}
self._headers_map = headers or {}
self.server_headers_map = {}
if default_ssl_cafile and not ssl_verify and default_ssl_verify:
default_ssl_verify = ssl.create_default_context(cafile=default_ssl_cafile)

self._default_ssl_verify = default_ssl_verify
self._ssl_verify_map = ssl_verify or {}
self.server_ssl_verify_map = {}

def switch_mcp_input_schema(self, input_schema: dict):
args = {}
Expand Down Expand Up @@ -362,8 +405,14 @@ async def preload_resource(self):
for server in server_list:
server_headers = self._headers_map.get(server, self._default_headers)
self.server_headers_map[server] = server_headers
server_ssl_verify = self._ssl_verify_map.get(
server, self._default_ssl_verify
)
self.server_ssl_verify_map[server] = server_ssl_verify

async with sse_client(url=server, headers=server_headers) as (read, write):
async with sse_client(
url=server, headers=server_headers, verify=server_ssl_verify
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
Expand All @@ -378,8 +427,13 @@ async def call_mcp_tool(
):
try:
headers_to_use = self.server_headers_map.get(server, {})
ssl_verify_to_use = self.server_ssl_verify_map.get(
server, True
)
async with sse_client(
url=server, headers=headers_to_use
url=server,
headers=headers_to_use,
verify=ssl_verify_to_use,
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
Expand Down
147 changes: 147 additions & 0 deletions packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import logging
import ssl
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse

import anyio
import httpx
import mcp.types as types
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse

logger = logging.getLogger(__name__)


def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)


@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
verify: ssl.SSLContext | str | bool = True,
):
"""
Client transport for SSE.

`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers, verify=verify) as client:
async with aconnect_sse(
client,
"GET",
url,
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")

async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)

url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)

task_status.started(endpoint_url)

case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
f"Received server message: {message}"
)
except Exception as exc:
logger.error(
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc)
continue

await read_stream_writer.send(message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
)
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()

async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for message in write_stream_reader:
logger.debug(f"Sending client message: {message}")
response = await client.post(
endpoint_url,
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await write_stream.aclose()

endpoint_url = await tg.start(sse_reader)
logger.info(
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url)

try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
6 changes: 4 additions & 2 deletions packages/dbgpt-core/src/dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dbgpt.util.i18n_utils import _

if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from httpx._types import ProxiesTypes, ProxyTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI

ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
Expand Down Expand Up @@ -139,6 +139,7 @@ def __init__(
api_version: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
proxy: Optional["ProxyTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "gpt-4o-mini",
context_length: Optional[int] = 8192,
Expand All @@ -160,6 +161,7 @@ def __init__(
api_key=self._resolve_env_vars(api_key),
api_version=self._resolve_env_vars(api_version),
proxies=proxies,
proxy=proxy,
full_url=kwargs.get("full_url"),
)

Expand Down Expand Up @@ -203,7 +205,7 @@ def new_client(
api_type=model_params.api_type,
api_version=model_params.api_version,
model=model_params.real_provider_model_name,
proxies=model_params.http_proxy,
proxy=model_params.http_proxy,
model_alias=model_params.real_provider_model_name,
context_length=max(model_params.context_length or 8192, 8192),
# full_url=model_params.proxy_server_url,
Expand Down
53 changes: 48 additions & 5 deletions packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def from_dict(
class MCPSSEToolPack(MCPToolPack):
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
"""Initialize the MCPSSEToolPack with the given MCP servers."""
import ssl

headers = {}
# token is not supported in sse mode
servers = (
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
)
if "token" in kwargs and kwargs["token"]:
# token is not supported in sse mode
servers = (
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
)
tokens = (
kwargs["token"].split(";")
if isinstance(kwargs["token"], str)
Expand All @@ -69,7 +71,33 @@ def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
for server in servers:
headers[server] = {"Authorization": f"Bearer {token}"}
kwargs.pop("token")
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
ssl_verify = True
ssl_verify_map = {}
if "no_ssl_verify" in kwargs:
if kwargs["no_ssl_verify"] is True:
ssl_verify = False
kwargs.pop("no_ssl_verify")
if ssl_verify is True and "ssl_ca_cert" in kwargs:
ssl_ca_certs = (
kwargs["ssl_ca_cert"].split(";")
if isinstance(kwargs["ssl_ca_cert"], str)
else kwargs["ssl_ca_cert"]
)
if len(servers) == len(ssl_ca_certs):
for i, ssl_ca_cert in enumerate(ssl_ca_certs):
ssl_verify_map[servers[i]] = ssl.create_default_context(
cafile=ssl_ca_cert
)
else:
ssl_ca_cert = ssl_ca_certs[0]
for server in servers:
ssl_verify_map[server] = ssl.create_default_context(
cafile=ssl_ca_cert
)
verify = ssl_verify_map if ssl_verify_map else ssl_verify
super().__init__(
mcp_servers=mcp_servers, headers=headers, ssl_verify=verify, **kwargs
)

@classmethod
def type_alias(cls) -> str:
Expand Down Expand Up @@ -97,5 +125,20 @@ class _DynMCPSSEPackResourceParameters(MCPPackResourceParameters):
"tags": "privacy",
},
)
no_ssl_verify: bool = dataclasses.field(
default=False,
metadata={
"help": _(
"Disable SSL verification. "
"This is not recommended for production use."
),
},
)
ssl_ca_cert: Optional[str] = dataclasses.field(
default=None,
metadata={
"help": _("Path to the CA certificate file. split by ';' "),
},
)

return _DynMCPSSEPackResourceParameters