diff --git a/Cargo.lock b/Cargo.lock index a07ca9a8e9..bfb2c9069e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -304,6 +304,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-server" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "495c05f60d6df0093e8fb6e74aa5846a0ad06abaf96d76166283720bf740f8ab" +dependencies = [ + "bytes", + "fs-err", + "http", + "http-body", + "hyper", + "hyper-util", + "tokio", + "tower-service", +] + [[package]] name = "axum_static" version = "1.7.1" @@ -1379,6 +1395,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fd-lock" version = "4.0.4" @@ -1447,6 +1469,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + [[package]] name = "foreign-types" version = "0.5.0" @@ -1454,7 +1485,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -1468,6 +1499,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1483,6 +1520,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d7be93788013f265201256d58f04936a8079ad5dc898743aa20525f503b683" +dependencies = [ + "autocfg", + "tokio", +] + [[package]] name = "futf" version = "0.1.5" @@ -2020,6 +2067,22 @@ dependencies = [ "webpki-roots 1.0.0", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.14" @@ -2613,7 +2676,7 @@ dependencies = [ "bitflags 2.9.1", "block", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "log", "objc", "paste", @@ -2906,6 +2969,7 @@ name = "mistralrs-server" version = "0.6.0" dependencies = [ "anyhow", + "async-trait", "axum 0.8.4", "clap", "ctrlc", @@ -2916,6 +2980,7 @@ dependencies = [ "mistralrs-server-core", "once_cell", "regex", + "rust-mcp-sdk", "rustyline", "serde", "serde_json", @@ -3024,6 +3089,23 @@ dependencies = [ "version_check", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -3261,6 +3343,50 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +[[package]] +name = "openssl" +version = "0.10.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -3400,6 +3526,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "png" version = "0.17.16" @@ -3911,11 +4043,13 @@ dependencies = [ "http-body-util", "hyper", "hyper-rustls", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -3927,6 +4061,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tokio-rustls", "tokio-util", "tower", @@ -4000,6 +4135,71 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust-mcp-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512fe25018087e27b9abbe1806fc83e4741c7b8d0dc5f7fa8c4f0f3b389ef658" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "syn 2.0.101", +] + +[[package]] +name = "rust-mcp-schema" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c9966340f5104a8d22b6c2db8901f8626a0f737820a385db3ffbb29b1f6ae0f" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "rust-mcp-sdk" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edbc126903df8ba1cb99a56b18b95875b34a24f81d05575ebb5eddc4c3293322" +dependencies = [ + "async-trait", + "axum 0.8.4", + "axum-server", + "futures", + "hyper", + "rust-mcp-macros", + "rust-mcp-schema", + "rust-mcp-transport", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tokio-stream", + "tracing", + "uuid 1.17.0", +] + +[[package]] +name = "rust-mcp-transport" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1195b57bfe37111460b44e1624a6552c7378749195709d104d28fc5b1ef59f9c" +dependencies = [ + "async-trait", + "bytes", + "futures", + "reqwest", + "rust-mcp-schema", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tokio-stream", + "tracing", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -4147,6 +4347,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "schemars" version = "0.8.22" @@ -4192,6 +4401,29 @@ dependencies = [ "tendril", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.9.1", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "selectors" version = "0.26.0" @@ -4826,6 +5058,19 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" +[[package]] +name = "tempfile" +version = "3.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "tendril" version = "0.4.3" @@ -5002,6 +5247,16 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rayon" version = "2.1.0" @@ -5022,6 +5277,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-tungstenite" version = "0.26.2" @@ -5513,6 +5779,12 @@ dependencies = [ "uuid 0.8.2", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/Cargo.toml b/Cargo.toml index d74bc37edf..c25ccc0f74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -160,6 +160,7 @@ mime_guess = "2.0.5" include_dir = "0.7.4" http = "1.3.1" hyper = "1.6.0" +rust-mcp-sdk = { version = "0.4.2", default-features = false, features = ["server", "hyper-server", "2025_03_26"] } bindgen_cuda = { git = "https://github.com/guoqingbao/bindgen_cuda.git", version = "0.1.6" } rubato = "0.16.2" rustfft = "6.3.0" diff --git a/README.md b/README.md index 2c0c5c0fb1..6560dcd829 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,7 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis - [Rust API](https://ericlbuehler.github.io/mistral.rs/mistralrs/) & [Python API](mistralrs-pyo3/API.md) - [Automatic device mapping](docs/DEVICE_MAPPING.md) (multi-GPU, CPU) - [Chat templates](docs/CHAT_TOK.md) & tokenizer auto-detection + - [MCP protocol](docs/MCP.md) for structured, realtime tool calls 2. **Performance** - CPU acceleration (MKL, AVX, [NEON](docs/DEVICE_MAPPING.md#arm-neon), [Accelerate](docs/DEVICE_MAPPING.md#apple-accelerate)) @@ -184,6 +185,16 @@ OpenAI API compatible API server - [Example](examples/server/chat.py) - [Use or extend the server in other axum projects](https://ericlbuehler.github.io/mistral.rs/mistralrs_server_core/) +### MCP Protocol + +Serve the same models over the open [MCP](docs/MCP.md) (Model Control Protocol) in parallel to the HTTP API: + +```bash +./mistralrs-server --mcp-port 4321 plain -m Qwen/Qwen3-4B +``` + +See the [docs](docs/MCP.md) for feature flags, examples and limitations. + ### Llama Index integration diff --git a/docs/HTTP.md b/docs/HTTP.md index 628d2968aa..56898d6473 100644 --- a/docs/HTTP.md +++ b/docs/HTTP.md @@ -4,6 +4,9 @@ Mistral.rs provides a lightweight OpenAI API compatible HTTP server based on [ax The API consists of the following endpoints. They can be viewed in your browser interactively by going to `http://localhost:/docs`. +> ℹ️ Besides the HTTP endpoints described below `mistralrs-server` can also expose the same functionality via the **MCP protocol**. +> Enable it with `--mcp-port ` and see [MCP.md](MCP.md) for details. + ## Additional object keys To support additional features, we have extended the completion and chat completion request objects. Both have the same keys added: diff --git a/docs/MCP.md b/docs/MCP.md new file mode 100644 index 0000000000..a471d49df1 --- /dev/null +++ b/docs/MCP.md @@ -0,0 +1,202 @@ +# MCP protocol support + +`mistralrs-server` can serve **MCP (Model Control Protocol)** traffic next to the regular OpenAI-compatible HTTP interface! + +MCP is an open, tool-based protocol that lets clients interact with models through structured *tool calls* instead of free-form HTTP routes. + +Under the hood the server uses [`rust-mcp-sdk`](https://crates.io/crates/rust-mcp-sdk) and exposes tools based on the supported modalities of the loaded model. + +Exposed tools: + +| Tool | Minimum `input` -> `output` modalities | Description | +| -- | -- | -- | +| `chat` | | `Text` -> `Text` | Wraps the OpenAI `/v1/chat/completions` endpoint. | + + +--- + +## ToC +- [MCP protocol support](#mcp-protocol-support) + - [ToC](#toc) + - [Running](#running) + - [Check if it's working](#check-if-its-working) + - [Example clients](#example-clients) + - [Rust](#rust) + - [Python](#python) + - [HTTP](#http) + - [Limitations](#limitations) + +--- + +## Running + +Start the normal HTTP server and add the `--mcp-port` flag to spin up an MCP server on a separate port: + +```bash +./target/release/mistralrs-server \ + --port 1234 # OpenAI compatible HTTP API + --mcp-port 4321 # MCP protocol endpoint (Streamable HTTP) + plain -m mistralai/Mistral-7B-Instruct-v0.3 +``` + +## Check if it's working + +Run this `curl` command to check the available tools: + +``` +curl -X POST http://localhost:4321/mcp \ +-H "Content-Type: application/json" \ +-d '{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} +}' +``` + +## Example clients + +### Rust + +```rust +use anyhow::Result; +use rust_mcp_sdk::{ + mcp_client::client_runtime, + schema::{ + CallToolRequestParams, ClientCapabilities, CreateMessageRequest, + Implementation, InitializeRequestParams, Message, LATEST_PROTOCOL_VERSION, + }, + ClientSseTransport, ClientSseTransportOptions, +}; + +struct Handler; +#[async_trait::async_trait] +impl rust_mcp_sdk::mcp_client::ClientHandler for Handler {} + +#[tokio::main] +async fn main() -> Result<()> { + let transport = ClientSseTransport::new( + "http://localhost:4321/mcp", + ClientSseTransportOptions::default(), + )?; + + let details = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { name: "mcp-client".into(), version: "0.1".into() }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + let client = client_runtime::create_client(details, transport, Handler); + client.clone().start().await?; + + let req = CreateMessageRequest { + model: "mistralai/Mistral-7B-Instruct-v0.3".into(), + messages: vec![Message::user("Explain Rust ownership.")], + ..Default::default() + }; + + let result = client + .call_tool(CallToolRequestParams::new("chat", req.into())) + .await?; + + println!("{}", result.content[0].as_text_content()?.text); + client.shut_down().await?; + Ok(()) +} +``` + +### Python + +```py +import asyncio +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + +SERVER_URL = "http://localhost:4321/mcp" + +async def main() -> None: + async with streamablehttp_client(SERVER_URL) as (read, write, _): + async with ClientSession(read, write) as session: + + # --- INITIALIZE --- + init_result = await session.initialize() + print("Server info:", init_result.serverInfo) + + # --- LIST TOOLS --- + tools = await session.list_tools() + print("Available tools:", [t.name for t in tools.tools]) + + # --- CALL TOOL --- + resp = await session.call_tool( + "chat", + arguments={ + "messages": [ + {"role": "user", "content": "Hello MCP 👋"}, + {"role": "assistant", "content": "Hi there!"} + ], + "maxTokens": 50, + "temperature": 0.7, + }, + ) + # resp.content is a list[CallToolResultContentItem]; extract text parts + text = "\n".join(c.text for c in resp.content if c.type == "text") + print("Model replied:", text) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +### HTTP + +**Call a tool:** +```bash +curl -X POST http://localhost:4321/mcp \ +-H "Content-Type: application/json" \ +-d '{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "chat", + "arguments": { + "messages": [ + { "role": "system", "content": "You are a helpful assistant." }, + { "role": "user", "content": "Hello, what’s the time?" } + ], + "maxTokens": 50, + "temperature": 0.7 + } + } +}' +``` + +**Initialize:** +```bash +curl -X POST http://localhost:4321/mcp \ +-H "Content-Type: application/json" \ +-d '{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {} +}' +``` + +**List tools:** +```bash +curl -X POST http://localhost:4321/mcp \ +-H "Content-Type: application/json" \ +-d '{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} +}' +``` + +## Limitations + +- Streaming requests are not implemented. +- No authentication layer is provided – run the MCP port behind a reverse proxy if you need auth. + +Contributions to extend MCP coverage (streaming, more tools, auth hooks) are welcome! diff --git a/docs/README.md b/docs/README.md index 51159e2ccd..ecbd30d51d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -38,6 +38,7 @@ - [Sampling](SAMPLING.md) - [TOML selector](TOML_SELECTOR.md) - [Tool calling](TOOL_CALLING.md) +- [MCP protocol](MCP.md) ## Cross-device inference - [Device mapping](DEVICE_MAPPING.md) diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index c9a253b9b4..7a35228839 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -94,8 +94,8 @@ pub use pipeline::{ MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader, - TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig, - UQFF_MULTI_FILE_DELIMITER, + SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType, + VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER, }; pub use request::{ ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat, diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 9de42976f1..8272f3559a 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -77,7 +77,7 @@ pub use crate::kv_cache::{ Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType, }; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub enum SupportedModality { Text, Audio, diff --git a/mistralrs-server/Cargo.toml b/mistralrs-server/Cargo.toml index a29b94a287..eb818e29e1 100644 --- a/mistralrs-server/Cargo.toml +++ b/mistralrs-server/Cargo.toml @@ -29,6 +29,8 @@ serde.workspace = true serde_json.workspace = true tokio.workspace = true tracing.workspace = true +rust-mcp-sdk.workspace = true +async-trait.workspace = true [features] cuda = ["mistralrs-core/cuda", "mistralrs-server-core/cuda"] @@ -43,3 +45,4 @@ accelerate = ["mistralrs-core/accelerate", "mistralrs-server-core/accelerate"] mkl = ["mistralrs-core/mkl", "mistralrs-server-core/mkl"] nccl = ["mistralrs-core/nccl", "mistralrs-server-core/nccl"] ring = ["mistralrs-core/ring", "mistralrs-server-core/ring"] +mcp-server = ["rust-mcp-sdk/server", "rust-mcp-sdk/hyper-server"] diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index b4da9b0544..bb2f3d4f27 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -1,6 +1,8 @@ use anyhow::Result; use clap::Parser; use mistralrs_core::{initialize_logging, ModelSelected, TokenSource}; +use rust_mcp_sdk::schema::LATEST_PROTOCOL_VERSION; +use tokio::join; use tracing::info; use mistralrs_server_core::{ @@ -10,6 +12,7 @@ use mistralrs_server_core::{ mod interactive_mode; use interactive_mode::interactive_mode; +mod mcp_server; #[derive(Parser)] #[command(version, about, long_about = None)] @@ -24,7 +27,7 @@ struct Args { /// Port to serve on. #[arg(short, long)] - port: Option, + port: Option, /// Log all responses and requests to this file #[clap(long, short)] @@ -134,6 +137,10 @@ struct Args { /// Enable thinking for interactive mode and models that support it. #[arg(long = "enable-thinking")] enable_thinking: bool, + + /// Port to serve MCP protocol on + #[arg(long)] + mcp_port: Option, } fn parse_token_source(s: &str) -> Result { @@ -185,27 +192,54 @@ async fn main() -> Result<()> { return Ok(()); } - // Needs to be after the .build call as that is where the daemon waits. - let setting_server = if !args.interactive_mode { - let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?"); - let ip = args.serve_ip.unwrap_or_else(|| "0.0.0.0".to_string()); + if !args.interactive_mode && args.port.is_none() && args.mcp_port.is_none() { + anyhow::bail!("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port` or `--mcp-port`?") + } - // Create listener early to validate address before model loading - let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?; - Some((listener, ip, port)) + let mcp_port = if let Some(port) = args.mcp_port { + let host = args + .serve_ip + .clone() + .unwrap_or_else(|| "0.0.0.0".to_string()); + info!("MCP server listening on http://{host}:{port}/mcp."); + info!("MCP protocol version is {}.", LATEST_PROTOCOL_VERSION); + let mcp_server = mcp_server::create_http_mcp_server(mistralrs.clone(), host, port); + + tokio::spawn(async move { + if let Err(e) = mcp_server.await { + eprintln!("MCP server error: {e}"); + } + }) } else { - None + tokio::spawn(async {}) }; - let app = MistralRsServerRouterBuilder::new() - .with_mistralrs(mistralrs) - .build() - .await?; + let oai_port = if let Some(port) = args.port { + let ip = args + .serve_ip + .clone() + .unwrap_or_else(|| "0.0.0.0".to_string()); - if let Some((listener, ip, port)) = setting_server { - info!("Serving on http://{ip}:{}.", port); - axum::serve(listener, app).await?; + // Create listener early to validate address before model loading + let listener = tokio::net::TcpListener::bind(format!("{ip}:{port}")).await?; + + let app = MistralRsServerRouterBuilder::new() + .with_mistralrs(mistralrs) + .build() + .await?; + + info!("OpenAI-compatible server listening on http://{ip}:{port}."); + + tokio::spawn(async move { + if let Err(e) = axum::serve(listener, app).await { + eprintln!("OpenAI server error: {e}"); + } + }) + } else { + tokio::spawn(async {}) }; + let (_, _) = join!(oai_port, mcp_port); + Ok(()) } diff --git a/mistralrs-server/src/mcp_server.rs b/mistralrs-server/src/mcp_server.rs new file mode 100644 index 0000000000..8c29ffd4fd --- /dev/null +++ b/mistralrs-server/src/mcp_server.rs @@ -0,0 +1,370 @@ +use async_trait::async_trait; +use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router}; +use mistralrs_core::SupportedModality; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::io; +use std::sync::Arc; +use tokio::net::TcpListener; + +use mistralrs_server_core::{ + chat_completion::{create_response_channel, parse_request}, + types::SharedMistralRsState, +}; + +// Import your existing types +use rust_mcp_sdk::schema::{ + schema_utils::CallToolError, CallToolResult, CallToolResultContentItem, Implementation, + InitializeResult, ListToolsResult, ServerCapabilities, ServerCapabilitiesTools, TextContent, + Tool, ToolInputSchema, LATEST_PROTOCOL_VERSION, +}; + +mod errors { + #![allow(dead_code)] + + /// JSON-RPC error codes based on MCPEx.Protocol.Errors + pub const PARSE_ERROR: i32 = -32700; + pub const INVALID_REQUEST: i32 = -32600; + pub const METHOD_NOT_FOUND: i32 = -32601; + pub const INVALID_PARAMS: i32 = -32602; + pub const INTERNAL_ERROR: i32 = -32603; +} + +// JSON-RPC types +#[derive(serde::Deserialize)] +struct JsonRpcRequest { + jsonrpc: String, + id: Option, + method: String, + params: Option, +} + +#[derive(serde::Serialize)] +struct JsonRpcResponse { + jsonrpc: String, + id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +#[derive(serde::Serialize)] +struct JsonRpcError { + code: i32, + message: String, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, +} + +// Keep your existing McpTool trait and ChatTool implementation +#[async_trait] +pub trait McpTool: Send + Sync { + fn name(&self) -> &str; + fn description(&self) -> Option<&str>; + fn input_schema(&self) -> &ToolInputSchema; + + fn as_tool_record(&self) -> Tool { + Tool { + name: self.name().to_string(), + description: self.description().map(|s| s.to_string()), + input_schema: self.input_schema().clone(), + annotations: None, + } + } + + async fn call( + &self, + args: serde_json::Value, + state: &SharedMistralRsState, + ) -> std::result::Result; +} + +pub struct ChatTool { + input_schema: ToolInputSchema, +} + +impl ChatTool { + pub fn new() -> Self { + let required = vec!["messages".to_string()]; + + let mut properties: HashMap> = HashMap::new(); + properties.insert( + "messages".to_string(), + json!({ + "type": "array", + "description": "Conversation messages so far", + "items": { + "type": "object", + "required": ["role", "content"], + "properties": { + "role": { "type": "string", "enum": ["user", "assistant", "system"] }, + "content": { "type": "string" } + } + } + }) + .as_object() + .unwrap() + .clone(), + ); + properties.insert( + "maxTokens".to_string(), + json!({ + "type": "integer", + "description": "Maximum number of tokens to generate" + }) + .as_object() + .unwrap() + .clone(), + ); + properties.insert( + "temperature".to_string(), + json!({ + "type": "number", + "description": "Sampling temperature between 0 and 1", + "minimum": 0.0, + "maximum": 1.0 + }) + .as_object() + .unwrap() + .clone(), + ); + properties.insert( + "systemPrompt".to_string(), + json!({ + "type": "string", + "description": "Optional system prompt to prepend to the conversation" + }) + .as_object() + .unwrap() + .clone(), + ); + + let input_schema = ToolInputSchema::new(required, Some(properties)); + Self { input_schema } + } +} + +#[async_trait] +impl McpTool for ChatTool { + fn name(&self) -> &str { + "chat" + } + + fn description(&self) -> Option<&str> { + Some("Send a chat completion request with messages and other hyperparameters.") + } + + fn input_schema(&self) -> &ToolInputSchema { + &self.input_schema + } + + async fn call( + &self, + args: serde_json::Value, + state: &SharedMistralRsState, + ) -> std::result::Result { + // Translate to the internal ChatCompletionRequest. + let chat_req: mistralrs_server_core::openai::ChatCompletionRequest = + serde_json::from_value(args).map_err(CallToolError::new)?; + + // Execute the request using existing helper utilities. + let (tx, mut rx) = create_response_channel(None); + let (request, _is_streaming) = parse_request(chat_req, state.clone(), tx) + .await + .map_err(|e| CallToolError::new(io::Error::other(e.to_string())))?; + + mistralrs_server_core::chat_completion::send_request(state, request) + .await + .map_err(|e| CallToolError::new(io::Error::other(e.to_string())))?; + + match rx.recv().await { + Some(mistralrs_core::Response::Done(resp)) => { + let content = resp + .choices + .iter() + .filter_map(|c| c.message.content.clone()) + .collect::>() + .join("\n"); + + Ok(CallToolResult { + content: vec![CallToolResultContentItem::TextContent(TextContent::new( + content, None, + ))], + is_error: None, + meta: None, + }) + } + Some(mistralrs_core::Response::ModelError(msg, _)) => { + Err(CallToolError::new(io::Error::other(msg))) + } + Some(_) | None => Err(CallToolError::new(io::Error::other("no response"))), + } + } +} + +const MCP_INSTRUCTIONS: &str = r#" +This server provides LLM text and multimodal model inference. You can use the following tools: +- `chat` for sending a chat completion request with a model message history +"#; + +// HTTP MCP Handler +pub struct HttpMcpHandler { + pub state: SharedMistralRsState, + tools: HashMap>, + server_info: InitializeResult, +} + +impl HttpMcpHandler { + pub fn new(state: SharedMistralRsState) -> Self { + let modalities = &state.config().modalities; + + let mut tools: HashMap> = HashMap::new(); + if modalities.input.contains(&SupportedModality::Text) + && modalities.output.contains(&SupportedModality::Text) + { + tools.insert("chat".to_string(), Arc::new(ChatTool::new())); + } + + let server_info = InitializeResult { + server_info: Implementation { + name: "mistralrs".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }, + capabilities: ServerCapabilities { + tools: Some(ServerCapabilitiesTools { list_changed: None }), + ..Default::default() + }, + meta: None, + instructions: Some(MCP_INSTRUCTIONS.to_string()), + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + }; + + Self { + state, + tools, + server_info, + } + } + + async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse { + if request.jsonrpc != "2.0" { + return JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(JsonRpcError { + code: errors::INVALID_REQUEST, + message: "Expected jsonrpc to be 2.0".to_string(), + data: None, + }), + }; + } + + match request.method.as_str() { + "initialize" => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::to_value(&self.server_info).unwrap()), + error: None, + }, + "ping" => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(json!({})), + error: None, + }, + "tools/list" => { + let tools: Vec = self.tools.values().map(|t| t.as_tool_record()).collect(); + let result = ListToolsResult { + tools, + meta: None, + next_cursor: None, + }; + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::to_value(result).unwrap()), + error: None, + } + } + "tools/call" => { + let params = request.params.unwrap_or(json!({})); + + // Extract tool name and arguments from params + let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let args = params.get("arguments").cloned().unwrap_or(json!({})); + + match self.tools.get(tool_name) { + Some(tool) => match tool.call(args, &self.state).await { + Ok(result) => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::to_value(result).unwrap()), + error: None, + }, + Err(e) => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(JsonRpcError { + code: errors::INTERNAL_ERROR, + message: format!("Tool execution error: {}", e), + data: None, + }), + }, + }, + None => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(JsonRpcError { + code: errors::METHOD_NOT_FOUND, + message: format!("Unknown tool: {}", tool_name), + data: None, + }), + }, + } + } + _ => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(JsonRpcError { + code: errors::METHOD_NOT_FOUND, + message: format!("Method not found: {}", request.method), + data: None, + }), + }, + } + } +} + +// Axum handler +async fn handle_jsonrpc( + State(handler): State>, + Json(request): Json, +) -> Result, StatusCode> { + let response = handler.handle_request(request).await; + Ok(Json(response)) +} + +// Create HTTP MCP server - this replaces your old create_mcp_server function +pub async fn create_http_mcp_server( + state: SharedMistralRsState, + host: String, + port: u16, +) -> Result<(), Box> { + let handler = Arc::new(HttpMcpHandler::new(state)); + + let app = Router::new() + .route("/mcp", post(handle_jsonrpc)) + .with_state(handler); + + let addr = format!("{}:{}", host, port); + let listener = TcpListener::bind(&addr).await?; + + axum::serve(listener, app).await?; + + Ok(()) +}