Skip to content

Commit 6d52180

Browse files
authored
fix: reuse sessions correctly in streamable HTTP transport (#615)
Fixes session management in the streamable HTTP server to properly reuse registered sessions for POST requests instead of always creating ephemeral sessions. This enables SendNotificationToSpecificClient and session-aware features to work correctly with POST-based interactions. Changes: - Check s.server.sessions for existing sessions before creating ephemeral ones - Register sessions after successful initialization from POST requests - Store sessions in both s.server.sessions and s.activeSessions for consistency - Add comprehensive tests for session reuse and notification delivery Fixes #614
1 parent 74a600b commit 6d52180

File tree

2 files changed

+197
-4
lines changed

2 files changed

+197
-4
lines changed

server/streamable_http.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,23 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
309309
}
310310
}
311311

312+
// For non-initialize requests, try to reuse existing registered session
313+
var session *streamableHttpSession
314+
if !isInitializeRequest {
315+
if sessionValue, ok := s.server.sessions.Load(sessionID); ok {
316+
if existingSession, ok := sessionValue.(*streamableHttpSession); ok {
317+
session = existingSession
318+
}
319+
}
320+
}
321+
312322
// Check if a persistent session exists (for sampling support), otherwise create ephemeral session
313323
// Persistent sessions are created by GET (continuous listening) connections
314-
var session *streamableHttpSession
315-
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
316-
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
317-
session = persistentSession
324+
if session == nil {
325+
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
326+
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
327+
session = persistentSession
328+
}
318329
}
319330
}
320331

@@ -417,6 +428,21 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
417428
s.logger.Errorf("Failed to write response: %v", err)
418429
}
419430
}
431+
432+
// Register session after successful initialization
433+
// Only register if not already registered (e.g., by a GET connection)
434+
if isInitializeRequest && sessionID != "" {
435+
if _, exists := s.server.sessions.Load(sessionID); !exists {
436+
// Store in activeSessions to prevent duplicate registration from GET
437+
s.activeSessions.Store(sessionID, session)
438+
// Register the session with the MCPServer for notification support
439+
if err := s.server.RegisterSession(ctx, session); err != nil {
440+
s.logger.Errorf("Failed to register POST session: %v", err)
441+
s.activeSessions.Delete(sessionID)
442+
// Don't fail the request, just log the error
443+
}
444+
}
445+
}
420446
}
421447

422448
func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {

server/streamable_http_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,3 +1314,170 @@ func TestInsecureStatefulSessionIdManager(t *testing.T) {
13141314
}
13151315
})
13161316
}
1317+
1318+
func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
1319+
t.Run("POST session registration enables SendNotificationToSpecificClient", func(t *testing.T) {
1320+
hooks := &Hooks{}
1321+
var registeredSessionID string
1322+
var mu sync.Mutex
1323+
var sessionRegistered sync.WaitGroup
1324+
sessionRegistered.Add(1)
1325+
1326+
hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
1327+
mu.Lock()
1328+
registeredSessionID = session.SessionID()
1329+
mu.Unlock()
1330+
sessionRegistered.Done()
1331+
})
1332+
1333+
mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
1334+
testServer := NewTestStreamableHTTPServer(mcpServer)
1335+
defer testServer.Close()
1336+
1337+
// Send initialize request to register session
1338+
resp, err := postJSON(testServer.URL, initRequest)
1339+
if err != nil {
1340+
t.Fatalf("Failed to send initialize request: %v", err)
1341+
}
1342+
defer resp.Body.Close()
1343+
1344+
if resp.StatusCode != http.StatusOK {
1345+
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
1346+
}
1347+
1348+
// Get session ID from response header
1349+
sessionID := resp.Header.Get(HeaderKeySessionID)
1350+
if sessionID == "" {
1351+
t.Fatal("Expected session ID in response header")
1352+
}
1353+
1354+
// Wait for session registration
1355+
done := make(chan struct{})
1356+
go func() {
1357+
sessionRegistered.Wait()
1358+
close(done)
1359+
}()
1360+
1361+
select {
1362+
case <-done:
1363+
// Session registered successfully
1364+
case <-time.After(2 * time.Second):
1365+
t.Fatal("Timeout waiting for session registration")
1366+
}
1367+
1368+
mu.Lock()
1369+
if registeredSessionID != sessionID {
1370+
t.Errorf("Expected registered session ID %s, got %s", sessionID, registeredSessionID)
1371+
}
1372+
mu.Unlock()
1373+
1374+
// Now test SendNotificationToSpecificClient
1375+
err = mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]any{
1376+
"message": "test notification",
1377+
})
1378+
if err != nil {
1379+
t.Errorf("SendNotificationToSpecificClient failed: %v", err)
1380+
}
1381+
})
1382+
1383+
t.Run("Session reuse for non-initialize requests", func(t *testing.T) {
1384+
mcpServer := NewMCPServer("test", "1.0.0")
1385+
1386+
// Add a tool that sends a notification
1387+
mcpServer.AddTool(mcp.NewTool("notify_tool"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1388+
session := ClientSessionFromContext(ctx)
1389+
if session == nil {
1390+
return mcp.NewToolResultError("no session in context"), nil
1391+
}
1392+
1393+
// Try to send notification to specific client
1394+
server := ServerFromContext(ctx)
1395+
err := server.SendNotificationToSpecificClient(session.SessionID(), "tool/notification", map[string]any{
1396+
"from": "tool",
1397+
})
1398+
if err != nil {
1399+
return mcp.NewToolResultError(fmt.Sprintf("notification failed: %v", err)), nil
1400+
}
1401+
1402+
return mcp.NewToolResultText("notification sent"), nil
1403+
})
1404+
1405+
testServer := NewTestStreamableHTTPServer(mcpServer)
1406+
defer testServer.Close()
1407+
1408+
// Initialize session
1409+
resp, err := postJSON(testServer.URL, initRequest)
1410+
if err != nil {
1411+
t.Fatalf("Failed to send initialize request: %v", err)
1412+
}
1413+
sessionID := resp.Header.Get(HeaderKeySessionID)
1414+
resp.Body.Close()
1415+
1416+
if sessionID == "" {
1417+
t.Fatal("Expected session ID in response header")
1418+
}
1419+
1420+
// Give time for registration to complete
1421+
time.Sleep(100 * time.Millisecond)
1422+
1423+
// Call tool with the session ID
1424+
toolCallRequest := map[string]any{
1425+
"jsonrpc": "2.0",
1426+
"id": 2,
1427+
"method": "tools/call",
1428+
"params": map[string]any{
1429+
"name": "notify_tool",
1430+
},
1431+
}
1432+
1433+
jsonBody, _ := json.Marshal(toolCallRequest)
1434+
req, _ := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewBuffer(jsonBody))
1435+
req.Header.Set("Content-Type", "application/json")
1436+
req.Header.Set(HeaderKeySessionID, sessionID)
1437+
1438+
resp, err = http.DefaultClient.Do(req)
1439+
if err != nil {
1440+
t.Fatalf("Failed to call tool: %v", err)
1441+
}
1442+
defer resp.Body.Close()
1443+
1444+
bodyBytes, _ := io.ReadAll(resp.Body)
1445+
bodyStr := string(bodyBytes)
1446+
1447+
// Response might be SSE format if notification was sent
1448+
var toolResponse jsonRPCResponse
1449+
if strings.HasPrefix(bodyStr, "event: message") {
1450+
// Parse SSE format
1451+
lines := strings.Split(bodyStr, "\n")
1452+
for _, line := range lines {
1453+
if strings.HasPrefix(line, "data: ") {
1454+
jsonData := strings.TrimPrefix(line, "data: ")
1455+
if err := json.Unmarshal([]byte(jsonData), &toolResponse); err == nil {
1456+
break
1457+
}
1458+
}
1459+
}
1460+
} else {
1461+
if err := json.Unmarshal(bodyBytes, &toolResponse); err != nil {
1462+
t.Fatalf("Failed to unmarshal response: %v. Body: %s", err, bodyStr)
1463+
}
1464+
}
1465+
1466+
if toolResponse.Error != nil {
1467+
t.Errorf("Tool call failed: %v", toolResponse.Error)
1468+
}
1469+
1470+
// Verify the tool result indicates success
1471+
if result, ok := toolResponse.Result["content"].([]any); ok {
1472+
if len(result) > 0 {
1473+
if content, ok := result[0].(map[string]any); ok {
1474+
if text, ok := content["text"].(string); ok {
1475+
if text != "notification sent" {
1476+
t.Errorf("Expected 'notification sent', got %s", text)
1477+
}
1478+
}
1479+
}
1480+
}
1481+
}
1482+
})
1483+
}

0 commit comments

Comments
 (0)