Skip to content

Commit 7fd6e1b

Browse files
authored
fix: correct resource url (#217)
* fix resource url * add unit test
1 parent 807b221 commit 7fd6e1b

File tree

2 files changed

+149
-3
lines changed

2 files changed

+149
-3
lines changed

internal/auth/oauth/endpoints.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,19 @@ func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc {
439439
host = r.URL.Host
440440
}
441441

442-
// Build the resource URL
443-
resourceURL := fmt.Sprintf("%s://%s", scheme, host)
444-
logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL)
442+
// Build the resource URL with correct MCP endpoint path based on transport
443+
var mcpPath string
444+
switch em.cfg.Transport {
445+
case "streamable-http":
446+
mcpPath = "/mcp"
447+
case "sse":
448+
mcpPath = "/sse"
449+
default:
450+
mcpPath = ""
451+
}
452+
453+
resourceURL := fmt.Sprintf("%s://%s%s", scheme, host, mcpPath)
454+
logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s (transport: %s)", resourceURL, em.cfg.Transport)
445455

446456
provider := em.provider
447457

internal/auth/oauth/endpoints_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package oauth
22

33
import (
4+
"crypto/tls"
45
"encoding/json"
6+
"fmt"
57
"net/http"
68
"net/http/httptest"
79
"strings"
@@ -509,6 +511,140 @@ func TestCORSHeaders(t *testing.T) {
509511
}
510512
}
511513

514+
func TestProtectedResourceMetadataEndpointTransportPaths(t *testing.T) {
515+
tests := []struct {
516+
name string
517+
transport string
518+
expectedPath string
519+
scheme string
520+
host string
521+
}{
522+
{
523+
name: "streamable-http transport",
524+
transport: "streamable-http",
525+
expectedPath: "/mcp",
526+
scheme: "http",
527+
host: "localhost:8000",
528+
},
529+
{
530+
name: "sse transport",
531+
transport: "sse",
532+
expectedPath: "/sse",
533+
scheme: "https",
534+
host: "localhost:8000",
535+
},
536+
{
537+
name: "stdio transport (no path)",
538+
transport: "stdio",
539+
expectedPath: "",
540+
scheme: "http",
541+
host: "localhost:8000",
542+
},
543+
{
544+
name: "empty transport (no path)",
545+
transport: "",
546+
expectedPath: "",
547+
scheme: "https",
548+
host: "example.com:9000",
549+
},
550+
}
551+
552+
for _, tt := range tests {
553+
t.Run(tt.name, func(t *testing.T) {
554+
cfg := createTestConfig()
555+
cfg.Transport = tt.transport
556+
557+
provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
558+
manager := NewEndpointManager(provider, cfg)
559+
560+
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
561+
req.Host = tt.host
562+
if tt.scheme == "https" {
563+
req.TLS = &tls.ConnectionState{}
564+
}
565+
566+
w := httptest.NewRecorder()
567+
handler := manager.protectedResourceMetadataHandler()
568+
handler(w, req)
569+
570+
if w.Code != http.StatusOK {
571+
t.Errorf("Expected status 200, got %d", w.Code)
572+
}
573+
574+
var metadata ProtectedResourceMetadata
575+
if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil {
576+
t.Fatalf("Failed to parse response: %v", err)
577+
}
578+
579+
// Verify authorization server URL reflects the correct scheme and host
580+
expectedAuthServerURL := fmt.Sprintf("%s://%s", tt.scheme, tt.host)
581+
if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServerURL {
582+
t.Errorf("Expected auth server %s, got %v", expectedAuthServerURL, metadata.AuthorizationServers)
583+
}
584+
585+
// Verify that the resource URL includes the transport-specific path
586+
expectedResourceURL := fmt.Sprintf("%s://%s%s", tt.scheme, tt.host, tt.expectedPath)
587+
if metadata.Resource != expectedResourceURL {
588+
t.Errorf("Expected resource URL %s, got %s", expectedResourceURL, metadata.Resource)
589+
}
590+
})
591+
}
592+
}
593+
594+
func TestProtectedResourceMetadataEndpointHostHeaders(t *testing.T) {
595+
tests := []struct {
596+
name string
597+
hostHeader string
598+
urlHost string
599+
expectedURL string
600+
}{
601+
{
602+
name: "use Host header when present",
603+
hostHeader: "api.example.com:8080",
604+
urlHost: "fallback.com:9000",
605+
expectedURL: "http://api.example.com:8080",
606+
},
607+
{
608+
name: "fallback to URL host when Host header empty",
609+
hostHeader: "",
610+
urlHost: "fallback.com:9000",
611+
expectedURL: "http://fallback.com:9000",
612+
},
613+
}
614+
615+
for _, tt := range tests {
616+
t.Run(tt.name, func(t *testing.T) {
617+
cfg := createTestConfig()
618+
cfg.Transport = "" // No additional path
619+
620+
provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
621+
manager := NewEndpointManager(provider, cfg)
622+
623+
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
624+
req.Host = tt.hostHeader
625+
req.URL.Host = tt.urlHost
626+
627+
w := httptest.NewRecorder()
628+
handler := manager.protectedResourceMetadataHandler()
629+
handler(w, req)
630+
631+
if w.Code != http.StatusOK {
632+
t.Errorf("Expected status 200, got %d", w.Code)
633+
}
634+
635+
var metadata ProtectedResourceMetadata
636+
if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil {
637+
t.Fatalf("Failed to parse response: %v", err)
638+
}
639+
640+
// Verify the handler executed successfully
641+
if len(metadata.AuthorizationServers) == 0 {
642+
t.Error("Expected authorization servers in response")
643+
}
644+
})
645+
}
646+
}
647+
512648
func TestAuthorizationProxyRedirectURIValidation(t *testing.T) {
513649
cfg := createTestConfig()
514650
provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)

0 commit comments

Comments
 (0)