|
1 | 1 | package oauth |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "crypto/tls" |
4 | 5 | "encoding/json" |
| 6 | + "fmt" |
5 | 7 | "net/http" |
6 | 8 | "net/http/httptest" |
7 | 9 | "strings" |
@@ -509,6 +511,140 @@ func TestCORSHeaders(t *testing.T) { |
509 | 511 | } |
510 | 512 | } |
511 | 513 |
|
| 514 | +func TestProtectedResourceMetadataEndpointTransportPaths(t *testing.T) { |
| 515 | + tests := []struct { |
| 516 | + name string |
| 517 | + transport string |
| 518 | + expectedPath string |
| 519 | + scheme string |
| 520 | + host string |
| 521 | + }{ |
| 522 | + { |
| 523 | + name: "streamable-http transport", |
| 524 | + transport: "streamable-http", |
| 525 | + expectedPath: "/mcp", |
| 526 | + scheme: "http", |
| 527 | + host: "localhost:8000", |
| 528 | + }, |
| 529 | + { |
| 530 | + name: "sse transport", |
| 531 | + transport: "sse", |
| 532 | + expectedPath: "/sse", |
| 533 | + scheme: "https", |
| 534 | + host: "localhost:8000", |
| 535 | + }, |
| 536 | + { |
| 537 | + name: "stdio transport (no path)", |
| 538 | + transport: "stdio", |
| 539 | + expectedPath: "", |
| 540 | + scheme: "http", |
| 541 | + host: "localhost:8000", |
| 542 | + }, |
| 543 | + { |
| 544 | + name: "empty transport (no path)", |
| 545 | + transport: "", |
| 546 | + expectedPath: "", |
| 547 | + scheme: "https", |
| 548 | + host: "example.com:9000", |
| 549 | + }, |
| 550 | + } |
| 551 | + |
| 552 | + for _, tt := range tests { |
| 553 | + t.Run(tt.name, func(t *testing.T) { |
| 554 | + cfg := createTestConfig() |
| 555 | + cfg.Transport = tt.transport |
| 556 | + |
| 557 | + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) |
| 558 | + manager := NewEndpointManager(provider, cfg) |
| 559 | + |
| 560 | + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) |
| 561 | + req.Host = tt.host |
| 562 | + if tt.scheme == "https" { |
| 563 | + req.TLS = &tls.ConnectionState{} |
| 564 | + } |
| 565 | + |
| 566 | + w := httptest.NewRecorder() |
| 567 | + handler := manager.protectedResourceMetadataHandler() |
| 568 | + handler(w, req) |
| 569 | + |
| 570 | + if w.Code != http.StatusOK { |
| 571 | + t.Errorf("Expected status 200, got %d", w.Code) |
| 572 | + } |
| 573 | + |
| 574 | + var metadata ProtectedResourceMetadata |
| 575 | + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { |
| 576 | + t.Fatalf("Failed to parse response: %v", err) |
| 577 | + } |
| 578 | + |
| 579 | + // Verify authorization server URL reflects the correct scheme and host |
| 580 | + expectedAuthServerURL := fmt.Sprintf("%s://%s", tt.scheme, tt.host) |
| 581 | + if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServerURL { |
| 582 | + t.Errorf("Expected auth server %s, got %v", expectedAuthServerURL, metadata.AuthorizationServers) |
| 583 | + } |
| 584 | + |
| 585 | + // Verify that the resource URL includes the transport-specific path |
| 586 | + expectedResourceURL := fmt.Sprintf("%s://%s%s", tt.scheme, tt.host, tt.expectedPath) |
| 587 | + if metadata.Resource != expectedResourceURL { |
| 588 | + t.Errorf("Expected resource URL %s, got %s", expectedResourceURL, metadata.Resource) |
| 589 | + } |
| 590 | + }) |
| 591 | + } |
| 592 | +} |
| 593 | + |
| 594 | +func TestProtectedResourceMetadataEndpointHostHeaders(t *testing.T) { |
| 595 | + tests := []struct { |
| 596 | + name string |
| 597 | + hostHeader string |
| 598 | + urlHost string |
| 599 | + expectedURL string |
| 600 | + }{ |
| 601 | + { |
| 602 | + name: "use Host header when present", |
| 603 | + hostHeader: "api.example.com:8080", |
| 604 | + urlHost: "fallback.com:9000", |
| 605 | + expectedURL: "http://api.example.com:8080", |
| 606 | + }, |
| 607 | + { |
| 608 | + name: "fallback to URL host when Host header empty", |
| 609 | + hostHeader: "", |
| 610 | + urlHost: "fallback.com:9000", |
| 611 | + expectedURL: "http://fallback.com:9000", |
| 612 | + }, |
| 613 | + } |
| 614 | + |
| 615 | + for _, tt := range tests { |
| 616 | + t.Run(tt.name, func(t *testing.T) { |
| 617 | + cfg := createTestConfig() |
| 618 | + cfg.Transport = "" // No additional path |
| 619 | + |
| 620 | + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) |
| 621 | + manager := NewEndpointManager(provider, cfg) |
| 622 | + |
| 623 | + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) |
| 624 | + req.Host = tt.hostHeader |
| 625 | + req.URL.Host = tt.urlHost |
| 626 | + |
| 627 | + w := httptest.NewRecorder() |
| 628 | + handler := manager.protectedResourceMetadataHandler() |
| 629 | + handler(w, req) |
| 630 | + |
| 631 | + if w.Code != http.StatusOK { |
| 632 | + t.Errorf("Expected status 200, got %d", w.Code) |
| 633 | + } |
| 634 | + |
| 635 | + var metadata ProtectedResourceMetadata |
| 636 | + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { |
| 637 | + t.Fatalf("Failed to parse response: %v", err) |
| 638 | + } |
| 639 | + |
| 640 | + // Verify the handler executed successfully |
| 641 | + if len(metadata.AuthorizationServers) == 0 { |
| 642 | + t.Error("Expected authorization servers in response") |
| 643 | + } |
| 644 | + }) |
| 645 | + } |
| 646 | +} |
| 647 | + |
512 | 648 | func TestAuthorizationProxyRedirectURIValidation(t *testing.T) { |
513 | 649 | cfg := createTestConfig() |
514 | 650 | provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) |
|
0 commit comments