Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,29 +1015,44 @@ func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bo
return false, nil
}

// InsecureStatefulSessionIdManager generate id with uuid
// It won't validate the id indeed, so it could be fake.
// InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions.
// It validates both format and existence of session IDs.
// For more secure session id, use a more complex generator, like a JWT.
type InsecureStatefulSessionIdManager struct{}
type InsecureStatefulSessionIdManager struct {
sessions sync.Map
terminated sync.Map
}

const idPrefix = "mcp-session-"

func (s *InsecureStatefulSessionIdManager) Generate() string {
return idPrefix + uuid.New().String()
sessionID := idPrefix + uuid.New().String()
s.sessions.Store(sessionID, true)
return sessionID
}

func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
// validate the session id is a valid uuid
if !strings.HasPrefix(sessionID, idPrefix) {
return false, fmt.Errorf("invalid session id: %s", sessionID)
}
if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
return false, fmt.Errorf("invalid session id: %s", sessionID)
}
if _, exists := s.terminated.Load(sessionID); exists {
return true, nil
}
if _, exists := s.sessions.Load(sessionID); !exists {
return false, fmt.Errorf("session not found: %s", sessionID)
}
return false, nil
}

func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
if _, exists := s.sessions.Load(sessionID); !exists {
return false, fmt.Errorf("session not found: %s", sessionID)
}
s.terminated.Store(sessionID, true)
s.sessions.Delete(sessionID)
return false, nil
}

Expand Down
8 changes: 4 additions & 4 deletions server/streamable_http_sampling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")
mcpServer.EnableSampling()

httpServer := NewStreamableHTTPServer(mcpServer)
httpServer := NewStreamableHTTPServer(mcpServer, WithStateLess(true))
testServer := httptest.NewServer(httpServer)
defer testServer.Close()

Expand Down Expand Up @@ -76,7 +76,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) {
},
{
name: "invalid request ID",
sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000",
sessionID: "any-session-id",
body: map[string]any{
"jsonrpc": "2.0",
"id": "invalid-id",
Expand All @@ -92,13 +92,13 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) {
},
{
name: "malformed result",
sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000",
sessionID: "any-session-id",
body: map[string]any{
"jsonrpc": "2.0",
"id": 1,
"result": "invalid-result",
},
expectedStatus: http.StatusInternalServerError, // Now correctly returns 500 due to no active session
expectedStatus: http.StatusInternalServerError,
},
}

Expand Down
278 changes: 278 additions & 0 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,281 @@ func postJSON(url string, bodyObject any) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
return http.DefaultClient.Do(req)
}

func TestStreamableHTTP_SessionValidation(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")
mcpServer.AddTool(mcp.NewTool("time",
mcp.WithDescription("Get the current time")), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("2024-01-01T00:00:00Z"), nil
})

server := NewTestStreamableHTTPServer(mcpServer)
defer server.Close()

t.Run("Reject tool call with fake session ID", func(t *testing.T) {
toolCallRequest := map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": map[string]any{
"name": "time",
},
}

jsonBody, _ := json.Marshal(toolCallRequest)
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(HeaderKeySessionID, "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", resp.StatusCode)
}

body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "Invalid session ID") {
t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body))
}
})

t.Run("Reject tool call with malformed session ID", func(t *testing.T) {
toolCallRequest := map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": map[string]any{
"name": "time",
},
}

jsonBody, _ := json.Marshal(toolCallRequest)
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(HeaderKeySessionID, "invalid-session-id")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", resp.StatusCode)
}

body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "Invalid session ID") {
t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body))
}
})

t.Run("Accept tool call with valid session ID from initialize", func(t *testing.T) {
jsonBody, _ := json.Marshal(initRequest)
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to initialize: %v", err)
}
defer resp.Body.Close()

sessionID := resp.Header.Get(HeaderKeySessionID)
if sessionID == "" {
t.Fatal("Expected session ID in response header")
}

toolCallRequest := map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": map[string]any{
"name": "time",
},
}

jsonBody, _ = json.Marshal(toolCallRequest)
req, _ = http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(HeaderKeySessionID, sessionID)

resp, err = server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to call tool: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status 200, got %d. Body: %s", resp.StatusCode, string(body))
}
})

t.Run("Reject tool call with terminated session ID", func(t *testing.T) {
jsonBody, _ := json.Marshal(initRequest)
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to initialize: %v", err)
}
resp.Body.Close()

sessionID := resp.Header.Get(HeaderKeySessionID)
if sessionID == "" {
t.Fatal("Expected session ID in response header")
}

req, _ = http.NewRequest(http.MethodDelete, server.URL, nil)
req.Header.Set(HeaderKeySessionID, sessionID)

resp, err = server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to terminate session: %v", err)
}
resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for termination, got %d", resp.StatusCode)
}

toolCallRequest := map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": map[string]any{
"name": "time",
},
}

jsonBody, _ = json.Marshal(toolCallRequest)
req, _ = http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(HeaderKeySessionID, sessionID)

resp, err = server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusNotFound {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status 404, got %d. Body: %s", resp.StatusCode, string(body))
}
})
}

func TestInsecureStatefulSessionIdManager(t *testing.T) {
t.Run("Generate creates valid session ID", func(t *testing.T) {
manager := &InsecureStatefulSessionIdManager{}
sessionID := manager.Generate()

if !strings.HasPrefix(sessionID, idPrefix) {
t.Errorf("Expected session ID to start with %s, got %s", idPrefix, sessionID)
}

isTerminated, err := manager.Validate(sessionID)
if err != nil {
t.Errorf("Expected valid session ID, got error: %v", err)
}
if isTerminated {
t.Error("Expected session to not be terminated")
}
})

t.Run("Validate rejects non-existent session ID", func(t *testing.T) {
manager := &InsecureStatefulSessionIdManager{}
fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff"

isTerminated, err := manager.Validate(fakeSessionID)
if err == nil {
t.Error("Expected error for non-existent session ID")
}
if isTerminated {
t.Error("Expected isTerminated to be false for invalid session")
}
if !strings.Contains(err.Error(), "session not found") {
t.Errorf("Expected 'session not found' error, got: %v", err)
}
})

t.Run("Validate rejects malformed session ID", func(t *testing.T) {
manager := &InsecureStatefulSessionIdManager{}
invalidSessionID := "invalid-session-id"

_, err := manager.Validate(invalidSessionID)
if err == nil {
t.Error("Expected error for malformed session ID")
}
if !strings.Contains(err.Error(), "invalid session id") {
t.Errorf("Expected 'invalid session id' error, got: %v", err)
}
})

t.Run("Terminate marks session as terminated", func(t *testing.T) {
manager := &InsecureStatefulSessionIdManager{}
sessionID := manager.Generate()

isNotAllowed, err := manager.Terminate(sessionID)
if err != nil {
t.Errorf("Expected no error on termination, got: %v", err)
}
if isNotAllowed {
t.Error("Expected termination to be allowed")
}

isTerminated, err := manager.Validate(sessionID)
if !isTerminated {
t.Error("Expected session to be marked as terminated")
}
if err != nil {
t.Errorf("Expected no error for terminated session, got: %v", err)
}
})

t.Run("Terminate rejects non-existent session ID", func(t *testing.T) {
manager := &InsecureStatefulSessionIdManager{}
fakeSessionID := "mcp-session-ffffffff-ffff-ffff-ffff-ffffffffffff"

_, err := manager.Terminate(fakeSessionID)
if err == nil {
t.Error("Expected error when terminating non-existent session")
}
if !strings.Contains(err.Error(), "session not found") {
t.Errorf("Expected 'session not found' error, got: %v", err)
}
})

t.Run("Concurrent generate and validate", func(t *testing.T) {
manager := &InsecureStatefulSessionIdManager{}
var wg sync.WaitGroup
sessionIDs := make([]string, 100)

for i := 0; i < 100; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
sessionIDs[index] = manager.Generate()
}(i)
}

wg.Wait()

for _, sessionID := range sessionIDs {
isTerminated, err := manager.Validate(sessionID)
if err != nil {
t.Errorf("Expected valid session ID %s, got error: %v", sessionID, err)
}
if isTerminated {
t.Errorf("Expected session %s to not be terminated", sessionID)
}
}
})
}