Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mcp/client/streamable_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

async def main():
try:
# To access RAGFlow server in `host` mode, you need to attach `api_key` for each request to indicate identification.
# async with streamablehttp_client("http://localhost:9382/mcp/", headers={"api_key": "ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _):
# Or follow the requirements of OAuth 2.1 Section 5 with Authorization header
# async with streamablehttp_client("http://localhost:9382/mcp/", headers={"Authorization": "Bearer ragflow-fixS-TicrohljzFkeLLWIaVhW7XlXPXIUW5solFor6o"}) as (read_stream, write_stream, _):
async with streamablehttp_client("http://localhost:9382/mcp/") as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
Expand Down
169 changes: 113 additions & 56 deletions mcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any

import click
import httpx
import mcp.types as types
from mcp.server.lowlevel import Server
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route
from strenum import StrEnum

import mcp.types as types
from mcp.server.lowlevel import Server


class LaunchMode(StrEnum):
SELF_HOST = "self-host"
Expand Down Expand Up @@ -68,10 +68,6 @@ def __init__(self, base_url: str, version="v1"):
self.api_url = f"{self.base_url}/api/{self.version}"
self._async_client = None

def bind_api_key(self, api_key: str):
self.api_key = api_key
self.authorization_header = {"Authorization": f"Bearer {self.api_key}"}

async def _get_client(self):
if self._async_client is None:
self._async_client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
Expand All @@ -82,16 +78,18 @@ async def close(self):
await self._async_client.aclose()
self._async_client = None

async def _post(self, path, json=None, stream=False, files=None):
if not self.api_key:
async def _post(self, path, json=None, stream=False, files=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client()
res = await client.post(url=self.api_url + path, json=json, headers=self.authorization_header)
res = await client.post(url=self.api_url + path, json=json, headers={"Authorization": f"Bearer {api_key}"})
return res

async def _get(self, path, params=None):
async def _get(self, path, params=None, api_key: str = ""):
if not api_key:
return None
client = await self._get_client()
res = await client.get(url=self.api_url + path, params=params, headers=self.authorization_header)
res = await client.get(url=self.api_url + path, params=params, headers={"Authorization": f"Bearer {api_key}"})
return res

def _is_cache_valid(self, ts):
Expand Down Expand Up @@ -129,8 +127,18 @@ def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list)
self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp())
self._document_metadata_cache.move_to_end(dataset_id)

async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
async def list_datasets(
self,
*,
api_key: str,
page: int = 1,
page_size: int = 1000,
orderby: str = "create_time",
desc: bool = True,
id: str | None = None,
name: str | None = None,
):
res = await self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])

Expand All @@ -145,6 +153,8 @@ async def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str

async def retrieval(
self,
*,
api_key: str,
dataset_ids,
document_ids=None,
question="",
Expand All @@ -162,7 +172,7 @@ async def retrieval(

# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await self.list_datasets()
dataset_list_str = await self.list_datasets(api_key=api_key)
dataset_ids = []

# Parse the dataset list to extract IDs
Expand All @@ -189,7 +199,7 @@ async def retrieval(
"document_ids": document_ids,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = await self._post("/retrieval", json=data_json)
res = await self._post("/retrieval", json=data_json, api_key=api_key)
if not res or res.status_code != 200:
raise Exception([types.TextContent(type="text", text="Cannot process this operation.")])

Expand All @@ -199,7 +209,7 @@ async def retrieval(
chunks = []

# Cache document metadata and dataset information
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, force_refresh=force_refresh)
document_cache, dataset_cache = await self._get_document_metadata_cache(dataset_ids, api_key=api_key, force_refresh=force_refresh)

# Process chunks with enhanced field mapping including per-chunk metadata
for chunk_data in data.get("chunks", []):
Expand Down Expand Up @@ -228,7 +238,7 @@ async def retrieval(

raise Exception([types.TextContent(type="text", text=res.get("message"))])

async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
async def _get_document_metadata_cache(self, dataset_ids, *, api_key: str, force_refresh=False):
"""Cache document metadata for all documents in the specified datasets"""
document_cache = {}
dataset_cache = {}
Expand All @@ -238,7 +248,7 @@ async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
dataset_meta = None if force_refresh else self._get_cached_dataset_metadata(dataset_id)
if not dataset_meta:
# First get dataset info for name
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1})
dataset_res = await self._get("/datasets", {"id": dataset_id, "page_size": 1}, api_key=api_key)
if dataset_res and dataset_res.status_code == 200:
dataset_data = dataset_res.json()
if dataset_data.get("code") == 0 and dataset_data.get("data"):
Expand All @@ -255,7 +265,9 @@ async def _get_document_metadata_cache(self, dataset_ids, force_refresh=False):
doc_id_meta_list = []
docs = {}
while page:
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}")
docs_res = await self._get(f"/datasets/{dataset_id}/documents?page={page}", api_key=api_key)
if not docs_res:
break
docs_data = docs_res.json()
if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"):
for doc in docs_data["data"]["docs"]:
Expand Down Expand Up @@ -335,9 +347,59 @@ async def sse_lifespan(server: Server) -> AsyncIterator[dict]:


app = Server("ragflow-mcp-server", lifespan=sse_lifespan)
AUTH_TOKEN_STATE_KEY = "ragflow_auth_token"


def _to_text(value: Any) -> str:
if isinstance(value, bytes):
return value.decode(errors="ignore")
return str(value)


def _extract_token_from_headers(headers: Any) -> str | None:
if not headers or not hasattr(headers, "get"):
return None

auth_keys = ("authorization", "Authorization", b"authorization", b"Authorization")
for key in auth_keys:
auth = headers.get(key)
if not auth:
continue
auth_text = _to_text(auth).strip()
if auth_text.lower().startswith("bearer "):
token = auth_text[7:].strip()
if token:
return token

api_key_keys = ("api_key", "x-api-key", "Api-Key", "X-API-Key", b"api_key", b"x-api-key", b"Api-Key", b"X-API-Key")
for key in api_key_keys:
token = headers.get(key)
if token:
token_text = _to_text(token).strip()
if token_text:
return token_text

return None


def _extract_token_from_request(request: Any) -> str | None:
if request is None:
return None

state = getattr(request, "state", None)
if state is not None:
token = getattr(state, AUTH_TOKEN_STATE_KEY, None)
if token:
return token

def with_api_key(required=True):
token = _extract_token_from_headers(getattr(request, "headers", None))
if token and state is not None:
setattr(state, AUTH_TOKEN_STATE_KEY, token)

return token


def with_api_key(required: bool = True):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
Expand All @@ -347,26 +409,14 @@ async def wrapper(*args, **kwargs):
raise ValueError("Get RAGFlow Context failed")

connector = ragflow_ctx.conn
api_key = HOST_API_KEY

if MODE == LaunchMode.HOST:
headers = ctx.session._init_options.capabilities.experimental.get("headers", {})
token = None

# lower case here, because of Starlette conversion
auth = headers.get("authorization", "")
if auth.startswith("Bearer "):
token = auth.removeprefix("Bearer ").strip()
elif "api_key" in headers:
token = headers["api_key"]

if required and not token:
api_key = _extract_token_from_request(getattr(ctx, "request", None)) or ""
if required and not api_key:
raise ValueError("RAGFlow API key or Bearer token is required.")

connector.bind_api_key(token)
else:
connector.bind_api_key(HOST_API_KEY)

return await func(*args, connector=connector, **kwargs)
return await func(*args, connector=connector, api_key=api_key, **kwargs)

return wrapper

Expand All @@ -375,8 +425,8 @@ async def wrapper(*args, **kwargs):

@app.list_tools()
@with_api_key(required=True)
async def list_tools(*, connector) -> list[types.Tool]:
dataset_description = await connector.list_datasets()
async def list_tools(*, connector: RAGFlowConnector, api_key: str) -> list[types.Tool]:
dataset_description = await connector.list_datasets(api_key=api_key)

return [
types.Tool(
Expand Down Expand Up @@ -446,7 +496,13 @@ async def list_tools(*, connector) -> list[types.Tool]:

@app.call_tool()
@with_api_key(required=True)
async def call_tool(name: str, arguments: dict, *, connector) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
async def call_tool(
name: str,
arguments: dict,
*,
connector: RAGFlowConnector,
api_key: str,
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name == "ragflow_retrieval":
document_ids = arguments.get("document_ids", [])
dataset_ids = arguments.get("dataset_ids", [])
Expand All @@ -462,7 +518,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text

# If no dataset_ids provided or empty list, get all available dataset IDs
if not dataset_ids:
dataset_list_str = await connector.list_datasets()
dataset_list_str = await connector.list_datasets(api_key=api_key)
dataset_ids = []

# Parse the dataset list to extract IDs
Expand All @@ -477,6 +533,7 @@ async def call_tool(name: str, arguments: dict, *, connector) -> list[types.Text
continue

return await connector.retrieval(
api_key=api_key,
dataset_ids=dataset_ids,
document_ids=document_ids,
question=question,
Expand Down Expand Up @@ -510,17 +567,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
path = scope["path"]
if path.startswith("/messages/") or path.startswith("/sse") or path.startswith("/mcp"):
headers = dict(scope["headers"])
token = None
auth_header = headers.get(b"authorization")
if auth_header and auth_header.startswith(b"Bearer "):
token = auth_header.removeprefix(b"Bearer ").strip()
elif b"api_key" in headers:
token = headers[b"api_key"]
token = _extract_token_from_headers(headers)

if not token:
response = JSONResponse({"error": "Missing or invalid authorization header"}, status_code=401)
await response(scope, receive, send)
return
scope.setdefault("state", {})[AUTH_TOKEN_STATE_KEY] = token

await self.app(scope, receive, send)

Expand All @@ -547,9 +600,8 @@ async def handle_sse(request):
# Add streamable HTTP route if enabled
streamablehttp_lifespan = None
if TRANSPORT_STREAMABLE_HTTP_ENABLED:
from starlette.types import Receive, Scope, Send

from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.types import Receive, Scope, Send

session_manager = StreamableHTTPSessionManager(
app=app,
Expand All @@ -558,8 +610,11 @@ async def handle_sse(request):
stateless=True,
)

async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
await session_manager.handle_request(scope, receive, send)
class StreamableHTTPEntry:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await session_manager.handle_request(scope, receive, send)

streamable_http_entry = StreamableHTTPEntry()

@asynccontextmanager
async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
Expand All @@ -570,7 +625,12 @@ async def streamablehttp_lifespan(app: Starlette) -> AsyncIterator[None]:
finally:
logging.info("StreamableHTTP application shutting down...")

routes.append(Mount("/mcp", app=handle_streamable_http))
routes.extend(
[
Route("/mcp", endpoint=streamable_http_entry, methods=["GET", "POST", "DELETE"]),
Mount("/mcp", app=streamable_http_entry),
]
)

return Starlette(
debug=True,
Expand Down Expand Up @@ -631,9 +691,6 @@ def parse_bool_flag(key: str, default: bool) -> bool:
if MODE == LaunchMode.SELF_HOST and not HOST_API_KEY:
raise click.UsageError("--api-key is required when --mode is 'self-host'")

if TRANSPORT_STREAMABLE_HTTP_ENABLED and MODE == LaunchMode.HOST:
raise click.UsageError("The --host mode is not supported with streamable-http transport yet.")

if not TRANSPORT_STREAMABLE_HTTP_ENABLED and JSON_RESPONSE:
JSON_RESPONSE = False

Expand Down Expand Up @@ -690,7 +747,7 @@ def parse_bool_flag(key: str, default: bool) -> bool:
--base-url=http://127.0.0.1:9380 \
--mode=self-host --api-key=ragflow-xxxxx

2. Host mode (multi-tenant, self-host only, clients must provide Authorization headers):
2. Host mode (multi-tenant, clients must provide Authorization headers):
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 \
--base-url=http://127.0.0.1:9380 \
--mode=host
Expand Down