Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ func TestSSE(t *testing.T) {
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
// Test that SSE events with only data field (no event field) are processed correctly
// This tests the fix for issue #369

var messageReceived chan struct{}

// Create a custom mock server that sends SSE events without event field
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand Down Expand Up @@ -449,7 +449,7 @@ func TestSSE(t *testing.T) {
messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)

// Signal that message was received
close(messageReceived)
})
Expand Down
175 changes: 160 additions & 15 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
}
}

// WithResumption enables the client to attempt resuming broken connections.
// This can help to reduce network congestion as the server does not need
// to redeliver messages that have already been sent on the previous broken
// connection.
//
// As the retry itself might fail, the retry count can be set and it must be a value >= 1.
// If the value is < 1, it will be set to 1.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery
// NOTICE: Even enabled, the server may not support this feature.
func WithResumption(maxRetryCount int) StreamableHTTPCOption {
if maxRetryCount < 1 {
maxRetryCount = 1
}
return func(sc *StreamableHTTP) {
sc.resumptionEnabled = true
sc.maxRetryCount = maxRetryCount
}
}

// StreamableHTTP implements Streamable HTTP transport.
//
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
Expand All @@ -66,14 +85,14 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
// - batching
// - continuously listening for server notifications when no request is in flight
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
serverURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
serverURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
resumptionEnabled bool
maxRetryCount int

sessionID atomic.Value // string

Expand Down Expand Up @@ -159,7 +178,8 @@ func (c *StreamableHTTP) Close() error {
}

const (
headerKeySessionID = "Mcp-Session-Id"
headerKeySessionID = "Mcp-Session-Id"
headerKeyLastEventID = "Last-Event-Id"
)

// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
Expand Down Expand Up @@ -300,16 +320,129 @@ func (c *StreamableHTTP) SendRequest(

case "text/event-stream":
// Server is using SSE for streaming responses
return c.handleSSEResponse(ctx, resp.Body)

if !c.resumptionEnabled {
return c.handleSSEResponse(ctx, resp.Body, nil)
}
var lastEventId string
resumptionCallback := func(id string) {
lastEventId = id
}
resp, err := c.handleSSEResponse(ctx, resp.Body, resumptionCallback)
if err == nil || lastEventId == "" {
return resp, err
}
for range c.maxRetryCount {
resp, err, canRetry := c.performSSEResumption(ctx, lastEventId, resumptionCallback)
if err == nil || lastEventId == "" || !canRetry {
return resp, err
}
}
return nil, fmt.Errorf("failed to retrieve response after attempting resumption %d times", c.maxRetryCount)
default:
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
}
}

// performSSEResumption sends a request to the server including both the
// session id and the last event id, expecting the server to return an
// SSE streaming response sending the events with ids after the last event id.
// It returns the final result for the request once received, or an error.
func (c *StreamableHTTP) performSSEResumption(
ctx context.Context,
lastEventId string,
resumptionCallback func(string),
) (*JSONRPCResponse, error, bool) {

ctx, cancel := context.WithCancel(ctx)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.serverURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err), false
}

req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
sessionID := c.sessionID.Load()
if sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID.(string))
}
if lastEventId == "" {
return nil, fmt.Errorf("sse resumption request requires a last event id"), false
}
req.Header.Set(headerKeyLastEventID, lastEventId)

for k, v := range c.headers {
req.Header.Set(k, v)
}

// Add OAuth authorization if configured
if c.oauthHandler != nil {
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
if err != nil {
// If we get an authorization error, return a specific error that can be handled by the client
if err.Error() == "no valid token available, authorization required" {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}, false
}
return nil, fmt.Errorf("failed to get authorization header: %w", err), false
}
req.Header.Set("Authorization", authHeader)
}

if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to connect to SSE stream: %w", err), true
}

// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
// handle session closed
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, fmt.Errorf("session terminated (404). need to re-initialize"), false
}

// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}, false
}

// handle error response
var errResponse JSONRPCResponse
body, _ := io.ReadAll(resp.Body)
if err := json.Unmarshal(body, &errResponse); err == nil {
return &errResponse, nil, false
}
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body), false
}

mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
switch mediaType {
case "text/event-stream":
resp, err := c.handleSSEResponse(ctx, resp.Body, resumptionCallback)
return resp, err, true
default:
return nil, fmt.Errorf("unexpected content type for sse resumption response: %s", resp.Header.Get("Content-Type")), false
}
}

// handleSSEResponse processes an SSE stream for a specific request.
// It returns the final result for the request once received, or an error.
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
func (c *StreamableHTTP) handleSSEResponse(
ctx context.Context,
reader io.ReadCloser,
resumptionCallback func(string),
) (*JSONRPCResponse, error) {

// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)
Expand All @@ -322,7 +455,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
// only close responseChan after readingSSE()
defer close(responseChan)

c.readSSE(ctx, reader, func(event, data string) {
c.readSSE(ctx, reader, func(event, data, id string) {

// (unsupported: batching)

Expand All @@ -332,6 +465,10 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
return
}

if id != "" && resumptionCallback != nil {
resumptionCallback(id)
}

// Handle notification
if message.ID.IsNil() {
var notification mcp.JSONRPCNotification
Expand Down Expand Up @@ -365,11 +502,11 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl

// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
// It will end when the reader is closed (or the context is done).
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data, id string)) {
defer reader.Close()

br := bufio.NewReader(reader)
var event, data string
var event, data, id string

for {
select {
Expand All @@ -385,7 +522,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
if event == "" {
event = "message"
}
handler(event, data)
handler(event, data, id)
}
return
}
Expand All @@ -407,9 +544,10 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
if event == "" {
event = "message"
}
handler(event, data)
handler(event, data, id)
event = ""
data = ""
id = ""
}
continue
}
Expand All @@ -418,6 +556,13 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
} else if strings.HasPrefix(line, "id:") {
eventId := strings.TrimSpace(strings.TrimPrefix(line, "id:"))
if strings.Contains(eventId, "\x00") {
// will be sent back in HTTP header, a null byte in header breaks HTTP standard
continue
}
id = eventId
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions client/transport/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func TestStreamableHTTP(t *testing.T) {
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
// Test that SSE events with only data field (no event field) are processed correctly
// This tests the fix for issue #369

// Create a custom mock server that sends SSE events without event field
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Expand All @@ -437,7 +437,7 @@ func TestStreamableHTTP(t *testing.T) {
// This should be processed as a "message" event according to SSE spec
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)

response := map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
Expand Down