Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions internal/auth/oauth/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,19 @@ func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc {
host = r.URL.Host
}

// Build the resource URL
resourceURL := fmt.Sprintf("%s://%s", scheme, host)
logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL)
// Build the resource URL with correct MCP endpoint path based on transport
var mcpPath string
switch em.cfg.Transport {
case "streamable-http":
mcpPath = "/mcp"
case "sse":
mcpPath = "/sse"
default:
mcpPath = ""
}

resourceURL := fmt.Sprintf("%s://%s%s", scheme, host, mcpPath)
logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s (transport: %s)", resourceURL, em.cfg.Transport)

provider := em.provider

Expand Down
136 changes: 136 additions & 0 deletions internal/auth/oauth/endpoints_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package oauth

import (
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -509,6 +511,140 @@ func TestCORSHeaders(t *testing.T) {
}
}

func TestProtectedResourceMetadataEndpointTransportPaths(t *testing.T) {
tests := []struct {
name string
transport string
expectedPath string
scheme string
host string
}{
{
name: "streamable-http transport",
transport: "streamable-http",
expectedPath: "/mcp",
scheme: "http",
host: "localhost:8000",
},
{
name: "sse transport",
transport: "sse",
expectedPath: "/sse",
scheme: "https",
host: "localhost:8000",
},
{
name: "stdio transport (no path)",
transport: "stdio",
expectedPath: "",
scheme: "http",
host: "localhost:8000",
},
{
name: "empty transport (no path)",
transport: "",
expectedPath: "",
scheme: "https",
host: "example.com:9000",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := createTestConfig()
cfg.Transport = tt.transport

provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
manager := NewEndpointManager(provider, cfg)

req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
req.Host = tt.host
if tt.scheme == "https" {
req.TLS = &tls.ConnectionState{}
}

w := httptest.NewRecorder()
handler := manager.protectedResourceMetadataHandler()
handler(w, req)

if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}

var metadata ProtectedResourceMetadata
if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}

// Verify authorization server URL reflects the correct scheme and host
expectedAuthServerURL := fmt.Sprintf("%s://%s", tt.scheme, tt.host)
if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServerURL {
t.Errorf("Expected auth server %s, got %v", expectedAuthServerURL, metadata.AuthorizationServers)
}

// Verify that the resource URL includes the transport-specific path
expectedResourceURL := fmt.Sprintf("%s://%s%s", tt.scheme, tt.host, tt.expectedPath)
if metadata.Resource != expectedResourceURL {
t.Errorf("Expected resource URL %s, got %s", expectedResourceURL, metadata.Resource)
}
})
}
}

func TestProtectedResourceMetadataEndpointHostHeaders(t *testing.T) {
tests := []struct {
name string
hostHeader string
urlHost string
expectedURL string
}{
{
name: "use Host header when present",
hostHeader: "api.example.com:8080",
urlHost: "fallback.com:9000",
expectedURL: "http://api.example.com:8080",
},
{
name: "fallback to URL host when Host header empty",
hostHeader: "",
urlHost: "fallback.com:9000",
expectedURL: "http://fallback.com:9000",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := createTestConfig()
cfg.Transport = "" // No additional path

provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
manager := NewEndpointManager(provider, cfg)

req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
req.Host = tt.hostHeader
req.URL.Host = tt.urlHost

w := httptest.NewRecorder()
handler := manager.protectedResourceMetadataHandler()
handler(w, req)

if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}

var metadata ProtectedResourceMetadata
if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}

// Verify the handler executed successfully
if len(metadata.AuthorizationServers) == 0 {
t.Error("Expected authorization servers in response")
}
})
}
}

func TestAuthorizationProxyRedirectURIValidation(t *testing.T) {
cfg := createTestConfig()
provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
Expand Down
Loading