closeGracefully() {
+ return Mono.fromRunnable(sink::complete);
+ }
+
+ @Override
+ public void close() {
+ sink.complete();
+ }
+
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Builder for creating instances of {@link DefaultMcpStreamableServerTransportProvider}.
+ *
+ * This builder provides a fluent API for configuring and creating instances of
+ * DefaultMcpStreamableServerTransportProvider with custom settings.
+ */
+ public static class Builder {
+
+ private ObjectMapper objectMapper;
+
+ private String mcpEndpoint = "/mcp";
+
+ private McpTransportContextExtractor contextExtractor = (
+ serverRequest) -> McpTransportContext.EMPTY;
+
+ private boolean disallowDelete;
+
+ private Duration keepAliveInterval;
+
+ private Builder() {
+ // used by a static method
+ }
+
+ /**
+ * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
+ * messages.
+ * @param objectMapper The ObjectMapper instance. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if objectMapper is null
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Sets the endpoint URI where clients should send their JSON-RPC messages.
+ * @param messageEndpoint The message endpoint URI. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if messageEndpoint is null
+ */
+ public Builder messageEndpoint(String messageEndpoint) {
+ Assert.notNull(messageEndpoint, "Message endpoint must not be null");
+ this.mcpEndpoint = messageEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets the context extractor that allows providing the MCP feature
+ * implementations to inspect HTTP transport level metadata that was present at
+ * HTTP request processing time. This allows to extract custom headers and other
+ * useful data for use during execution later on in the process.
+ * @param contextExtractor The contextExtractor to fill in a
+ * {@link McpTransportContext}.
+ * @return this builder instance
+ * @throws IllegalArgumentException if contextExtractor is null
+ */
+ public Builder contextExtractor(McpTransportContextExtractor contextExtractor) {
+ Assert.notNull(contextExtractor, "contextExtractor must not be null");
+ this.contextExtractor = contextExtractor;
+ return this;
+ }
+
+ /**
+ * Sets whether the session removal capability is disabled.
+ * @param disallowDelete if {@code true}, the DELETE endpoint will not be
+ * supported and sessions won't be deleted.
+ * @return this builder instance
+ */
+ public Builder disallowDelete(boolean disallowDelete) {
+ this.disallowDelete = disallowDelete;
+ return this;
+ }
+
+ /**
+ * Sets the keep-alive interval for the server transport.
+ * @param keepAliveInterval The interval for sending keep-alive messages. If null,
+ * no keep-alive will be scheduled.
+ * @return this builder instance
+ */
+ public Builder keepAliveInterval(Duration keepAliveInterval) {
+ this.keepAliveInterval = keepAliveInterval;
+ return this;
+ }
+
+ /**
+ * Builds a new instance of {@link DefaultMcpStreamableServerTransportProvider} with
+ * the configured settings.
+ * @return A new DefaultMcpStreamableServerTransportProvider instance
+ * @throws IllegalStateException if required parameters are not set
+ */
+ public DefaultMcpStreamableServerTransportProvider build() {
+ Assert.notNull(objectMapper, "ObjectMapper must be set");
+ Assert.notNull(mcpEndpoint, "Message endpoint must be set");
+
+ return new DefaultMcpStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor,
+ disallowDelete, keepAliveInterval);
+ }
+
+ }
+
+}
diff --git a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/SdkMcpServer.java b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/SdkMcpServer.java
new file mode 100644
index 000000000..1417ff81c
--- /dev/null
+++ b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/SdkMcpServer.java
@@ -0,0 +1,30 @@
+package modelengine.fel.tool.mcp.server;
+
+
+import io.modelcontextprotocol.server.McpServer;
+import io.modelcontextprotocol.spec.McpSchema;
+import modelengine.fel.tool.service.ToolExecuteService;
+import modelengine.fitframework.annotation.Component;
+import io.modelcontextprotocol.server.McpSyncServer;
+
+import java.time.Duration;
+
+@Component
+public class SdkMcpServer {
+ private final McpSyncServer mcpSyncServer;
+
+ public SdkMcpServer(DefaultMcpStreamableServerTransportProvider transportProvider) {
+ this.mcpSyncServer = McpServer.sync(transportProvider)
+ .serverInfo("hkx-server", "1.0.0")
+ .capabilities(McpSchema.ServerCapabilities.builder()
+ .resources(false, true) // Enable resource support
+ .tools(true) // Enable tool support
+ .prompts(true) // Enable prompt support
+ .logging() // Enable logging support
+ .completions() // Enable completions support
+ .build())
+ .requestTimeout(Duration.ofSeconds(10))
+ .build();
+ }
+
+}
From e4f4006738971fcd4d0a30a44de53a9661eb28db Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=BB=84=E5=8F=AF=E6=AC=A3?= <2218887102@qq.com>
Date: Mon, 29 Sep 2025 17:12:43 +0800
Subject: [PATCH 02/37] =?UTF-8?q?=E6=8E=A5=E5=85=A5Fit=20Http=E9=80=BB?=
=?UTF-8?q?=E8=BE=91?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../fel/java/plugins/tool-mcp-server/pom.xml | 2 +-
...tMcpStreamableServerTransportProvider.java | 661 +++++++++++-------
...xMcpStreamableServerTransportProvider.java | 447 ++++++++++++
.../fel/tool/mcp/server/SdkMcpServer.java | 3 +-
.../fel/tool/mcp/server/ServerSentEvent.java | 264 +++++++
.../mcp/server/support/SseServerResponse.java | 261 +++++++
6 files changed, 1390 insertions(+), 248 deletions(-)
create mode 100644 framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/FluxMcpStreamableServerTransportProvider.java
create mode 100644 framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/ServerSentEvent.java
create mode 100644 framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/support/SseServerResponse.java
diff --git a/framework/fel/java/plugins/tool-mcp-server/pom.xml b/framework/fel/java/plugins/tool-mcp-server/pom.xml
index 42c088576..306ff8c3b 100644
--- a/framework/fel/java/plugins/tool-mcp-server/pom.xml
+++ b/framework/fel/java/plugins/tool-mcp-server/pom.xml
@@ -44,7 +44,7 @@
io.modelcontextprotocol.sdk
mcp
- 0.13.0
+ 0.12.0
diff --git a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/DefaultMcpStreamableServerTransportProvider.java b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/DefaultMcpStreamableServerTransportProvider.java
index 7fb9f636a..f5271ec09 100644
--- a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/DefaultMcpStreamableServerTransportProvider.java
+++ b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/DefaultMcpStreamableServerTransportProvider.java
@@ -7,61 +7,85 @@
import io.modelcontextprotocol.spec.*;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.KeepAliveScheduler;
+import modelengine.fit.http.annotation.*;
+import modelengine.fit.http.entity.Entity;
+import modelengine.fit.http.protocol.HttpResponseStatus;
+import modelengine.fit.http.protocol.MessageHeaderNames;
+import modelengine.fit.http.protocol.MimeType;
import modelengine.fit.http.server.HttpClassicServerRequest;
import modelengine.fit.http.server.HttpClassicServerResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import reactor.core.Disposable;
-import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
-import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
+import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.locks.ReentrantLock;
+@RequestMapping("/mcp/streamable")
public class DefaultMcpStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
private static final Logger logger = LoggerFactory.getLogger(DefaultMcpStreamableServerTransportProvider.class);
+ /**
+ * Event type for JSON-RPC messages sent through the SSE connection.
+ */
public static final String MESSAGE_EVENT_TYPE = "message";
- private final ObjectMapper objectMapper;
+ /**
+ * Event type for sending the message endpoint URI to clients.
+ */
+ public static final String ENDPOINT_EVENT_TYPE = "endpoint";
- private final String mcpEndpoint;
+ /**
+ * Default base URL for the message endpoint.
+ */
+ public static final String DEFAULT_BASE_URL = "";
+ /**
+ * Flag indicating whether DELETE requests are disallowed on the endpoint.
+ */
private final boolean disallowDelete;
- private final RouterFunction> routerFunction;
+ private final ObjectMapper objectMapper;
private McpStreamableServerSession.Factory sessionFactory;
+ /**
+ * Map of active client sessions, keyed by mcp-session-id.
+ */
private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
private McpTransportContextExtractor contextExtractor;
+ /**
+ * Flag indicating if the transport is shutting down.
+ */
private volatile boolean isClosing = false;
private KeepAliveScheduler keepAliveScheduler;
- private DefaultMcpStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
- McpTransportContextExtractor contextExtractor, boolean disallowDelete,
- Duration keepAliveInterval) {
+ /**
+ * Constructs a new DefaultMcpStreamableServerTransportProvider instance.
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ * of messages.
+ * @param disallowDelete Whether to disallow DELETE requests on the endpoint.
+ * @throws IllegalArgumentException if any parameter is null
+ */
+ private DefaultMcpStreamableServerTransportProvider(ObjectMapper objectMapper,
+ boolean disallowDelete, McpTransportContextExtractor contextExtractor,
+ Duration keepAliveInterval) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
- Assert.notNull(mcpEndpoint, "Message endpoint must not be null");
- Assert.notNull(contextExtractor, "Context extractor must not be null");
+ Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null");
this.objectMapper = objectMapper;
- this.mcpEndpoint = mcpEndpoint;
- this.contextExtractor = contextExtractor;
this.disallowDelete = disallowDelete;
- this.routerFunction = RouterFunctions.route()
- .GET(this.mcpEndpoint, this::handleGet)
- .POST(this.mcpEndpoint, this::handlePost)
- .DELETE(this.mcpEndpoint, this::handleDelete)
- .build();
+ this.contextExtractor = contextExtractor;
if (keepAliveInterval != null) {
this.keepAliveScheduler = KeepAliveScheduler
@@ -84,282 +108,451 @@ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory)
this.sessionFactory = sessionFactory;
}
+ /**
+ * Broadcasts a notification to all connected clients through their SSE connections.
+ * If any errors occur during sending to a particular client, they are logged but
+ * don't prevent sending to other clients.
+ * @param method The method name for the notification
+ * @param params The parameters for the notification
+ * @return A Mono that completes when the broadcast attempt is finished
+ */
@Override
public Mono notifyClients(String method, Object params) {
- if (sessions.isEmpty()) {
+ if (this.sessions.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
}
- logger.debug("Attempting to broadcast message to {} active sessions", sessions.size());
+ logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
- return Flux.fromIterable(sessions.values())
- .flatMap(session -> session.sendNotification(method, params)
- .doOnError(
- e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()))
- .onErrorComplete())
- .then();
+ return Mono.fromRunnable(() -> {
+ this.sessions.values().parallelStream().forEach(session -> {
+ try {
+ session.sendNotification(method, params).block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage());
+ }
+ });
+ });
}
+ /**
+ * Initiates a graceful shutdown of the transport.
+ * @return A Mono that completes when all cleanup operations are finished
+ */
@Override
public Mono closeGracefully() {
- return Mono.defer(() -> {
+ return Mono.fromRunnable(() -> {
this.isClosing = true;
- return Flux.fromIterable(sessions.values())
- .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()))
- .flatMap(McpStreamableServerSession::closeGracefully)
- .then();
+ logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
+
+ this.sessions.values().parallelStream().forEach(session -> {
+ try {
+ session.closeGracefully().block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to close session {}: {}", session.getId(), e.getMessage());
+ }
+ });
+
+ this.sessions.clear();
+ logger.debug("Graceful shutdown completed");
}).then().doOnSuccess(v -> {
- sessions.clear();
if (this.keepAliveScheduler != null) {
this.keepAliveScheduler.shutdown();
}
});
}
- /**
- * Returns the WebFlux router function that defines the transport's HTTP endpoints.
- * This router function should be integrated into the application's web configuration.
- *
- *
- * The router function defines one endpoint with three methods:
- *
- * - GET {messageEndpoint} - For the client listening SSE stream
- * - POST {messageEndpoint} - For receiving client messages
- * - DELETE {messageEndpoint} - For removing sessions
- *
- * @return The configured {@link RouterFunction} for handling HTTP requests
- */
- public RouterFunction> getRouterFunction() {
- return this.routerFunction;
- }
/**
- * Opens the listening SSE streams for clients.
+ * Setup the listening SSE connections and message replay.
* @param request The incoming server request
- * @return A Mono which emits a response with the SSE event stream
*/
- private Mono handleGet(HttpClassicServerRequest request) {
- if (isClosing) {
- return HttpClassicServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
+ @GetMapping
+ private void handleGet(HttpClassicServerRequest request, HttpClassicServerResponse response) {
+ if (this.isClosing) {
+ response.statusCode(HttpResponseStatus.SERVICE_UNAVAILABLE.statusCode());
+ response.entity(Entity.createText(response, "Server is shutting down"));
+ return;
+ }
+
+ List acceptHeaders = request.headers().all(MessageHeaderNames.ACCEPT);
+ if (!acceptHeaders.contains(MimeType.TEXT_EVENT_STREAM.value())) {
+ response.statusCode(HttpResponseStatus.BAD_REQUEST.statusCode());
+ response.entity(Entity.createText(response, "Invalid Accept header. Expected TEXT_EVENT_STREAM"));
+ return;
}
McpTransportContext transportContext = this.contextExtractor.extract(request);
- return Mono.defer(() -> {
- List acceptHeaders = request.headers().asHttpHeaders().getAccept();
- if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
- return HttpClassicServerResponse.badRequest().build();
- }
+ if (!request.headers().contains(HttpHeaders.MCP_SESSION_ID)) {
+ response.statusCode(HttpResponseStatus.BAD_REQUEST.statusCode());
+ response.entity(Entity.createText(response, "Session ID required in mcp-session-id header"));
+ return;
+ }
- if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
- return HttpClassicServerResponse.badRequest().build(); // TODO: say we need a session
- // id
- }
+ String sessionId = request.headers().first(HttpHeaders.MCP_SESSION_ID).orElse("");
+ McpStreamableServerSession session = this.sessions.get(sessionId);
- String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID);
+ if (session == null) {
+ response.statusCode(HttpResponseStatus.NOT_FOUND.statusCode());
+ return;
+ }
- McpStreamableServerSession session = this.sessions.get(sessionId);
+ logger.debug("Handling GET request for session: {}", sessionId);
- if (session == null) {
- return HttpClassicServerResponse.notFound().build();
- }
+ try {
- if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
- String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
- return HttpClassicServerResponse.ok()
- .contentType(MediaType.TEXT_EVENT_STREAM)
- .body(session.replay(lastId)
- .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
- ServerSentEvent.class);
- }
+ return HttpClassicServerResponse.sse(sseBuilder -> {
+ sseBuilder.onTimeout(() -> {
+ logger.debug("SSE connection timed out for session: {}", sessionId);
+ });
+
+ DefaultStreamableMcpSessionTransport sessionTransport = new DefaultStreamableMcpSessionTransport(
+ sessionId, sseBuilder);
+
+ // Check if this is a replay request
+ if (request.headers().contains(HttpHeaders.LAST_EVENT_ID)) {
+ String lastId = request.headers().first(HttpHeaders.LAST_EVENT_ID).orElse("");
- return HttpClassicServerResponse.ok()
- .contentType(MediaType.TEXT_EVENT_STREAM)
- .body(Flux.>create(sink -> {
- WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport(
- sink);
- McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
- .listeningStream(sessionTransport);
- sink.onDispose(listeningStream::close);
- // TODO Clarify why the outer context is not present in the
- // Flux.create sink?
- }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class);
-
- }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
+ try {
+ session.replay(lastId)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .toIterable()
+ .forEach(message -> {
+ try {
+ sessionTransport.sendMessage(message)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to replay message: {}", e.getMessage());
+ sseBuilder.error(e);
+ }
+ });
+ }
+ catch (Exception e) {
+ logger.error("Failed to replay messages: {}", e.getMessage());
+ sseBuilder.error(e);
+ }
+ }
+ else {
+ // Establish new listening stream
+ McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
+ .listeningStream(sessionTransport);
+
+ sseBuilder.onComplete(() -> {
+ logger.debug("SSE connection completed for session: {}", sessionId);
+ listeningStream.close();
+ });
+ }
+ }, Duration.ZERO);
+ }
+ catch (Exception e) {
+ logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage());
+ response.statusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.statusCode());
+ }
}
/**
- * Handles incoming JSON-RPC messages from clients.
+ * Handles POST requests for incoming JSON-RPC messages from clients.
* @param request The incoming server request containing the JSON-RPC message
- * @return A Mono with the response appropriate to a particular Streamable HTTP flow.
*/
- private Mono handlePost(HttpClassicServerRequest request) {
- if (isClosing) {
- return HttpClassicServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
+ @PostMapping
+ private void handlePost(@RequestBody String body,HttpClassicServerRequest request, HttpClassicServerResponse response) {
+ if (this.isClosing) {
+ response.statusCode(HttpResponseStatus.SERVICE_UNAVAILABLE.statusCode());
+ response.entity(Entity.createText(response, "Server is shutting down"));
+ return;
+ }
+
+ List acceptHeaders = request.headers().all(MessageHeaderNames.ACCEPT);
+ if (!acceptHeaders.contains(MimeType.TEXT_EVENT_STREAM.value())
+ || !acceptHeaders.contains(MimeType.APPLICATION_JSON.value())) {
+ response.statusCode(HttpResponseStatus.BAD_REQUEST.statusCode());
+ response.entity(Entity.createObject(response, new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")));
+ return;
}
McpTransportContext transportContext = this.contextExtractor.extract(request);
- List acceptHeaders = request.headers().asHttpHeaders().getAccept();
- if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
- && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) {
- return HttpClassicServerResponse.badRequest().build();
- }
+ try {
+
+ McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
+
+ // Handle initialization request
+ if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
+ && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
+ McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
+ new TypeReference() {
+ });
+ McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
+ .startSession(initializeRequest);
+ this.sessions.put(init.session().getId(), init.session());
+
+ try {
+ McpSchema.InitializeResult initResult = init.initResult().block();
+ response.statusCode(HttpResponseStatus.OK.statusCode());
+ response.headers().set("Content-Type", MimeType.APPLICATION_JSON.value());
+ response.headers().set(HttpHeaders.MCP_SESSION_ID, init.session().getId());
+ response.entity(Entity.createObject(response,
+ new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)));
+ return;
+ }
+ catch (Exception e) {
+ logger.error("Failed to initialize session: {}", e.getMessage());
+ response.statusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.statusCode());
+ response.entity(Entity.createObject(response, new McpError(e.getMessage())));
+ return;
+ }
+ }
+
+ // Handle other messages that require a session
+ if (!request.headers().contains(HttpHeaders.MCP_SESSION_ID)) {
+ response.statusCode(HttpResponseStatus.BAD_REQUEST.statusCode());
+ response.entity(Entity.createObject(response, new McpError("Session ID missing")));
+ return;
+ }
+
+ String sessionId = request.headers().first(HttpHeaders.MCP_SESSION_ID).orElse("");
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ response.statusCode(HttpResponseStatus.NOT_FOUND.statusCode());
+ response.entity(Entity.createObject(response, new McpError("Session not found: " + sessionId)));
+ return;
+ }
+
+ if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) {
+ session.accept(jsonrpcResponse)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ response.statusCode(HttpResponseStatus.ACCEPTED.statusCode());
+ }
+ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
+ session.accept(jsonrpcNotification)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ response.statusCode(HttpResponseStatus.ACCEPTED.statusCode());
+ }
+ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
+ // For streaming responses, we need to return SSE
+ return HttpClassicServerResponse.sse(sseBuilder -> {
+ sseBuilder.onComplete(() -> {
+ logger.debug("Request response stream completed for session: {}", sessionId);
+ });
+ sseBuilder.onTimeout(() -> {
+ logger.debug("Request response stream timed out for session: {}", sessionId);
+ });
+
+ DefaultStreamableMcpSessionTransport sessionTransport = new DefaultStreamableMcpSessionTransport(
+ sessionId, sseBuilder);
- return request.bodyToMono(String.class).flatMap(body -> {
try {
- McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
- if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
- && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
- McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
- new TypeReference() {
- });
- McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
- .startSession(initializeRequest);
- sessions.put(init.session().getId(), init.session());
- return init.initResult().map(initializeResult -> {
- McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse(
- McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null);
- try {
- return this.objectMapper.writeValueAsString(jsonrpcResponse);
- }
- catch (IOException e) {
- logger.warn("Failed to serialize initResponse", e);
- throw Exceptions.propagate(e);
- }
- })
- .flatMap(initResult -> HttpClassicServerResponse.ok()
- .contentType(MediaType.APPLICATION_JSON)
- .header(HttpHeaders.MCP_SESSION_ID, init.session().getId())
- .bodyValue(initResult));
- }
-
- if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
- return HttpClassicServerResponse.badRequest().bodyValue(new McpError("Session ID missing"));
- }
-
- String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID);
- McpStreamableServerSession session = sessions.get(sessionId);
-
- if (session == null) {
- return HttpClassicServerResponse.status(HttpStatus.NOT_FOUND)
- .bodyValue(new McpError("Session not found: " + sessionId));
- }
-
- if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) {
- return session.accept(jsonrpcResponse).then(HttpClassicServerResponse.accepted().build());
- }
- else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
- return session.accept(jsonrpcNotification).then(HttpClassicServerResponse.accepted().build());
- }
- else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
- return HttpClassicServerResponse.ok()
- .contentType(MediaType.TEXT_EVENT_STREAM)
- .body(Flux.>create(sink -> {
- WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink);
- Mono stream = session.responseStream(jsonrpcRequest, st);
- Disposable streamSubscription = stream.onErrorComplete(err -> {
- sink.error(err);
- return true;
- }).contextWrite(sink.contextView()).subscribe();
- sink.onCancel(streamSubscription);
- // TODO Clarify why the outer context is not present in the
- // Flux.create sink?
- }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)),
- ServerSentEvent.class);
- }
- else {
- return HttpClassicServerResponse.badRequest().bodyValue(new McpError("Unknown message type"));
- }
+ session.responseStream(jsonrpcRequest, sessionTransport)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
}
- catch (IllegalArgumentException | IOException e) {
- logger.error("Failed to deserialize message: {}", e.getMessage());
- return HttpClassicServerResponse.badRequest().bodyValue(new McpError("Invalid message format"));
+ catch (Exception e) {
+ logger.error("Failed to handle request stream: {}", e.getMessage());
+ sseBuilder.error(e);
}
- })
- .switchIfEmpty(HttpClassicServerResponse.badRequest().build())
- .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
+ }, Duration.ZERO);
+ }
+ else {
+ response.statusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.statusCode());
+ response.entity(Entity.createObject(response, new McpError("Unknown message type")));
+ }
+ }
+ catch (IllegalArgumentException | IOException e) {
+ logger.error("Failed to deserialize message: {}", e.getMessage());
+ response.statusCode(HttpResponseStatus.BAD_REQUEST.statusCode());
+ response.entity(Entity.createObject(response, new McpError("Invalid message format")));
+ }
+ catch (Exception e) {
+ logger.error("Error handling message: {}", e.getMessage());
+ response.statusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.statusCode());
+ response.entity(Entity.createObject(response, new McpError(e.getMessage())));
+ }
}
- private Mono handleDelete(HttpClassicServerRequest request) {
- if (isClosing) {
- return HttpClassicServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
+ /**
+ * Handles DELETE requests for session deletion.
+ * @param request The incoming server request
+ */
+ @DeleteMapping
+ private void handleDelete(HttpClassicServerRequest request, HttpClassicServerResponse response) {
+ if (this.isClosing) {
+ response.statusCode(HttpResponseStatus.SERVICE_UNAVAILABLE.statusCode());
+ response.entity(Entity.createText(response, "Server is shutting down"));
+ return;
+ }
+
+ if (this.disallowDelete) {
+ response.statusCode(HttpResponseStatus.METHOD_NOT_ALLOWED.statusCode());
+ return;
}
McpTransportContext transportContext = this.contextExtractor.extract(request);
- return Mono.defer(() -> {
- if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
- return HttpClassicServerResponse.badRequest().build(); // TODO: say we need a session
- // id
- }
-
- if (this.disallowDelete) {
- return HttpClassicServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
- }
-
- String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID);
+ if (!request.headers().contains(HttpHeaders.MCP_SESSION_ID)) {
+ response.statusCode(HttpResponseStatus.BAD_REQUEST.statusCode());
+ response.entity(Entity.createText(response, "Session ID required in mcp-session-id header"));
+ return;
+ }
- McpStreamableServerSession session = this.sessions.get(sessionId);
+ String sessionId = request.headers().first(HttpHeaders.MCP_SESSION_ID).orElse("");
+ McpStreamableServerSession session = this.sessions.get(sessionId);
- if (session == null) {
- return HttpClassicServerResponse.notFound().build();
- }
+ if (session == null) {
+ response.statusCode(HttpResponseStatus.NOT_FOUND.statusCode());
+ return;
+ }
- return session.delete().then(HttpClassicServerResponse.ok().build());
- }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext));
+ try {
+ session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
+ this.sessions.remove(sessionId);
+ response.statusCode(HttpResponseStatus.OK.statusCode());
+ }
+ catch (Exception e) {
+ logger.error("Failed to delete session {}: {}", sessionId, e.getMessage());
+ response.statusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.statusCode());
+ response.entity(Entity.createObject(response, new McpError(e.getMessage())));
+ }
}
+ /**
+ * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class
+ * handles the transport-level communication for a specific client session.
+ *
+ *
+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the
+ * underlying SSE builder to prevent race conditions when multiple threads attempt to
+ * send messages concurrently.
+ */
private class DefaultStreamableMcpSessionTransport implements McpStreamableServerTransport {
- private final FluxSink> sink;
+ private final String sessionId;
+
+ private final SseBuilder sseBuilder;
+
+ private final ReentrantLock lock = new ReentrantLock();
- public DefaultStreamableMcpSessionTransport(FluxSink> sink) {
- this.sink = sink;
+ private volatile boolean closed = false;
+
+ /**
+ * Creates a new session transport with the specified ID and SSE builder.
+ * @param sessionId The unique identifier for this session
+ * @param sseBuilder The SSE builder for sending server events to the client
+ */
+ DefaultStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) {
+ this.sessionId = sessionId;
+ this.sseBuilder = sseBuilder;
+ logger.debug("Streamable session transport {} initialized with SSE builder", sessionId);
}
+ /**
+ * Sends a JSON-RPC message to the client through the SSE connection.
+ * @param message The JSON-RPC message to send
+ * @return A Mono that completes when the message has been sent
+ */
@Override
public Mono sendMessage(McpSchema.JSONRPCMessage message) {
- return this.sendMessage(message, null);
+ return sendMessage(message, null);
}
+ /**
+ * Sends a JSON-RPC message to the client through the SSE connection with a
+ * specific message ID.
+ * @param message The JSON-RPC message to send
+ * @param messageId The message ID for SSE event identification
+ * @return A Mono that completes when the message has been sent
+ */
@Override
public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) {
- return Mono.fromSupplier(() -> {
+ return Mono.fromRunnable(() -> {
+ if (this.closed) {
+ logger.debug("Attempted to send message to closed session: {}", this.sessionId);
+ return;
+ }
+
+ this.lock.lock();
try {
- return objectMapper.writeValueAsString(message);
+ if (this.closed) {
+ logger.debug("Session {} was closed during message send attempt", this.sessionId);
+ return;
+ }
+
+ String jsonText = objectMapper.writeValueAsString(message);
+ this.sseBuilder.id(messageId != null ? messageId : this.sessionId)
+ .event(MESSAGE_EVENT_TYPE)
+ .data(jsonText);
+ logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId);
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
+ try {
+ this.sseBuilder.error(e);
+ }
+ catch (Exception errorException) {
+ logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId,
+ errorException.getMessage());
+ }
}
- catch (IOException e) {
- throw Exceptions.propagate(e);
+ finally {
+ this.lock.unlock();
}
- }).doOnNext(jsonText -> {
- ServerSentEvent