diff --git a/server/streamable_http.go b/server/streamable_http.go index 9ad37fea1..8c31d1762 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -1015,29 +1015,47 @@ func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bo return false, nil } -// InsecureStatefulSessionIdManager generate id with uuid -// It won't validate the id indeed, so it could be fake. +// InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions. +// It validates both format and existence of session IDs. // For more secure session id, use a more complex generator, like a JWT. -type InsecureStatefulSessionIdManager struct{} +type InsecureStatefulSessionIdManager struct { + sessions sync.Map + terminated sync.Map +} const idPrefix = "mcp-session-" func (s *InsecureStatefulSessionIdManager) Generate() string { - return idPrefix + uuid.New().String() + sessionID := idPrefix + uuid.New().String() + s.sessions.Store(sessionID, true) + return sessionID } func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { - // validate the session id is a valid uuid if !strings.HasPrefix(sessionID, idPrefix) { return false, fmt.Errorf("invalid session id: %s", sessionID) } if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { return false, fmt.Errorf("invalid session id: %s", sessionID) } + if _, exists := s.terminated.Load(sessionID); exists { + return true, nil + } + if _, exists := s.sessions.Load(sessionID); !exists { + return false, fmt.Errorf("session not found: %s", sessionID) + } return false, nil } func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + if _, exists := s.terminated.Load(sessionID); exists { + return false, nil + } + if _, exists := s.sessions.Load(sessionID); !exists { + return false, nil + } + s.terminated.Store(sessionID, true) + s.sessions.Delete(sessionID) return false, nil } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 50be27fa7..38d2e96cf 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -45,7 +45,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") mcpServer.EnableSampling() - httpServer := NewStreamableHTTPServer(mcpServer) + httpServer := NewStreamableHTTPServer(mcpServer, WithStateLess(true)) testServer := httptest.NewServer(httpServer) defer testServer.Close() @@ -76,7 +76,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { }, { name: "invalid request ID", - sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + sessionID: "any-session-id", body: map[string]any{ "jsonrpc": "2.0", "id": "invalid-id", @@ -92,13 +92,13 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { }, { name: "malformed result", - sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + sessionID: "any-session-id", body: map[string]any{ "jsonrpc": "2.0", "id": 1, "result": "invalid-result", }, - expectedStatus: http.StatusInternalServerError, // Now correctly returns 500 due to no active session + expectedStatus: http.StatusInternalServerError, }, } diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index cee2fc031..c2647f8a1 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -1015,3 +1015,302 @@ func postJSON(url string, bodyObject any) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") return http.DefaultClient.Do(req) } + +func TestStreamableHTTP_SessionValidation(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.AddTool(mcp.NewTool("time", + mcp.WithDescription("Get the current time")), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("2024-01-01T00:00:00Z"), nil + }) + + server := NewTestStreamableHTTPServer(mcpServer) + defer server.Close() + + t.Run("Reject tool call with fake session ID", func(t *testing.T) { + toolCallRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "time", + }, + } + + jsonBody, _ := json.Marshal(toolCallRequest) + req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "Invalid session ID") { + t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body)) + } + }) + + t.Run("Reject tool call with malformed session ID", func(t *testing.T) { + toolCallRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "time", + }, + } + + jsonBody, _ := json.Marshal(toolCallRequest) + req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, "invalid-session-id") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "Invalid session ID") { + t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body)) + } + }) + + t.Run("Accept tool call with valid session ID from initialize", func(t *testing.T) { + jsonBody, _ := json.Marshal(initRequest) + req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + defer resp.Body.Close() + + sessionID := resp.Header.Get(HeaderKeySessionID) + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + + toolCallRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "time", + }, + } + + jsonBody, _ = json.Marshal(toolCallRequest) + req, _ = http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, sessionID) + + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to call tool: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Errorf("Expected status 200, got %d. Body: %s", resp.StatusCode, string(body)) + } + }) + + t.Run("Reject tool call with terminated session ID", func(t *testing.T) { + jsonBody, _ := json.Marshal(initRequest) + req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + resp.Body.Close() + + sessionID := resp.Header.Get(HeaderKeySessionID) + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + + req, _ = http.NewRequest(http.MethodDelete, server.URL, nil) + req.Header.Set(HeaderKeySessionID, sessionID) + + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to terminate session: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for termination, got %d", resp.StatusCode) + } + + toolCallRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "time", + }, + } + + jsonBody, _ = json.Marshal(toolCallRequest) + req, _ = http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, sessionID) + + resp, err = server.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + body, _ := io.ReadAll(resp.Body) + t.Errorf("Expected status 404, got %d. Body: %s", resp.StatusCode, string(body)) + } + }) +} + +func TestInsecureStatefulSessionIdManager(t *testing.T) { + t.Run("Generate creates valid session ID", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + sessionID := manager.Generate() + + if !strings.HasPrefix(sessionID, idPrefix) { + t.Errorf("Expected session ID to start with %s, got %s", idPrefix, sessionID) + } + + isTerminated, err := manager.Validate(sessionID) + if err != nil { + t.Errorf("Expected valid session ID, got error: %v", err) + } + if isTerminated { + t.Error("Expected session to not be terminated") + } + }) + + t.Run("Validate rejects non-existent session ID", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff" + + isTerminated, err := manager.Validate(fakeSessionID) + if err == nil { + t.Error("Expected error for non-existent session ID") + } + if isTerminated { + t.Error("Expected isTerminated to be false for invalid session") + } + if !strings.Contains(err.Error(), "session not found") { + t.Errorf("Expected 'session not found' error, got: %v", err) + } + }) + + t.Run("Validate rejects malformed session ID", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + invalidSessionID := "invalid-session-id" + + _, err := manager.Validate(invalidSessionID) + if err == nil { + t.Error("Expected error for malformed session ID") + } + if !strings.Contains(err.Error(), "invalid session id") { + t.Errorf("Expected 'invalid session id' error, got: %v", err) + } + }) + + t.Run("Terminate marks session as terminated", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + sessionID := manager.Generate() + + isNotAllowed, err := manager.Terminate(sessionID) + if err != nil { + t.Errorf("Expected no error on termination, got: %v", err) + } + if isNotAllowed { + t.Error("Expected termination to be allowed") + } + + isTerminated, err := manager.Validate(sessionID) + if !isTerminated { + t.Error("Expected session to be marked as terminated") + } + if err != nil { + t.Errorf("Expected no error for terminated session, got: %v", err) + } + }) + + t.Run("Terminate is idempotent for non-existent session ID", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff" + + isNotAllowed, err := manager.Terminate(fakeSessionID) + if err != nil { + t.Errorf("Expected no error when terminating non-existent session, got: %v", err) + } + if isNotAllowed { + t.Error("Expected isNotAllowed to be false") + } + }) + + t.Run("Terminate is idempotent for already-terminated session", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + sessionID := manager.Generate() + + isNotAllowed, err := manager.Terminate(sessionID) + if err != nil { + t.Errorf("Expected no error on first termination, got: %v", err) + } + if isNotAllowed { + t.Error("Expected termination to be allowed") + } + + isNotAllowed, err = manager.Terminate(sessionID) + if err != nil { + t.Errorf("Expected no error on second termination (idempotent), got: %v", err) + } + if isNotAllowed { + t.Error("Expected termination to be allowed on retry") + } + }) + + t.Run("Concurrent generate and validate", func(t *testing.T) { + manager := &InsecureStatefulSessionIdManager{} + var wg sync.WaitGroup + sessionIDs := make([]string, 100) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + sessionIDs[index] = manager.Generate() + }(i) + } + + wg.Wait() + + for _, sessionID := range sessionIDs { + isTerminated, err := manager.Validate(sessionID) + if err != nil { + t.Errorf("Expected valid session ID %s, got error: %v", sessionID, err) + } + if isTerminated { + t.Errorf("Expected session %s to not be terminated", sessionID) + } + } + }) +}