From 0f61827f91f338096cb0a35cda425070c79bf38f Mon Sep 17 00:00:00 2001 From: Guoxun Wei Date: Thu, 25 Sep 2025 10:56:30 +0800 Subject: [PATCH 1/2] fix resource url --- internal/auth/oauth/endpoints.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go index a3ea8c0..8c2b160 100644 --- a/internal/auth/oauth/endpoints.go +++ b/internal/auth/oauth/endpoints.go @@ -439,9 +439,19 @@ func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { host = r.URL.Host } - // Build the resource URL - resourceURL := fmt.Sprintf("%s://%s", scheme, host) - logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL) + // Build the resource URL with correct MCP endpoint path based on transport + var mcpPath string + switch em.cfg.Transport { + case "streamable-http": + mcpPath = "/mcp" + case "sse": + mcpPath = "/sse" + default: + mcpPath = "" + } + + resourceURL := fmt.Sprintf("%s://%s%s", scheme, host, mcpPath) + logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s (transport: %s)", resourceURL, em.cfg.Transport) provider := em.provider From 01c55a868882f1971503a42f0a4562bcc42eaa0d Mon Sep 17 00:00:00 2001 From: Guoxun Wei Date: Mon, 29 Sep 2025 10:15:19 +0800 Subject: [PATCH 2/2] add unit test --- internal/auth/oauth/endpoints_test.go | 136 ++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/internal/auth/oauth/endpoints_test.go b/internal/auth/oauth/endpoints_test.go index f457d5c..a7d6309 100644 --- a/internal/auth/oauth/endpoints_test.go +++ b/internal/auth/oauth/endpoints_test.go @@ -1,7 +1,9 @@ package oauth import ( + "crypto/tls" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -509,6 +511,140 @@ func TestCORSHeaders(t *testing.T) { } } +func TestProtectedResourceMetadataEndpointTransportPaths(t *testing.T) { + tests := []struct { + name string + transport string + expectedPath string + scheme string + host string + }{ + { + name: "streamable-http transport", + transport: "streamable-http", + expectedPath: "/mcp", + scheme: "http", + host: "localhost:8000", + }, + { + name: "sse transport", + transport: "sse", + expectedPath: "/sse", + scheme: "https", + host: "localhost:8000", + }, + { + name: "stdio transport (no path)", + transport: "stdio", + expectedPath: "", + scheme: "http", + host: "localhost:8000", + }, + { + name: "empty transport (no path)", + transport: "", + expectedPath: "", + scheme: "https", + host: "example.com:9000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := createTestConfig() + cfg.Transport = tt.transport + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) + req.Host = tt.host + if tt.scheme == "https" { + req.TLS = &tls.ConnectionState{} + } + + w := httptest.NewRecorder() + handler := manager.protectedResourceMetadataHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var metadata ProtectedResourceMetadata + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Verify authorization server URL reflects the correct scheme and host + expectedAuthServerURL := fmt.Sprintf("%s://%s", tt.scheme, tt.host) + if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServerURL { + t.Errorf("Expected auth server %s, got %v", expectedAuthServerURL, metadata.AuthorizationServers) + } + + // Verify that the resource URL includes the transport-specific path + expectedResourceURL := fmt.Sprintf("%s://%s%s", tt.scheme, tt.host, tt.expectedPath) + if metadata.Resource != expectedResourceURL { + t.Errorf("Expected resource URL %s, got %s", expectedResourceURL, metadata.Resource) + } + }) + } +} + +func TestProtectedResourceMetadataEndpointHostHeaders(t *testing.T) { + tests := []struct { + name string + hostHeader string + urlHost string + expectedURL string + }{ + { + name: "use Host header when present", + hostHeader: "api.example.com:8080", + urlHost: "fallback.com:9000", + expectedURL: "http://api.example.com:8080", + }, + { + name: "fallback to URL host when Host header empty", + hostHeader: "", + urlHost: "fallback.com:9000", + expectedURL: "http://fallback.com:9000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := createTestConfig() + cfg.Transport = "" // No additional path + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) + req.Host = tt.hostHeader + req.URL.Host = tt.urlHost + + w := httptest.NewRecorder() + handler := manager.protectedResourceMetadataHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var metadata ProtectedResourceMetadata + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Verify the handler executed successfully + if len(metadata.AuthorizationServers) == 0 { + t.Error("Expected authorization servers in response") + } + }) + } +} + func TestAuthorizationProxyRedirectURIValidation(t *testing.T) { cfg := createTestConfig() provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)