diff --git a/docs/oauth-authentication.md b/docs/oauth-authentication.md
new file mode 100644
index 0000000..344279f
--- /dev/null
+++ b/docs/oauth-authentication.md
@@ -0,0 +1,474 @@
+# OAuth Authentication for AKS-MCP
+
+This document describes how to configure and use OAuth authentication with AKS-MCP.
+
+## Overview
+
+AKS-MCP now supports OAuth 2.1 authentication using Azure Active Directory as the authorization server. When enabled, OAuth authentication provides secure access control for MCP endpoints using Bearer tokens.
+
+## Features
+
+- **Azure AD Integration**: Uses Azure Active Directory as the OAuth authorization server
+- **JWT Token Validation**: Validates JWT tokens with Azure AD signing keys
+- **OAuth 2.0 Metadata Endpoints**: Provides standard OAuth metadata discovery endpoints
+- **Dynamic Client Registration**: Supports RFC 7591 dynamic client registration
+- **Token Introspection**: Implements RFC 7662 token introspection
+- **Transport Support**: Works with both SSE and HTTP Streamable transports
+- **Flexible Configuration**: Supports environment variables and command-line configuration
+
+## Environment Setup and Azure AD Configuration
+
+### Prerequisites
+
+Before setting up OAuth authentication, ensure you have:
+
+- Azure CLI installed and configured (`az login`)
+- An Azure subscription with appropriate permissions to create applications
+- Azure Active Directory tenant access
+
+### Important: Environment Variables Shared Between OAuth and Azure CLI
+
+AKS-MCP uses the same environment variables (`AZURE_TENANT_ID`, `AZURE_CLIENT_ID`) for both OAuth authentication and Azure CLI operations. This design provides configuration simplicity but requires careful permission setup:
+
+**When `AZURE_CLIENT_ID` is set:**
+- OAuth: Used for validating user tokens accessing the MCP server
+- Azure CLI: Used for managed identity/workload identity authentication to access Azure resources
+
+**Permission Requirements:**
+- The Azure AD application must have both **OAuth permissions** (for user authentication) AND **Azure resource permissions** (for az CLI operations)
+- Missing either set of permissions will cause authentication failures
+
+### Step 1: Create Azure AD Application
+
+#### Using Azure Portal (Recommended)
+
+1. **Navigate to Azure Portal**
+ - Go to https://portal.azure.com
+ - Sign in with your Azure account
+
+2. **Create App Registration**
+ ```
+ Navigation: Azure Active Directory → App registrations → New registration
+ ```
+
+ Configure the following:
+ - **Name**: `AKS-MCP-OAuth` (or your preferred name)
+ - **Supported account types**: "Accounts in this organizational directory only"
+ - **Redirect URI Platform Options**:
+
+#### Supported Platform Types
+
+**✅ Mobile and desktop applications (Recommended)**
+- **Platform**: "Mobile and desktop applications"
+- **Redirect URIs**:
+ - `http://localhost:8000/oauth/callback`
+- **Benefits**:
+ - Native support for PKCE (required by OAuth 2.1)
+ - No client secret required (public client)
+ - Better security for localhost redirects
+- **Status**: ✅ **Confirmed working**
+
+**❌ Single-page application (SPA) - Not Recommended**
+- **Platform**: "Single-page application (SPA)"
+- **Redirect URIs**: Same as above
+- **Benefits**:
+ - Designed for PKCE flow
+ - No client secret required
+- **Critical Limitations**:
+ - **Token exchange restriction**: Azure AD error AADSTS9002327 - "Tokens issued for the 'Single-Page Application' client-type may only be redeemed via cross-origin requests"
+ - **Architecture mismatch**: SPA platform expects frontend JavaScript to handle token exchange, but AKS-MCP performs backend token exchange
+ - **CORS requirements**: Requires complex frontend-backend coordination for OAuth flow
+- **Status**: ❌ **Not compatible with AKS-MCP's backend OAuth implementation**
+
+**❌ Web application**
+- **Platform**: "Web"
+- **Why not supported**:
+ - Requires client secret (confidential client)
+ - AKS-MCP implements public client flow without secrets
+ - PKCE handling may differ
+
+**Choose Platform Recommendation:**
+1. **Primary**: Use "Mobile and desktop applications" (✅ confirmed working)
+2. **Avoid**: "Single-page application" - incompatible with backend OAuth implementation (AADSTS9002327 error)
+3. **Avoid**: "Web" platform due to client secret requirements
+
+3. **Record Essential Information**
+ From the "Overview" page, note:
+ - **Application (client) ID** - This is your `CLIENT_ID`
+ - **Directory (tenant) ID** - This is your `TENANT_ID`
+
+#### Using Azure CLI (Alternative)
+
+**For Mobile and desktop applications platform:**
+```bash
+# Create Azure AD application with public client platform
+az ad app create --display-name "AKS-MCP-OAuth" \
+ --public-client-redirect-uris "http://localhost:8000/oauth/callback"
+
+# Get application details
+az ad app list --display-name "AKS-MCP-OAuth" --query "[0].{appId:appId,objectId:objectId}"
+
+# Get your tenant ID
+az account show --query "tenantId" -o tsv
+```
+
+### Step 2: Configure API Permissions
+
+**Critical: Both OAuth and Azure CLI require proper permissions**
+
+1. **Add Required API Permissions**
+ ```
+ Navigation: Azure Active Directory → App registrations → [Your App] → API permissions
+ ```
+
+2. **Add Azure Service Management Permission (Required for OAuth)**
+ - Click "Add a permission"
+ - Select "Microsoft APIs" → "Azure Service Management"
+ - Choose "Delegated permissions"
+ - Select `user_impersonation`
+ - Click "Add permissions"
+
+3. **Add Azure Resource Management Permissions (Required for Azure CLI)**
+
+ When `AZURE_CLIENT_ID` is set, Azure CLI will use this application for authentication. Add these permissions based on your AKS-MCP access level:
+
+ **For readonly access:**
+ - Microsoft Graph → Application permissions → `Directory.Read.All`
+ - Azure Service Management → Delegated permissions → `user_impersonation`
+
+ **For readwrite/admin access:**
+ - Microsoft Graph → Application permissions → `Directory.Read.All`
+ - Azure Service Management → Delegated permissions → `user_impersonation`
+ - Consider adding specific Azure resource permissions based on your needs
+
+4. **Grant Admin Consent (Required)**
+ - Click "Grant admin consent for [Your Organization]"
+ - Confirm the consent
+
+**⚠️ Important Notes:**
+- Without proper Azure CLI permissions, you'll see "Insufficient privileges" errors when AKS-MCP tries to access Azure resources
+- The same application serves both OAuth authentication (user access to MCP) and Azure CLI authentication (MCP access to Azure)
+- Test both OAuth flow AND Azure resource access after permission changes
+
+### Step 3: Environment Configuration
+
+Set the required environment variables:
+
+```bash
+# Replace with your actual values from Step 1
+export AZURE_TENANT_ID="your-tenant-id"
+export AZURE_CLIENT_ID="your-client-id"
+export AZURE_SUBSCRIPTION_ID="your-subscription-id" # Optional, for AKS operations
+```
+
+**⚠️ Important: Dual Authentication Impact**
+
+When you set `AZURE_CLIENT_ID`, it affects both OAuth and Azure CLI authentication:
+
+1. **OAuth Authentication**: Validates user tokens for MCP server access
+2. **Azure CLI Authentication**: AKS-MCP uses this client ID for managed identity authentication when accessing Azure resources
+
+**Common Issues:**
+- If you only configured OAuth permissions, Azure CLI operations will fail with "Insufficient privileges"
+- If you only configured Azure resource permissions, OAuth token validation may fail
+- Solution: Ensure your Azure AD application has BOTH sets of permissions (see Step 2)
+
+**Testing Both Authentication Paths:**
+```bash
+# Test OAuth (should work after proper setup)
+curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp
+
+# Test Azure CLI access (should work after proper permissions)
+# This happens automatically when AKS-MCP tries to access Azure resources
+./aks-mcp --oauth-enabled --access-level=readonly
+```
+
+### Step 4: Start AKS-MCP with OAuth
+
+```bash
+# Using HTTP Streamable transport with OAuth (recommended)
+./aks-mcp \
+ --transport=streamable-http \
+ --port=8000 \
+ --oauth-enabled \
+ --oauth-tenant-id="$AZURE_TENANT_ID" \
+ --oauth-client-id="$AZURE_CLIENT_ID" \
+ --oauth-redirects="http://localhost:8000/oauth/callback" \
+ --access-level=readonly
+
+# Using SSE transport with OAuth (alternative)
+./aks-mcp \
+ --transport=sse \
+ --port=8000 \
+ --oauth-enabled \
+ --oauth-tenant-id="$AZURE_TENANT_ID" \
+ --oauth-client-id="$AZURE_CLIENT_ID" \
+ --oauth-redirects="http://localhost:8000/oauth/callback" \
+ --access-level=readonly
+
+# Environment variables are automatically used if set
+# You can also just use:
+./aks-mcp --transport=streamable-http --port=8000 --oauth-enabled --access-level=readonly
+```
+
+## Configuration Options
+
+### Command Line Flags
+
+- `--oauth-enabled`: Enable OAuth authentication (default: false)
+- `--oauth-tenant-id`: Azure AD tenant ID (or use AZURE_TENANT_ID env var)
+- `--oauth-client-id`: Azure AD client ID (or use AZURE_CLIENT_ID env var)
+- `--oauth-redirects`: Comma-separated list of allowed redirect URIs (required when OAuth enabled)
+- `--oauth-cors-origins`: Comma-separated list of allowed CORS origins for OAuth endpoints (e.g. http://localhost:6274 for MCP Inspector). If empty, no cross-origin requests are allowed for security
+
+**Note**: OAuth scopes are automatically configured to use `https://management.azure.com/.default` for optimal Azure AD compatibility. Custom scopes are not currently configurable via command line.
+
+### Example with Command Line Flags
+
+```bash
+./aks-mcp --transport=sse --oauth-enabled=true \
+ --oauth-tenant-id="12345678-1234-1234-1234-123456789012" \
+ --oauth-client-id="87654321-4321-4321-4321-210987654321"
+```
+
+**Note**: Scopes are automatically set to `https://management.azure.com/.default` and cannot be customized via command line.
+
+## OAuth Endpoints
+
+When OAuth is enabled, the following endpoints are available:
+
+### Metadata Endpoints (Unauthenticated)
+
+- `GET /.well-known/oauth-protected-resource` - OAuth 2.0 Protected Resource Metadata (RFC 9728)
+- `GET /.well-known/oauth-authorization-server` - OAuth 2.0 Authorization Server Metadata (RFC 8414)
+- `GET /.well-known/openid-configuration` - OpenID Connect Discovery (alias for authorization server metadata)
+- `GET /health` - Health check endpoint
+
+### OAuth Flow Endpoints (Unauthenticated)
+
+- `GET /oauth2/v2.0/authorize` - Authorization endpoint proxy to Azure AD
+- `POST /oauth2/v2.0/token` - Token exchange endpoint proxy to Azure AD
+- `GET /oauth/callback` - Authorization Code flow callback handler
+- `POST /oauth/register` - Dynamic Client Registration (RFC 7591)
+
+### Token Management (Unauthenticated for simplicity)
+
+- `POST /oauth/introspect` - Token Introspection (RFC 7662)
+
+### Authenticated MCP Endpoints
+
+When OAuth is enabled, these endpoints require Bearer token authentication:
+
+- **SSE Transport**: `GET /sse`, `POST /message`
+- **HTTP Streamable Transport**: `POST /mcp`
+
+## Client Integration
+
+### Obtaining an Access Token
+
+Use the Azure AD OAuth flow to obtain an access token:
+
+```bash
+# Example using Azure CLI (for testing)
+az account get-access-token --resource https://management.azure.com/ --query accessToken -o tsv
+```
+
+### Making Authenticated Requests
+
+Include the Bearer token in the Authorization header:
+
+```bash
+# Example authenticated request to SSE endpoint
+curl -H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
+ -H "Accept: text/event-stream" \
+ http://localhost:8000/sse
+
+# Example authenticated request to HTTP Streamable endpoint
+curl -H "Authorization: Bearer YOUR_ACCESS_TOKEN" \
+ -H "Content-Type: application/json" \
+ -X POST http://localhost:8000/mcp \
+ -d '{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}'
+```
+
+## Testing OAuth Integration
+
+### 1. Test OAuth Metadata
+
+```bash
+# Get protected resource metadata
+curl http://localhost:8000/.well-known/oauth-protected-resource
+
+# Get authorization server metadata
+curl http://localhost:8000/.well-known/oauth-authorization-server
+```
+
+### 2. Test Dynamic Client Registration
+
+```bash
+curl -X POST http://localhost:8000/oauth/register \
+ -H "Content-Type: application/json" \
+ -d '{
+ "redirect_uris": ["http://localhost:3000/oauth/callback"],
+ "client_name": "Test MCP Client",
+ "grant_types": ["authorization_code"]
+ }'
+```
+
+### 3. Test Token Introspection
+
+```bash
+curl -X POST http://localhost:8000/oauth/introspect \
+ -H "Content-Type: application/x-www-form-urlencoded" \
+ -d "token=YOUR_ACCESS_TOKEN"
+```
+
+## Security Considerations
+
+1. **HTTPS in Production**: Always use HTTPS in production environments
+2. **Token Validation**: JWT tokens are validated against Azure AD signing keys
+3. **Scope Validation**: Tokens must include required scopes
+4. **Audience Validation**: Tokens must have the correct audience claim
+5. **Redirect URI Validation**: Only configured redirect URIs are allowed
+
+## Troubleshooting
+
+### Common Issues
+
+#### Authentication and Token Issues
+1. **Invalid Token**: Ensure the token is valid and not expired
+2. **Wrong Audience**: Verify the token audience matches `https://management.azure.com`
+3. **Missing Scopes**: Ensure the token includes `https://management.azure.com/.default` scope
+4. **JWT Signature Validation Failed**:
+ - Check that Azure AD application platform is set correctly
+ - Verify tenant ID matches the issuer in the token
+ - Ensure token is using v2.0 format (from Azure Management API scope)
+
+#### Azure AD Application Configuration Issues
+5. **Client ID Not Found**: Verify the Application (client) ID is correct
+6. **Redirect URI Mismatch**: Ensure redirect URIs match exactly in Azure AD app registration
+7. **Wrong Platform Type**: Use "Mobile and desktop applications", NOT "Web" or "Single-page application"
+8. **Insufficient Permissions**: Verify both OAuth and Azure resource permissions are configured
+9. **SPA Platform Incompatibility (AADSTS9002327)**:
+ - Error: "Tokens issued for the 'Single-Page Application' client-type may only be redeemed via cross-origin requests"
+ - Solution: Change Azure AD app platform to "Mobile and desktop applications"
+ - Cause: SPA platform requires frontend token exchange, incompatible with AKS-MCP's backend implementation
+
+#### Network and Endpoint Issues
+10. **CORS Errors**: Check that redirect URIs are properly configured for localhost
+11. **Network Issues**: Check connectivity to Azure AD endpoints
+11. **Port Conflicts**: Ensure the configured port (default 8000) is available
+
+#### Scope and Permission Issues
+12. **Scope Mixing Error**:
+ - Error: "scope can't be combined with resource-specific scopes"
+ - Solution: Our implementation automatically handles this by using only Azure Management API scope
+13. **Resource Parameter Issues**:
+ - Azure AD doesn't support RFC 8707 resource parameter
+ - Our implementation works around this limitation automatically
+
+### Debug Logging
+
+Enable verbose logging for OAuth debugging:
+
+```bash
+./aks-mcp --oauth-enabled=true --verbose
+```
+
+### Health Check
+
+Use the health endpoint to verify OAuth configuration:
+
+```bash
+curl http://localhost:8000/health
+```
+
+Expected response with OAuth enabled:
+```json
+{
+ "status": "healthy",
+ "oauth": {
+ "enabled": true
+ }
+}
+```
+
+### Testing OAuth Flow Step by Step
+
+1. **Test Metadata Discovery**:
+```bash
+# Should return authorization server URLs
+curl http://localhost:8000/.well-known/oauth-protected-resource
+
+# Should return PKCE support and endpoints
+curl http://localhost:8000/.well-known/oauth-authorization-server
+```
+
+2. **Test Client Registration**:
+```bash
+curl -X POST http://localhost:8000/oauth/register \
+ -H "Content-Type: application/json" \
+ -d '{
+ "redirect_uris": ["http://localhost:8000/oauth/callback"],
+ "client_name": "Test Client"
+ }'
+```
+
+3. **Test Authorization Flow**:
+ - Open browser to: `http://localhost:8000/oauth2/v2.0/authorize?response_type=code&client_id=YOUR_CLIENT_ID&redirect_uri=http://localhost:8000/oauth/callback&scope=https://management.azure.com/.default&code_challenge=CHALLENGE&code_challenge_method=S256&state=STATE`
+
+4. **Verify Token Validation**:
+```bash
+# Use a valid Azure AD token
+curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp
+```
+
+## Migration from Non-OAuth
+
+To migrate from a non-OAuth AKS-MCP deployment:
+
+1. Update clients to obtain and include Bearer tokens
+2. Enable OAuth on the server with `--oauth-enabled=true`
+3. Configure Azure AD application and credentials
+4. Test with a subset of clients before full migration
+5. Monitor logs for authentication errors
+
+## Integration with MCP Inspector
+
+The MCP Inspector tool can be used to test OAuth-enabled AKS-MCP servers. Configure the Inspector's OAuth settings to match your AKS-MCP OAuth configuration for testing.
+
+### Important: Redirect URI Configuration for MCP Inspector
+
+When using MCP Inspector with OAuth authentication, you need to add the Inspector's proxy redirect URI to your OAuth configuration:
+
+```bash
+# Add Inspector's redirect URI (typically http://localhost:6274/oauth/callback)
+./aks-mcp \
+ --transport=streamable-http \
+ --port=8000 \
+ --oauth-enabled \
+ --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback" \
+ --access-level=readonly
+```
+
+**Key Points:**
+- MCP Inspector typically runs on port 6274 by default
+- The Inspector creates a proxy redirect URI at `/oauth/callback`
+- You must include both your server's redirect URI AND the Inspector's redirect URI
+- You must also configure CORS origins to allow the Inspector's web interface to make requests
+- Comma-separate multiple redirect URIs in the `--oauth-redirects` parameter
+- Comma-separate multiple CORS origins in the `--oauth-cors-origins` parameter
+- Without the Inspector's redirect URI, OAuth authentication will fail with "redirect_uri not registered" error
+- Without the Inspector's CORS origin, the web interface will be blocked by browser CORS policy
+
+**Example with MCP Inspector configuration:**
+```bash
+./aks-mcp \
+ --transport=streamable-http \
+ --port=8000 \
+ --oauth-enabled \
+ --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback" \
+ --oauth-cors-origins="http://localhost:6274" \
+ --access-level=readonly
+```
+
+For more information, see the MCP OAuth specification and Azure AD documentation.
\ No newline at end of file
diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go
new file mode 100644
index 0000000..af49899
--- /dev/null
+++ b/internal/auth/oauth/endpoints.go
@@ -0,0 +1,1021 @@
+package oauth
+
+import (
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/Azure/aks-mcp/internal/auth"
+ "github.com/Azure/aks-mcp/internal/config"
+)
+
+// validateAzureADURL validates that the URL is a legitimate Azure AD endpoint
+func validateAzureADURL(tokenURL string) error {
+ parsedURL, err := url.Parse(tokenURL)
+ if err != nil {
+ return fmt.Errorf("invalid URL format: %w", err)
+ }
+
+ // Only allow HTTPS for security
+ if parsedURL.Scheme != "https" {
+ return fmt.Errorf("only HTTPS URLs are allowed")
+ }
+
+ // Only allow Azure AD endpoints
+ if parsedURL.Host != "login.microsoftonline.com" {
+ return fmt.Errorf("only Azure AD endpoints are allowed")
+ }
+
+ // Validate path format for token endpoint (should be /{tenantId}/oauth2/v2.0/token)
+ if !strings.Contains(parsedURL.Path, "/oauth2/v2.0/token") {
+ return fmt.Errorf("invalid Azure AD token endpoint path")
+ }
+
+ return nil
+}
+
+// EndpointManager manages OAuth-related HTTP endpoints
+type EndpointManager struct {
+ provider *AzureOAuthProvider
+ cfg *config.ConfigData
+}
+
+// NewEndpointManager creates a new OAuth endpoint manager
+func NewEndpointManager(provider *AzureOAuthProvider, cfg *config.ConfigData) *EndpointManager {
+ return &EndpointManager{
+ provider: provider,
+ cfg: cfg,
+ }
+}
+
+// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting
+func (em *EndpointManager) setCORSHeaders(w http.ResponseWriter, r *http.Request) {
+ requestOrigin := r.Header.Get("Origin")
+
+ // Check if the request origin is in the allowed list
+ var allowedOrigin string
+ for _, allowed := range em.provider.config.AllowedOrigins {
+ if requestOrigin == allowed {
+ allowedOrigin = requestOrigin
+ break
+ }
+ }
+
+ // Only set CORS headers if origin is allowed
+ if allowedOrigin != "" {
+ w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version")
+ w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
+ w.Header().Set("Access-Control-Allow-Credentials", "false")
+ } else if requestOrigin != "" {
+ log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin)
+ }
+}
+
+// setCacheHeaders sets cache control headers based on EnableCache configuration
+func (em *EndpointManager) setCacheHeaders(w http.ResponseWriter) {
+ if config.EnableCache {
+ // Enable caching for 1 hour when cache is enabled
+ w.Header().Set("Cache-Control", "max-age=3600")
+ } else {
+ // Disable all caching when cache is disabled (for debugging)
+ w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
+ w.Header().Set("Pragma", "no-cache")
+ w.Header().Set("Expires", "0")
+ }
+}
+
+// RegisterEndpoints registers OAuth endpoints with the provided HTTP mux
+func (em *EndpointManager) RegisterEndpoints(mux *http.ServeMux) {
+ // OAuth 2.0 Protected Resource Metadata endpoint (RFC 9728)
+ mux.HandleFunc("/.well-known/oauth-protected-resource", em.protectedResourceMetadataHandler())
+
+ // OAuth 2.0 Authorization Server Metadata endpoint (RFC 8414)
+ // Note: This would typically be served by Azure AD, but we provide a proxy for convenience
+ mux.HandleFunc("/.well-known/oauth-authorization-server", em.authServerMetadataProxyHandler())
+
+ // OpenID Connect Discovery endpoint (compatibility with MCP Inspector)
+ mux.HandleFunc("/.well-known/openid-configuration", em.authServerMetadataProxyHandler())
+
+ // Authorization endpoint proxy to handle Azure AD compatibility
+ mux.HandleFunc("/oauth2/v2.0/authorize", em.authorizationProxyHandler())
+
+ // Dynamic Client Registration endpoint (RFC 7591)
+ mux.HandleFunc("/oauth/register", em.clientRegistrationHandler())
+
+ // Token introspection endpoint (RFC 7662) - optional
+ mux.HandleFunc("/oauth/introspect", em.tokenIntrospectionHandler())
+
+ // OAuth 2.0 callback endpoint for Authorization Code flow
+ mux.HandleFunc("/oauth/callback", em.callbackHandler())
+
+ // OAuth 2.0 token endpoint for Authorization Code exchange
+ mux.HandleFunc("/oauth2/v2.0/token", em.tokenHandler())
+
+ // Health check endpoint (unauthenticated)
+ mux.HandleFunc("/health", em.healthHandler())
+}
+
+// authServerMetadataProxyHandler proxies authorization server metadata from Azure AD
+func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("OAuth DEBUG: Received request for authorization server metadata: %s %s", r.Method, r.URL.Path)
+
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodGet {
+ log.Printf("OAuth ERROR: Invalid method %s for metadata endpoint", r.Method)
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Get metadata from Azure AD
+ provider := em.provider
+
+ // Build server URL based on the request
+ scheme := "http"
+ if r.TLS != nil {
+ scheme = "https"
+ }
+
+ // Use the Host header from the request
+ host := r.Host
+ if host == "" {
+ host = r.URL.Host
+ }
+
+ serverURL := fmt.Sprintf("%s://%s", scheme, host)
+
+ metadata, err := provider.GetAuthorizationServerMetadata(serverURL)
+ if err != nil {
+ log.Printf("Failed to fetch authorization server metadata: %v\n", err)
+ http.Error(w, fmt.Sprintf("Failed to fetch authorization server metadata: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ em.setCacheHeaders(w)
+
+ if err := json.NewEncoder(w).Encode(metadata); err != nil {
+ log.Printf("Failed to encode response: %v\n", err)
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ return
+ }
+ }
+}
+
+// clientRegistrationHandler implements OAuth 2.0 Dynamic Client Registration (RFC 7591)
+func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("OAuth DEBUG: Received client registration request: %s %s", r.Method, r.URL.Path)
+
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodPost {
+ log.Printf("OAuth ERROR: Invalid method %s for client registration endpoint, only POST allowed", r.Method)
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Parse client registration request
+ var registrationRequest ClientRegistrationRequest
+
+ if err := json.NewDecoder(r.Body).Decode(®istrationRequest); err != nil {
+ log.Printf("OAuth ERROR: Failed to parse client registration JSON: %v", err)
+ em.writeErrorResponse(w, "invalid_request", "Invalid JSON in request body", http.StatusBadRequest)
+ return
+ }
+
+ log.Printf("OAuth DEBUG: Client registration request parsed - client_name: %s, redirect_uris: %v", registrationRequest.ClientName, registrationRequest.RedirectURIs)
+
+ // Validate registration request
+ if err := em.validateClientRegistration(®istrationRequest); err != nil {
+ log.Printf("OAuth ERROR: Client registration validation failed: %v", err)
+ em.writeErrorResponse(w, "invalid_client_metadata", err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ // Use client-requested grant types if provided and valid, otherwise use defaults
+ grantTypes := registrationRequest.GrantTypes
+ if len(grantTypes) == 0 {
+ grantTypes = []string{"authorization_code", "refresh_token"}
+ }
+
+ // Use client-requested response types if provided and valid, otherwise use defaults
+ responseTypes := registrationRequest.ResponseTypes
+ if len(responseTypes) == 0 {
+ responseTypes = []string{"code"}
+ }
+
+ // For Azure AD compatibility, use the configured client ID
+ // In a full RFC 7591 implementation, each registration would get a unique ID
+ // But since Azure AD requires pre-registered client IDs, we return the configured one
+ clientID := em.cfg.OAuthConfig.ClientID
+
+ clientInfo := map[string]interface{}{
+ "client_id": clientID, // Use configured Azure AD client ID
+ "client_id_issued_at": time.Now().Unix(), // RFC 7591: timestamp of issuance
+ "redirect_uris": registrationRequest.RedirectURIs,
+ "token_endpoint_auth_method": "none", // Public client (PKCE required)
+ "grant_types": grantTypes,
+ "response_types": responseTypes,
+ "client_name": registrationRequest.ClientName,
+ "client_uri": registrationRequest.ClientURI,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+
+ if err := json.NewEncoder(w).Encode(clientInfo); err != nil {
+ log.Printf("OAuth ERROR: Failed to encode client registration response: %v", err)
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ return
+ }
+ }
+}
+
+// validateClientRegistration validates a client registration request
+func (em *EndpointManager) validateClientRegistration(req *ClientRegistrationRequest) error {
+ // Validate redirect URIs - require at least one
+ if len(req.RedirectURIs) == 0 {
+ return fmt.Errorf("at least one redirect_uri is required")
+ }
+
+ // Basic URL validation for redirect URIs
+ for _, redirectURI := range req.RedirectURIs {
+ if _, err := url.Parse(redirectURI); err != nil {
+ return fmt.Errorf("invalid redirect_uri format: %s", redirectURI)
+ }
+ }
+
+ // Validate grant types
+ validGrantTypes := map[string]bool{
+ "authorization_code": true,
+ "refresh_token": true,
+ }
+
+ for _, grantType := range req.GrantTypes {
+ if !validGrantTypes[grantType] {
+ return fmt.Errorf("unsupported grant_type: %s", grantType)
+ }
+ }
+
+ // Validate response types
+ validResponseTypes := map[string]bool{
+ "code": true,
+ }
+
+ for _, responseType := range req.ResponseTypes {
+ if !validResponseTypes[responseType] {
+ return fmt.Errorf("unsupported response_type: %s", responseType)
+ }
+ }
+
+ return nil
+}
+
+// validateRedirectURI validates that a redirect URI is registered and allowed
+func (em *EndpointManager) validateRedirectURI(redirectURI string) error {
+ if len(em.cfg.OAuthConfig.RedirectURIs) == 0 {
+ return fmt.Errorf("no redirect URIs configured")
+ }
+
+ for _, allowed := range em.cfg.OAuthConfig.RedirectURIs {
+ if redirectURI == allowed {
+ return nil
+ }
+ }
+
+ log.Printf("OAuth SECURITY WARNING: Invalid redirect URI attempted: %s, allowed: %v",
+ redirectURI, em.cfg.OAuthConfig.RedirectURIs)
+ return fmt.Errorf("redirect_uri not registered: %s", redirectURI)
+}
+
+// tokenIntrospectionHandler implements RFC 7662 OAuth 2.0 Token Introspection
+func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // This endpoint should be protected with client authentication
+ // For simplicity, we'll skip client auth in this implementation
+
+ token := r.FormValue("token")
+ if token == "" {
+ em.writeErrorResponse(w, "invalid_request", "Missing token parameter", http.StatusBadRequest)
+ return
+ }
+
+ // Validate the token
+ provider := em.provider
+
+ tokenInfo, err := provider.ValidateToken(r.Context(), token)
+ if err != nil {
+ // Return inactive token response
+ response := map[string]interface{}{
+ "active": false,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ log.Printf("Failed to encode introspection response: %v", err)
+ }
+ return
+ }
+
+ // Return active token response
+ response := map[string]interface{}{
+ "active": true,
+ "client_id": em.cfg.OAuthConfig.ClientID,
+ "scope": strings.Join(tokenInfo.Scope, " "),
+ "sub": tokenInfo.Subject,
+ "aud": tokenInfo.Audience,
+ "iss": tokenInfo.Issuer,
+ "exp": tokenInfo.ExpiresAt.Unix(),
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ return
+ }
+ }
+}
+
+// healthHandler provides a simple health check endpoint
+func (em *EndpointManager) healthHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ response := map[string]interface{}{
+ "status": "healthy",
+ "oauth": map[string]interface{}{
+ "enabled": em.cfg.OAuthConfig.Enabled,
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ return
+ }
+ }
+}
+
+// protectedResourceMetadataHandler handles OAuth 2.0 Protected Resource Metadata requests
+func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("OAuth DEBUG: Received request for protected resource metadata: %s %s", r.Method, r.URL.Path)
+
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodGet {
+ log.Printf("OAuth ERROR: Invalid method %s for protected resource metadata endpoint", r.Method)
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Build resource URL based on the request
+ scheme := "http"
+ if r.TLS != nil {
+ scheme = "https"
+ }
+
+ // Use the Host header from the request
+ host := r.Host
+ if host == "" {
+ host = r.URL.Host
+ }
+
+ // Build the resource URL
+ resourceURL := fmt.Sprintf("%s://%s", scheme, host)
+ log.Printf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL)
+
+ provider := em.provider
+
+ metadata, err := provider.GetProtectedResourceMetadata(resourceURL)
+ if err != nil {
+ log.Printf("OAuth ERROR: Failed to get protected resource metadata: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+
+ log.Printf("OAuth DEBUG: Successfully generated protected resource metadata with %d authorization servers", len(metadata.AuthorizationServers))
+
+ w.Header().Set("Content-Type", "application/json")
+ em.setCacheHeaders(w)
+
+ if err := json.NewEncoder(w).Encode(metadata); err != nil {
+ log.Printf("OAuth ERROR: Failed to encode protected resource metadata response: %v", err)
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ return
+ }
+ }
+}
+
+// writeErrorResponse writes an OAuth error response
+func (em *EndpointManager) writeErrorResponse(w http.ResponseWriter, errorCode, description string, statusCode int) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(statusCode)
+
+ response := map[string]interface{}{
+ "error": errorCode,
+ "error_description": description,
+ }
+
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ log.Printf("Failed to encode error response: %v", err)
+ }
+}
+
+// authorizationProxyHandler proxies authorization requests to Azure AD with resource parameter filtering
+func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("OAuth DEBUG: Received authorization proxy request: %s %s", r.Method, r.URL.Path)
+
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodGet {
+ log.Printf("OAuth ERROR: Invalid method %s for authorization endpoint, only GET allowed", r.Method)
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Parse query parameters
+ query := r.URL.Query()
+
+ // Validate redirect_uri parameter for security and better user experience
+ redirectURI := query.Get("redirect_uri")
+ if redirectURI == "" {
+ log.Printf("OAuth ERROR: Missing redirect_uri parameter in authorization request")
+ log.Printf("OAuth HELP: To fix this error, configure redirect URIs using --oauth-redirects flag")
+ log.Printf("OAuth HELP: For MCP Inspector, use: --oauth-redirects=\"http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback\"")
+ em.writeErrorResponse(w, "invalid_request", "redirect_uri parameter is required", http.StatusBadRequest)
+ return
+ }
+
+ // Validate that the redirect_uri is registered and allowed
+ if err := em.validateRedirectURI(redirectURI); err != nil {
+ log.Printf("OAuth ERROR: redirect_uri %s not registered - requests will be blocked for security", redirectURI)
+ em.writeErrorResponse(w, "invalid_request", fmt.Sprintf("redirect_uri not registered: %s", redirectURI), http.StatusBadRequest)
+ return
+ }
+
+ // Enforce PKCE for OAuth 2.1 compliance (MCP requirement)
+ codeChallenge := query.Get("code_challenge")
+ codeChallengeMethod := query.Get("code_challenge_method")
+
+ if codeChallenge == "" {
+ log.Printf("OAuth ERROR: Missing PKCE code_challenge parameter (required for OAuth 2.1)")
+ em.writeErrorResponse(w, "invalid_request", "PKCE code_challenge is required", http.StatusBadRequest)
+ return
+ }
+
+ if codeChallengeMethod == "" {
+ // Default to S256 if not specified
+ query.Set("code_challenge_method", "S256")
+ log.Printf("OAuth DEBUG: Setting default code_challenge_method to S256")
+ } else if codeChallengeMethod != "S256" {
+ log.Printf("OAuth ERROR: Unsupported code_challenge_method: %s (only S256 supported)", codeChallengeMethod)
+ em.writeErrorResponse(w, "invalid_request", "Only S256 code_challenge_method is supported", http.StatusBadRequest)
+ return
+ }
+
+ // Resource parameter handling for MCP compliance
+ // requestedScopes := strings.Split(query.Get("scope"), " ")
+
+ // Azure AD v2.0 doesn't support RFC 8707 Resource Indicators in authorization requests
+ // Remove the resource parameter if present for Azure AD compatibility
+ resourceParam := query.Get("resource")
+ if resourceParam != "" {
+ log.Printf("OAuth DEBUG: Removing resource parameter for Azure AD compatibility: %s", resourceParam)
+ query.Del("resource")
+ }
+
+ // Use only server-required scopes for Azure AD compatibility
+ // Azure AD .default scopes cannot be mixed with OpenID Connect scopes
+ // We prioritize Azure Management API access over OpenID Connect user info
+ finalScopes := em.cfg.OAuthConfig.RequiredScopes
+
+ finalScopeString := strings.Join(finalScopes, " ")
+ query.Set("scope", finalScopeString)
+ log.Printf("OAuth DEBUG: Setting final scope for Azure AD: %s", finalScopeString)
+
+ // Build the Azure AD authorization URL
+ azureAuthURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", em.cfg.OAuthConfig.TenantID)
+
+ // Create the redirect URL with filtered parameters
+ redirectURL := fmt.Sprintf("%s?%s", azureAuthURL, query.Encode())
+ log.Printf("OAuth DEBUG: Redirecting to Azure AD authorization endpoint: %s", azureAuthURL)
+
+ // Redirect to Azure AD
+ http.Redirect(w, r, redirectURL, http.StatusFound)
+ }
+}
+
+// callbackHandler handles OAuth 2.0 Authorization Code flow callback
+func (em *EndpointManager) callbackHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("OAuth DEBUG: Received callback request: %s %s", r.Method, r.URL.Path)
+
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodGet {
+ log.Printf("OAuth ERROR: Invalid method %s for callback endpoint, only GET allowed", r.Method)
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Parse query parameters
+ query := r.URL.Query()
+
+ // Check for error response from authorization server
+ if authError := query.Get("error"); authError != "" {
+ errorDesc := query.Get("error_description")
+ log.Printf("OAuth ERROR: Authorization server returned error: %s - %s", authError, errorDesc)
+ em.writeCallbackErrorResponse(w, fmt.Sprintf("Authorization failed: %s - %s", authError, errorDesc))
+ return
+ }
+
+ // Get authorization code
+ code := query.Get("code")
+ if code == "" {
+ log.Printf("OAuth ERROR: Missing authorization code in callback")
+ em.writeCallbackErrorResponse(w, "Missing authorization code")
+ return
+ }
+
+ // Get state parameter for CSRF protection
+ state := query.Get("state")
+ if state == "" {
+ log.Printf("OAuth ERROR: Missing state parameter in callback")
+ em.writeCallbackErrorResponse(w, "Missing state parameter")
+ return
+ }
+
+ log.Printf("OAuth DEBUG: Callback parameters validated - has_code: true, state: %s", state)
+
+ // Validate redirect URI for security - construct expected URI and validate it
+ expectedRedirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port)
+ if err := em.validateRedirectURI(expectedRedirectURI); err != nil {
+ log.Printf("OAuth ERROR: Redirect URI validation failed: %v", err)
+ em.writeCallbackErrorResponse(w, "Invalid redirect URI")
+ return
+ }
+
+ // Exchange authorization code for access token
+ tokenResponse, err := em.exchangeCodeForToken(code, state)
+ if err != nil {
+ log.Printf("OAuth ERROR: Failed to exchange authorization code for token: %v", err)
+ em.writeCallbackErrorResponse(w, fmt.Sprintf("Failed to exchange code for token: %v", err))
+ return
+ }
+
+ // Skip token validation in callback - validation happens during MCP requests
+ // Create minimal token info for callback success page
+ tokenInfo := &auth.TokenInfo{
+ AccessToken: tokenResponse.AccessToken,
+ TokenType: "Bearer",
+ ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration
+ Scope: em.cfg.OAuthConfig.RequiredScopes, // Use configured scopes
+ Subject: "authenticated_user", // Placeholder
+ Audience: []string{fmt.Sprintf("https://sts.windows.net/%s/", em.cfg.OAuthConfig.TenantID)},
+ Issuer: fmt.Sprintf("https://sts.windows.net/%s/", em.cfg.OAuthConfig.TenantID),
+ Claims: make(map[string]interface{}),
+ }
+
+ // Return success response with token information
+ em.writeCallbackSuccessResponse(w, tokenResponse, tokenInfo)
+ }
+}
+
+// TokenResponse represents the response from token exchange
+type TokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int `json:"expires_in"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ Scope string `json:"scope,omitempty"`
+}
+
+// exchangeCodeForToken exchanges authorization code for access token
+func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenResponse, error) {
+ // Prepare token exchange request
+ tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.cfg.OAuthConfig.TenantID)
+
+ // Validate URL for security
+ if err := validateAzureADURL(tokenURL); err != nil {
+ return nil, fmt.Errorf("invalid token URL: %w", err)
+ }
+
+ // Use default callback redirect URI for token exchange
+ redirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port)
+
+ // Prepare form data
+ data := url.Values{}
+ data.Set("grant_type", "authorization_code")
+ data.Set("client_id", em.cfg.OAuthConfig.ClientID)
+ data.Set("code", code)
+ data.Set("redirect_uri", redirectURI)
+ data.Set("scope", strings.Join(em.cfg.OAuthConfig.RequiredScopes, " "))
+
+ // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests
+ // It uses scope-based resource identification instead
+ // For MCP compliance, we handle resource binding through audience validation
+
+ // Make token exchange request
+ resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above
+ if err != nil {
+ return nil, fmt.Errorf("token exchange request failed: %w", err)
+ }
+ defer func() {
+ if err := resp.Body.Close(); err != nil {
+ log.Printf("Failed to close response body: %v", err)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
+ }
+
+ // Parse token response
+ var tokenResponse TokenResponse
+ if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
+ }
+
+ return &tokenResponse, nil
+}
+
+// writeCallbackErrorResponse writes an error response for callback
+func (em *EndpointManager) writeCallbackErrorResponse(w http.ResponseWriter, message string) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusBadRequest)
+
+ html := fmt.Sprintf(`
+
+
+
+ OAuth Authentication Error
+
+
+
+
+
Authentication Error
+
%s
+
Please try again or contact your administrator.
+
+
+`, message)
+
+ if _, err := w.Write([]byte(html)); err != nil {
+ log.Printf("Failed to write error response: %v", err)
+ }
+}
+
+// writeCallbackSuccessResponse writes a success response for callback
+func (em *EndpointManager) writeCallbackSuccessResponse(w http.ResponseWriter, tokenResponse *TokenResponse, tokenInfo *auth.TokenInfo) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+
+ // Generate a secure session token for the client to use
+ _, err := em.generateSessionToken()
+ if err != nil {
+ em.writeCallbackErrorResponse(w, "Failed to generate session token")
+ return
+ }
+
+ html := fmt.Sprintf(`
+
+
+
+ OAuth Authentication Success
+
+
+
+
+
Authentication Successful
+
You have been successfully authenticated with Azure AD.
+
+
+
Access Token (use as Bearer token):
+
%s
+
+
+
+
+
Token Information:
+
+ - Subject: %s
+ - Audience: %s
+ - Scope: %s
+ - Expires: %s
+
+
+
+
+
For MCP Client Usage:
+
Use this token in the Authorization header:
+
Authorization: Bearer %s
+
+
+
+
+
+
+`,
+ tokenResponse.AccessToken,
+ tokenInfo.Subject,
+ strings.Join(tokenInfo.Audience, ", "),
+ strings.Join(tokenInfo.Scope, ", "),
+ tokenInfo.ExpiresAt.Format("2006-01-02 15:04:05 UTC"),
+ tokenResponse.AccessToken,
+ tokenResponse.AccessToken)
+
+ if _, err := w.Write([]byte(html)); err != nil {
+ log.Printf("Failed to write success response: %v", err)
+ }
+}
+
+// isValidClientID validates if a client ID is acceptable
+func (em *EndpointManager) isValidClientID(clientID string) bool {
+ // Accept configured client ID (primary method for Azure AD)
+ if clientID == em.cfg.OAuthConfig.ClientID {
+ return true
+ }
+
+ // For future extensibility, could accept other registered client IDs
+ // But for Azure AD integration, we primarily use the configured client ID
+
+ return false
+}
+
+// generateSessionToken generates a secure random session token
+func (em *EndpointManager) generateSessionToken() (string, error) {
+ bytes := make([]byte, 32)
+ if _, err := rand.Read(bytes); err != nil {
+ return "", err
+ }
+ return base64.URLEncoding.EncodeToString(bytes), nil
+}
+
+// tokenHandler handles OAuth 2.0 token endpoint requests (Authorization Code exchange)
+func (em *EndpointManager) tokenHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("OAuth DEBUG: Received token endpoint request: %s %s", r.Method, r.URL.Path)
+
+ // Set CORS headers for all requests
+ em.setCORSHeaders(w, r)
+
+ // Handle preflight OPTIONS request
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodPost {
+ log.Printf("OAuth ERROR: Invalid method %s for token endpoint, only POST allowed", r.Method)
+ em.writeErrorResponse(w, "invalid_request", "Only POST method is allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Parse form data
+ if err := r.ParseForm(); err != nil {
+ log.Printf("OAuth ERROR: Failed to parse form data: %v", err)
+ em.writeErrorResponse(w, "invalid_request", "Failed to parse form data", http.StatusBadRequest)
+ return
+ }
+
+ // Validate grant type
+ grantType := r.FormValue("grant_type")
+ if grantType != "authorization_code" {
+ log.Printf("OAuth ERROR: Unsupported grant type: %s", grantType)
+ em.writeErrorResponse(w, "unsupported_grant_type", fmt.Sprintf("Unsupported grant type: %s", grantType), http.StatusBadRequest)
+ return
+ }
+
+ // Extract required parameters
+ code := r.FormValue("code")
+ clientID := r.FormValue("client_id")
+ redirectURI := r.FormValue("redirect_uri")
+ codeVerifier := r.FormValue("code_verifier") // PKCE parameter
+
+ if code == "" {
+ log.Printf("OAuth ERROR: Missing authorization code in token request")
+ em.writeErrorResponse(w, "invalid_request", "Missing authorization code", http.StatusBadRequest)
+ return
+ }
+
+ if clientID == "" {
+ log.Printf("OAuth ERROR: Missing client_id in token request")
+ em.writeErrorResponse(w, "invalid_request", "Missing client_id", http.StatusBadRequest)
+ return
+ }
+
+ if redirectURI == "" {
+ log.Printf("OAuth ERROR: Missing redirect_uri in token request")
+ em.writeErrorResponse(w, "invalid_request", "Missing redirect_uri", http.StatusBadRequest)
+ return
+ }
+
+ // Enforce PKCE code_verifier for OAuth 2.1 compliance
+ if codeVerifier == "" {
+ log.Printf("OAuth ERROR: Missing PKCE code_verifier (required for OAuth 2.1)")
+ em.writeErrorResponse(w, "invalid_request", "PKCE code_verifier is required", http.StatusBadRequest)
+ return
+ }
+
+ // Validate client ID (accept both configured and dynamically registered clients)
+ if !em.isValidClientID(clientID) {
+ log.Printf("OAuth ERROR: Invalid client_id: %s", clientID)
+ em.writeErrorResponse(w, "invalid_client", "Invalid client_id", http.StatusBadRequest)
+ return
+ }
+
+ // Validate redirect URI for security
+ if err := em.validateRedirectURI(redirectURI); err != nil {
+ log.Printf("OAuth ERROR: Redirect URI validation failed in token endpoint: %v", err)
+ em.writeErrorResponse(w, "invalid_request", "Invalid redirect_uri", http.StatusBadRequest)
+ return
+ }
+
+ // Extract scope from the token request (MCP client should send the same scope)
+ requestedScope := r.FormValue("scope")
+ if requestedScope == "" {
+ // Fallback to server required scopes if not provided
+ requestedScope = strings.Join(em.cfg.OAuthConfig.RequiredScopes, " ")
+ }
+
+ log.Printf("OAuth DEBUG: Exchanging authorization code for access token with Azure AD, scope: %s", requestedScope)
+
+ // Exchange authorization code for access token with Azure AD
+ tokenResponse, err := em.exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, requestedScope)
+ if err != nil {
+ log.Printf("OAuth ERROR: Token exchange with Azure AD failed: %v", err)
+ em.writeErrorResponse(w, "invalid_grant", fmt.Sprintf("Authorization code exchange failed: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Return token response
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "no-store")
+ w.Header().Set("Pragma", "no-cache")
+
+ if err := json.NewEncoder(w).Encode(tokenResponse); err != nil {
+ log.Printf("OAuth ERROR: Failed to encode token response: %v", err)
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ return
+ }
+ }
+}
+
+// exchangeCodeForTokenDirect exchanges authorization code for access token directly with Azure AD
+func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, scope string) (*TokenResponse, error) {
+ // Prepare token exchange request to Azure AD
+ tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.cfg.OAuthConfig.TenantID)
+
+ // Validate URL for security
+ if err := validateAzureADURL(tokenURL); err != nil {
+ return nil, fmt.Errorf("invalid token URL: %w", err)
+ }
+
+ // Prepare form data
+ data := url.Values{}
+ data.Set("grant_type", "authorization_code")
+ data.Set("client_id", em.cfg.OAuthConfig.ClientID)
+ data.Set("code", code)
+ data.Set("redirect_uri", redirectURI)
+ data.Set("scope", scope) // Use the scope provided by the client
+
+ // Add PKCE code_verifier if present
+ if codeVerifier != "" {
+ data.Set("code_verifier", codeVerifier)
+ log.Printf("Including PKCE code_verifier in Azure AD token request")
+ } else {
+ log.Printf("No PKCE code_verifier provided - this may cause PKCE verification to fail")
+ }
+
+ // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests
+ // It uses scope-based resource identification instead
+ // For MCP compliance, we handle resource binding through audience validation
+ log.Printf("Azure AD token request with scope: %s", scope)
+
+ // Make token exchange request to Azure AD
+ resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above
+ if err != nil {
+ return nil, fmt.Errorf("token exchange request failed: %w", err)
+ }
+ defer func() {
+ if err := resp.Body.Close(); err != nil {
+ log.Printf("Failed to close response body: %v", err)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
+ }
+
+ // Parse token response
+ var tokenResponse TokenResponse
+ if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
+ }
+
+ log.Printf("Token exchange successful: access_token received (length: %d)", len(tokenResponse.AccessToken))
+
+ return &tokenResponse, nil
+}
diff --git a/internal/auth/oauth/endpoints_test.go b/internal/auth/oauth/endpoints_test.go
new file mode 100644
index 0000000..f457d5c
--- /dev/null
+++ b/internal/auth/oauth/endpoints_test.go
@@ -0,0 +1,601 @@
+package oauth
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Azure/aks-mcp/internal/auth"
+ "github.com/Azure/aks-mcp/internal/config"
+)
+
+// createTestConfig creates a test ConfigData with OAuth configuration
+func createTestConfig() *config.ConfigData {
+ cfg := config.NewConfig()
+ cfg.Host = "127.0.0.1"
+ cfg.Port = 8000
+ cfg.OAuthConfig = &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ RedirectURIs: []string{"http://127.0.0.1:8000/oauth/callback", "http://localhost:8000/oauth/callback"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: false,
+ ValidateAudience: false,
+ ExpectedAudience: "https://management.azure.com/",
+ },
+ }
+ return cfg
+}
+
+func TestEndpointManager_RegisterEndpoints(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ mux := http.NewServeMux()
+ manager.RegisterEndpoints(mux)
+
+ // Test that endpoints are registered by making requests
+ testCases := []struct {
+ method string
+ path string
+ status int
+ }{
+ {"GET", "/.well-known/oauth-protected-resource", http.StatusOK},
+ {"GET", "/.well-known/oauth-authorization-server", http.StatusInternalServerError}, // Will fail without real Azure AD
+ {"POST", "/oauth/register", http.StatusBadRequest}, // Missing required data
+ {"POST", "/oauth/introspect", http.StatusBadRequest}, // Missing token param
+ {"GET", "/oauth/callback", http.StatusBadRequest}, // Missing required params
+ {"GET", "/health", http.StatusOK},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.method+" "+tc.path, func(t *testing.T) {
+ req := httptest.NewRequest(tc.method, tc.path, nil)
+ w := httptest.NewRecorder()
+
+ mux.ServeHTTP(w, req)
+
+ if w.Code != tc.status {
+ t.Errorf("Expected status %d for %s %s, got %d", tc.status, tc.method, tc.path, w.Code)
+ }
+ })
+ }
+}
+
+func TestProtectedResourceMetadataEndpoint(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.protectedResourceMetadataHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+
+ var metadata ProtectedResourceMetadata
+ if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ expectedAuthServer := "http://example.com"
+ if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServer {
+ t.Errorf("Expected auth server %s, got %v", expectedAuthServer, metadata.AuthorizationServers)
+ }
+
+ if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" {
+ t.Errorf("Expected scopes %v, got %v", cfg.OAuthConfig.RequiredScopes, metadata.ScopesSupported)
+ }
+}
+
+func TestClientRegistrationEndpoint(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test valid registration request
+ registrationRequest := map[string]interface{}{
+ "redirect_uris": []string{"http://localhost:3000/callback"},
+ "token_endpoint_auth_method": "none",
+ "grant_types": []string{"authorization_code"},
+ "response_types": []string{"code"},
+ "scope": "https://management.azure.com/.default",
+ "client_name": "Test Client",
+ }
+
+ reqBody, _ := json.Marshal(registrationRequest)
+ req := httptest.NewRequest("POST", "/oauth/register", strings.NewReader(string(reqBody)))
+ req.Header.Set("Content-Type", "application/json")
+
+ w := httptest.NewRecorder()
+ handler := manager.clientRegistrationHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusCreated {
+ t.Errorf("Expected status 201, got %d", w.Code)
+ }
+
+ var response map[string]interface{}
+ if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ if response["client_id"] == "" {
+ t.Error("Expected client_id in response")
+ }
+
+ redirectURIs, ok := response["redirect_uris"].([]interface{})
+ if !ok || len(redirectURIs) != 1 {
+ t.Errorf("Expected redirect URIs in response")
+ }
+}
+
+func TestTokenIntrospectionEndpoint(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test with valid token (since JWT validation is disabled, any token works)
+ // Note: Must use a token that looks like a JWT (has dots) to pass initial format checks
+ req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader("token=header.payload.signature"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ w := httptest.NewRecorder()
+ handler := manager.tokenIntrospectionHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+
+ var response map[string]interface{}
+ if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ if active, ok := response["active"].(bool); !ok || !active {
+ t.Error("Expected active token")
+ }
+}
+
+func TestTokenIntrospectionEndpointMissingToken(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test without token parameter
+ req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader(""))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ w := httptest.NewRecorder()
+ handler := manager.tokenIntrospectionHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing token, got %d", w.Code)
+ }
+}
+
+func TestHealthEndpoint(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ req := httptest.NewRequest("GET", "/health", nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.healthHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+
+ var response map[string]interface{}
+ if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ if response["status"] != "healthy" {
+ t.Errorf("Expected status healthy, got %v", response["status"])
+ }
+
+ oauth, ok := response["oauth"].(map[string]interface{})
+ if !ok {
+ t.Error("Expected oauth object in response")
+ }
+
+ if oauth["enabled"] != true {
+ t.Errorf("Expected oauth enabled true, got %v", oauth["enabled"])
+ }
+}
+
+func TestValidateClientRegistration(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ tests := []struct {
+ name string
+ request map[string]interface{}
+ wantErr bool
+ }{
+ {
+ name: "valid request",
+ request: map[string]interface{}{
+ "redirect_uris": []string{"http://localhost:3000/callback"},
+ "grant_types": []string{"authorization_code"},
+ "response_types": []string{"code"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "missing redirect URIs",
+ request: map[string]interface{}{
+ "grant_types": []string{"authorization_code"},
+ "response_types": []string{"code"},
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid grant type",
+ request: map[string]interface{}{
+ "redirect_uris": []string{"http://localhost:3000/callback"},
+ "grant_types": []string{"client_credentials"},
+ "response_types": []string{"code"},
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid response type",
+ request: map[string]interface{}{
+ "redirect_uris": []string{"http://localhost:3000/callback"},
+ "grant_types": []string{"authorization_code"},
+ "response_types": []string{"token"},
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Convert test request to the expected struct format
+ req := &ClientRegistrationRequest{}
+
+ if redirectURIs, ok := tt.request["redirect_uris"].([]string); ok {
+ req.RedirectURIs = redirectURIs
+ }
+ if grantTypes, ok := tt.request["grant_types"].([]string); ok {
+ req.GrantTypes = grantTypes
+ }
+ if responseTypes, ok := tt.request["response_types"].([]string); ok {
+ req.ResponseTypes = responseTypes
+ }
+
+ err := manager.validateClientRegistration(req)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateClientRegistration() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestCallbackEndpointMissingCode(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test callback without authorization code
+ req := httptest.NewRequest("GET", "/oauth/callback?state=test-state", nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.callbackHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing code, got %d", w.Code)
+ }
+
+ // Check that response contains HTML error page
+ contentType := w.Header().Get("Content-Type")
+ if !strings.Contains(contentType, "text/html") {
+ t.Errorf("Expected HTML content type, got %s", contentType)
+ }
+
+ body := w.Body.String()
+ if !strings.Contains(body, "Missing authorization code") {
+ t.Error("Expected error message about missing authorization code")
+ }
+}
+
+func TestCallbackEndpointMissingState(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test callback without state parameter
+ req := httptest.NewRequest("GET", "/oauth/callback?code=test-code", nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.callbackHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing state, got %d", w.Code)
+ }
+
+ body := w.Body.String()
+ if !strings.Contains(body, "Missing state parameter") {
+ t.Error("Expected error message about missing state parameter")
+ }
+}
+
+func TestCallbackEndpointAuthError(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test callback with authorization error
+ req := httptest.NewRequest("GET", "/oauth/callback?error=access_denied&error_description=User%20denied%20access", nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.callbackHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for auth error, got %d", w.Code)
+ }
+
+ body := w.Body.String()
+ if !strings.Contains(body, "Authorization failed") {
+ t.Error("Expected error message about authorization failure")
+ }
+ if !strings.Contains(body, "access_denied") {
+ t.Error("Expected specific error code in response")
+ }
+}
+
+func TestCallbackEndpointMethodNotAllowed(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ // Test callback with POST method (should only accept GET)
+ req := httptest.NewRequest("POST", "/oauth/callback", nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.callbackHandler()
+ handler(w, req)
+
+ if w.Code != http.StatusMethodNotAllowed {
+ t.Errorf("Expected status 405 for POST method, got %d", w.Code)
+ }
+}
+
+func TestValidateRedirectURI(t *testing.T) {
+ cfg := createTestConfig()
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ tests := []struct {
+ name string
+ redirectURI string
+ wantErr bool
+ }{
+ {
+ name: "valid redirect URI - 127.0.0.1",
+ redirectURI: "http://127.0.0.1:8000/oauth/callback",
+ wantErr: false,
+ },
+ {
+ name: "valid redirect URI - localhost",
+ redirectURI: "http://localhost:8000/oauth/callback",
+ wantErr: false,
+ },
+ {
+ name: "invalid redirect URI - wrong port",
+ redirectURI: "http://127.0.0.1:9000/oauth/callback",
+ wantErr: true,
+ },
+ {
+ name: "invalid redirect URI - wrong path",
+ redirectURI: "http://127.0.0.1:8000/oauth/malicious",
+ wantErr: true,
+ },
+ {
+ name: "invalid redirect URI - external domain",
+ redirectURI: "http://malicious.com:8000/oauth/callback",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := manager.validateRedirectURI(tt.redirectURI)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateRedirectURI() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+
+ // Test with empty redirect URIs configuration
+ cfgEmpty := createTestConfig()
+ cfgEmpty.OAuthConfig.RedirectURIs = []string{}
+ managerEmpty := NewEndpointManager(provider, cfgEmpty)
+
+ err := managerEmpty.validateRedirectURI("http://127.0.0.1:8000/oauth/callback")
+ if err == nil {
+ t.Error("Expected error when no redirect URIs are configured")
+ }
+}
+
+// TestAuthorizationProxyRedirectURIValidation tests the authorization endpoint redirect URI validation
+func TestCORSHeaders(t *testing.T) {
+ cfg := createTestConfig()
+ cfg.OAuthConfig.AllowedOrigins = []string{"http://localhost:6274"}
+
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ tests := []struct {
+ name string
+ origin string
+ expectCORSSet bool
+ expectOrigin string
+ }{
+ {
+ name: "allowed origin",
+ origin: "http://localhost:6274",
+ expectCORSSet: true,
+ expectOrigin: "http://localhost:6274",
+ },
+ {
+ name: "disallowed origin",
+ origin: "http://malicious.com",
+ expectCORSSet: false,
+ expectOrigin: "",
+ },
+ {
+ name: "no origin header",
+ origin: "",
+ expectCORSSet: false,
+ expectOrigin: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest("GET", "/health", nil)
+ if tt.origin != "" {
+ req.Header.Set("Origin", tt.origin)
+ }
+ w := httptest.NewRecorder()
+
+ handler := manager.healthHandler()
+ handler(w, req)
+
+ corsOrigin := w.Header().Get("Access-Control-Allow-Origin")
+ if tt.expectCORSSet {
+ if corsOrigin != tt.expectOrigin {
+ t.Errorf("Expected CORS origin %s, got %s", tt.expectOrigin, corsOrigin)
+ }
+ } else {
+ if corsOrigin != "" {
+ t.Errorf("Expected no CORS headers, but got Access-Control-Allow-Origin: %s", corsOrigin)
+ }
+ }
+ })
+ }
+}
+
+func TestAuthorizationProxyRedirectURIValidation(t *testing.T) {
+ cfg := createTestConfig()
+ provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig)
+ manager := NewEndpointManager(provider, cfg)
+
+ tests := []struct {
+ name string
+ redirectURI string
+ expectError bool
+ expectCode int
+ }{
+ {
+ name: "missing redirect_uri",
+ redirectURI: "",
+ expectError: true,
+ expectCode: http.StatusBadRequest,
+ },
+ {
+ name: "valid redirect_uri - 127.0.0.1",
+ redirectURI: "http://127.0.0.1:8000/oauth/callback",
+ expectError: false,
+ expectCode: http.StatusFound, // Should redirect to Azure AD
+ },
+ {
+ name: "valid redirect_uri - localhost",
+ redirectURI: "http://localhost:8000/oauth/callback",
+ expectError: false,
+ expectCode: http.StatusFound, // Should redirect to Azure AD
+ },
+ {
+ name: "invalid redirect_uri - wrong port",
+ redirectURI: "http://127.0.0.1:9000/oauth/callback",
+ expectError: true,
+ expectCode: http.StatusBadRequest,
+ },
+ {
+ name: "invalid redirect_uri - wrong path",
+ redirectURI: "http://127.0.0.1:8000/oauth/malicious",
+ expectError: true,
+ expectCode: http.StatusBadRequest,
+ },
+ {
+ name: "invalid redirect_uri - external domain",
+ redirectURI: "http://malicious.com:8000/oauth/callback",
+ expectError: true,
+ expectCode: http.StatusBadRequest,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create request URL with redirect_uri parameter if provided
+ requestURL := "/oauth2/v2.0/authorize?response_type=code&client_id=test-client&code_challenge=test&code_challenge_method=S256&state=test"
+ if tt.redirectURI != "" {
+ requestURL += "&redirect_uri=" + tt.redirectURI
+ }
+
+ req := httptest.NewRequest("GET", requestURL, nil)
+ w := httptest.NewRecorder()
+
+ handler := manager.authorizationProxyHandler()
+ handler(w, req)
+
+ if tt.expectError {
+ if w.Code != tt.expectCode {
+ t.Errorf("Expected status code %d, got %d", tt.expectCode, w.Code)
+ }
+
+ // Check that error response contains helpful information
+ body := w.Body.String()
+ if !strings.Contains(body, "redirect_uri") {
+ t.Errorf("Error response should mention redirect_uri, got: %s", body)
+ }
+ } else {
+ if w.Code != tt.expectCode {
+ t.Errorf("Expected status code %d, got %d", tt.expectCode, w.Code)
+ }
+
+ // For successful cases, check redirect location contains expected parameters
+ location := w.Header().Get("Location")
+ if location == "" {
+ t.Errorf("Expected redirect location header, got empty")
+ }
+ if !strings.Contains(location, "login.microsoftonline.com") {
+ t.Errorf("Expected redirect to Azure AD, got: %s", location)
+ }
+ }
+ })
+ }
+}
diff --git a/internal/auth/oauth/middleware.go b/internal/auth/oauth/middleware.go
new file mode 100644
index 0000000..5170385
--- /dev/null
+++ b/internal/auth/oauth/middleware.go
@@ -0,0 +1,299 @@
+package oauth
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "strings"
+
+ "github.com/Azure/aks-mcp/internal/auth"
+)
+
+// contextKey is a custom type for context keys to avoid collisions
+type contextKey string
+
+const tokenInfoKey contextKey = "token_info"
+
+// AuthMiddleware handles OAuth authentication for HTTP requests
+type AuthMiddleware struct {
+ provider *AzureOAuthProvider
+ serverURL string
+}
+
+// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting
+func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter, r *http.Request) {
+ requestOrigin := r.Header.Get("Origin")
+
+ // Check if the request origin is in the allowed list
+ var allowedOrigin string
+ for _, allowed := range m.provider.config.AllowedOrigins {
+ if requestOrigin == allowed {
+ allowedOrigin = requestOrigin
+ break
+ }
+ }
+
+ // Only set CORS headers if origin is allowed
+ if allowedOrigin != "" {
+ w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version")
+ w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
+ w.Header().Set("Access-Control-Allow-Credentials", "false")
+ } else if requestOrigin != "" {
+ log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin)
+ }
+}
+
+// NewAuthMiddleware creates a new authentication middleware
+func NewAuthMiddleware(provider *AzureOAuthProvider, serverURL string) *AuthMiddleware {
+ return &AuthMiddleware{
+ provider: provider,
+ serverURL: serverURL,
+ }
+}
+
+// Middleware returns an HTTP middleware function for OAuth authentication
+func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
+ // Skip authentication for specific endpoints
+ if m.shouldSkipAuth(r) {
+ log.Printf("Skipping auth for path: %s\n", r.URL.Path)
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // Perform authentication
+ authResult := m.authenticateRequest(r)
+
+ if !authResult.Authenticated {
+ log.Printf("Authentication FAILED - handling error\n")
+ m.handleAuthError(w, r, authResult)
+ return
+ }
+
+ // Add token info to request context
+ ctx := context.WithValue(r.Context(), tokenInfoKey, authResult.TokenInfo)
+ r = r.WithContext(ctx)
+
+ next.ServeHTTP(w, r)
+ })
+}
+
+// shouldSkipAuth determines if authentication should be skipped for this request
+func (m *AuthMiddleware) shouldSkipAuth(r *http.Request) bool {
+ // Skip auth for OAuth metadata endpoints
+ path := r.URL.Path
+
+ skipPaths := []string{
+ "/.well-known/oauth-protected-resource",
+ "/.well-known/oauth-authorization-server",
+ "/.well-known/openid-configuration",
+ "/oauth2/v2.0/authorize",
+ "/oauth/register",
+ "/oauth/callback",
+ "/oauth2/v2.0/token",
+ "/oauth/introspect",
+ "/health",
+ "/ping",
+ }
+
+ for _, skipPath := range skipPaths {
+ if path == skipPath {
+ return true
+ }
+ }
+
+ return false
+}
+
+// authenticateRequest performs OAuth authentication on the request
+func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {
+ // Extract Bearer token from Authorization header
+ authHeader := r.Header.Get("Authorization")
+
+ if authHeader == "" {
+ log.Printf("OAuth DEBUG - Missing authorization header for %s %s\n", r.Method, r.URL.Path)
+ log.Printf("OAuth DEBUG - Request headers: %+v\n", r.Header)
+ return &auth.AuthResult{
+ Authenticated: false,
+ Error: "missing authorization header",
+ StatusCode: http.StatusUnauthorized,
+ }
+ }
+
+ // Check for Bearer token format
+ const bearerPrefix = "Bearer "
+ if !strings.HasPrefix(authHeader, bearerPrefix) {
+ log.Printf("FAILED - Invalid authorization header format (missing Bearer prefix)\n")
+ return &auth.AuthResult{
+ Authenticated: false,
+ Error: "invalid authorization header format",
+ StatusCode: http.StatusUnauthorized,
+ }
+ }
+
+ token := strings.TrimPrefix(authHeader, bearerPrefix)
+ if token == "" {
+ log.Printf("FAILED - Empty bearer token\n")
+ return &auth.AuthResult{
+ Authenticated: false,
+ Error: "empty bearer token",
+ StatusCode: http.StatusUnauthorized,
+ }
+ }
+
+ // Basic JWT structure validation
+ tokenParts := strings.Split(token, ".")
+ if len(tokenParts) != 3 {
+ log.Printf("FAILED - JWT structure validation (has %d parts, expected 3)\n", len(tokenParts))
+ return &auth.AuthResult{
+ Authenticated: false,
+ Error: "invalid JWT structure",
+ StatusCode: http.StatusUnauthorized,
+ }
+ }
+
+ // Validate the token
+ tokenInfo, err := m.provider.ValidateToken(r.Context(), token)
+ if err != nil {
+ log.Printf("FAILED - Provider token validation failed: %v\n", err)
+ return &auth.AuthResult{
+ Authenticated: false,
+ Error: fmt.Sprintf("token validation failed: %v", err),
+ StatusCode: http.StatusUnauthorized,
+ }
+ }
+
+ // Validate required scopes - strict enforcement for security
+ if !m.validateScopes(tokenInfo.Scope) {
+ log.Printf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes)
+ return &auth.AuthResult{
+ Authenticated: false,
+ Error: "insufficient scope",
+ StatusCode: http.StatusForbidden,
+ }
+ }
+
+ return &auth.AuthResult{
+ Authenticated: true,
+ TokenInfo: tokenInfo,
+ StatusCode: http.StatusOK,
+ }
+}
+
+// validateScopes checks if the token has required scopes
+func (m *AuthMiddleware) validateScopes(tokenScopes []string) bool {
+ requiredScopes := m.provider.config.RequiredScopes
+ if len(requiredScopes) == 0 {
+ return true // No scopes required
+ }
+
+ // Check if token has at least one required scope
+ for _, required := range requiredScopes {
+ if m.hasScopePermission(required, tokenScopes) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// hasScopePermission checks if the token scopes satisfy the required scope
+func (m *AuthMiddleware) hasScopePermission(requiredScope string, tokenScopes []string) bool {
+ // Direct scope match
+ for _, tokenScope := range tokenScopes {
+ if tokenScope == requiredScope {
+ return true
+ }
+ }
+
+ // Azure resource scope mapping
+ azureResourceMappings := map[string][]string{
+ "https://management.azure.com/.default": {
+ "user_impersonation",
+ "https://management.azure.com/user_impersonation",
+ "https://management.azure.com/.default",
+ "https://management.core.windows.net/",
+ "https://management.azure.com/",
+ },
+ "https://graph.microsoft.com/.default": {
+ "User.Read",
+ "https://graph.microsoft.com/User.Read",
+ },
+ }
+
+ if allowedScopes, exists := azureResourceMappings[requiredScope]; exists {
+ for _, allowedScope := range allowedScopes {
+ for _, tokenScope := range tokenScopes {
+ if tokenScope == allowedScope {
+ return true
+ }
+ }
+ }
+ }
+
+ return false
+}
+
+// handleAuthError handles authentication errors
+func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, authResult *auth.AuthResult) {
+ // Set CORS headers
+ m.setCORSHeaders(w, r)
+ w.Header().Set("Content-Type", "application/json")
+
+ // Add WWW-Authenticate header for 401 responses (RFC 9728 Section 5.1)
+ if authResult.StatusCode == http.StatusUnauthorized {
+ // Build the resource metadata URL
+ scheme := "http"
+ if r.TLS != nil {
+ scheme = "https"
+ }
+ host := r.Host
+ if host == "" {
+ host = r.URL.Host
+ }
+ serverURL := fmt.Sprintf("%s://%s", scheme, host)
+ resourceMetadataURL := fmt.Sprintf("%s/.well-known/oauth-protected-resource", serverURL)
+
+ // RFC 9728 compliant WWW-Authenticate header
+ wwwAuth := fmt.Sprintf(`Bearer realm="%s", resource_metadata="%s"`, serverURL, resourceMetadataURL)
+
+ // Add error information if available
+ if authResult.Error != "" {
+ wwwAuth += fmt.Sprintf(`, error="invalid_token", error_description="%s"`, authResult.Error)
+ }
+
+ w.Header().Set("WWW-Authenticate", wwwAuth)
+ }
+
+ w.WriteHeader(authResult.StatusCode)
+
+ errorResponse := map[string]interface{}{
+ "error": getOAuthErrorCode(authResult.StatusCode),
+ "error_description": authResult.Error,
+ }
+
+ if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
+ log.Printf("MIDDLEWARE ERROR: Failed to encode error response: %v\n", err)
+ } else {
+ log.Printf("MIDDLEWARE ERROR: Error response sent\n")
+ }
+}
+
+// getOAuthErrorCode returns appropriate OAuth error code for HTTP status
+func getOAuthErrorCode(statusCode int) string {
+ switch statusCode {
+ case http.StatusUnauthorized:
+ return "invalid_token"
+ case http.StatusForbidden:
+ return "insufficient_scope"
+ case http.StatusBadRequest:
+ return "invalid_request"
+ default:
+ return "server_error"
+ }
+}
diff --git a/internal/auth/oauth/middleware_test.go b/internal/auth/oauth/middleware_test.go
new file mode 100644
index 0000000..358cb1e
--- /dev/null
+++ b/internal/auth/oauth/middleware_test.go
@@ -0,0 +1,253 @@
+package oauth
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Azure/aks-mcp/internal/auth"
+)
+
+// GetTokenInfo extracts token information from request context (test helper)
+func GetTokenInfo(r *http.Request) (*auth.TokenInfo, bool) {
+ tokenInfo, ok := r.Context().Value(tokenInfoKey).(*auth.TokenInfo)
+ return tokenInfo, ok
+}
+
+func TestAuthMiddleware(t *testing.T) {
+ // Create test config with minimal required scopes for testing
+ // Note: We cannot test with empty RequiredScopes because the OAuth configuration
+ // validation now requires at least one scope to be specified. This is intentional
+ // to prevent security misconfigurations in production environments.
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: false,
+ ValidateAudience: false,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+ middleware := NewAuthMiddleware(provider, "http://localhost:8000")
+
+ // Create a test handler
+ testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ if _, err := w.Write([]byte("success")); err != nil {
+ t.Errorf("Failed to write test response: %v", err)
+ }
+ })
+
+ wrappedHandler := middleware.Middleware(testHandler)
+
+ tests := []struct {
+ name string
+ authHeader string
+ expectedStatus int
+ path string
+ }{
+ {
+ name: "valid bearer token",
+ authHeader: "Bearer header.payload.signature",
+ expectedStatus: http.StatusOK,
+ path: "/test",
+ },
+ {
+ name: "missing authorization header",
+ authHeader: "",
+ expectedStatus: http.StatusUnauthorized,
+ path: "/test",
+ },
+ {
+ name: "invalid token format",
+ authHeader: "InvalidFormat",
+ expectedStatus: http.StatusUnauthorized,
+ path: "/test",
+ },
+ {
+ name: "non-bearer token",
+ authHeader: "Basic dXNlcjpwYXNz",
+ expectedStatus: http.StatusUnauthorized,
+ path: "/test",
+ },
+ {
+ name: "skip auth for metadata endpoint",
+ authHeader: "",
+ expectedStatus: http.StatusOK,
+ path: "/.well-known/oauth-protected-resource",
+ },
+ {
+ name: "skip auth for health endpoint",
+ authHeader: "",
+ expectedStatus: http.StatusOK,
+ path: "/health",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest("GET", tt.path, nil)
+ if tt.authHeader != "" {
+ req.Header.Set("Authorization", tt.authHeader)
+ }
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ if w.Code != tt.expectedStatus {
+ t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
+ }
+
+ // Check WWW-Authenticate header for 401 responses
+ if w.Code == http.StatusUnauthorized {
+ wwwAuth := w.Header().Get("WWW-Authenticate")
+ if wwwAuth == "" {
+ t.Error("Expected WWW-Authenticate header for 401 response")
+ }
+ }
+ })
+ }
+}
+
+func TestAuthMiddlewareContextPropagation(t *testing.T) {
+ // Note: We cannot test with empty RequiredScopes because the OAuth configuration
+ // validation now requires at least one scope to be specified.
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: false,
+ ValidateAudience: false,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+ middleware := NewAuthMiddleware(provider, "http://localhost:8000")
+
+ testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Check if token info is available in context
+ tokenInfo, ok := GetTokenInfo(r)
+ if !ok {
+ t.Error("Token info not found in context")
+ return
+ }
+
+ if tokenInfo.AccessToken != "header.payload.signature" {
+ t.Errorf("Expected token header.payload.signature, got %s", tokenInfo.AccessToken)
+ }
+
+ w.WriteHeader(http.StatusOK)
+ })
+
+ wrappedHandler := middleware.Middleware(testHandler)
+
+ req := httptest.NewRequest("GET", "/test", nil)
+ req.Header.Set("Authorization", "Bearer header.payload.signature")
+
+ w := httptest.NewRecorder()
+ wrappedHandler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", w.Code)
+ }
+}
+
+func TestShouldSkipAuth(t *testing.T) {
+ // Note: We cannot test with empty RequiredScopes because the OAuth configuration
+ // validation now requires at least one scope to be specified.
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"}, // Minimal scope for testing
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: false,
+ ValidateAudience: false,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+ middleware := NewAuthMiddleware(provider, "http://localhost:8000")
+
+ tests := []struct {
+ path string
+ expected bool
+ }{
+ {"/.well-known/oauth-protected-resource", true},
+ {"/.well-known/oauth-authorization-server", true},
+ {"/.well-known/openid-configuration", true},
+ {"/oauth2/v2.0/authorize", true},
+ {"/oauth/register", true},
+ {"/oauth/callback", true},
+ {"/oauth2/v2.0/token", true},
+ {"/oauth/introspect", true},
+ {"/health", true},
+ {"/ping", true},
+ {"/test", false},
+ {"/mcp", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.path, func(t *testing.T) {
+ req := httptest.NewRequest("GET", tt.path, nil)
+ result := middleware.shouldSkipAuth(req)
+ if result != tt.expected {
+ t.Errorf("Expected %v for path %s, got %v", tt.expected, tt.path, result)
+ }
+ })
+ }
+}
+
+func TestGetTokenInfo(t *testing.T) {
+ // Test with valid token info
+ tokenInfo := &auth.TokenInfo{
+ AccessToken: "test-token",
+ TokenType: "Bearer",
+ Subject: "user123",
+ }
+
+ ctx := context.WithValue(context.Background(), tokenInfoKey, tokenInfo)
+ req := httptest.NewRequest("GET", "/test", nil)
+ req = req.WithContext(ctx)
+
+ retrievedTokenInfo, ok := GetTokenInfo(req)
+ if !ok {
+ t.Error("Expected to find token info in context")
+ }
+
+ if retrievedTokenInfo.AccessToken != "test-token" {
+ t.Errorf("Expected access token test-token, got %s", retrievedTokenInfo.AccessToken)
+ }
+
+ // Test without token info
+ req = httptest.NewRequest("GET", "/test", nil)
+ _, ok = GetTokenInfo(req)
+ if ok {
+ t.Error("Expected not to find token info in empty context")
+ }
+}
diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go
new file mode 100644
index 0000000..d6ec9bc
--- /dev/null
+++ b/internal/auth/oauth/provider.go
@@ -0,0 +1,523 @@
+package oauth
+
+import (
+ "context"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "math/big"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Azure/aks-mcp/internal/auth"
+ internalConfig "github.com/Azure/aks-mcp/internal/config"
+ "github.com/golang-jwt/jwt/v5"
+)
+
+// AzureOAuthProvider implements OAuth authentication for Azure AD
+type AzureOAuthProvider struct {
+ config *auth.OAuthConfig
+ httpClient *http.Client
+ keyCache *keyCache
+ enableCache bool
+}
+
+// keyCache caches Azure AD signing keys
+type keyCache struct {
+ keys map[string]*rsa.PublicKey
+ expiresAt time.Time
+ mu sync.RWMutex
+}
+
+// AzureADMetadata represents Azure AD OAuth metadata
+type AzureADMetadata struct {
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ TokenEndpoint string `json:"token_endpoint"`
+ RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
+ JWKSUri string `json:"jwks_uri"`
+ ScopesSupported []string `json:"scopes_supported"`
+ ResponseTypesSupported []string `json:"response_types_supported"`
+ GrantTypesSupported []string `json:"grant_types_supported"`
+ SubjectTypesSupported []string `json:"subject_types_supported"`
+ TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
+ CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
+}
+
+// ProtectedResourceMetadata represents MCP protected resource metadata (RFC 9728 compliant)
+type ProtectedResourceMetadata struct {
+ AuthorizationServers []string `json:"authorization_servers"`
+ Resource string `json:"resource"`
+ ScopesSupported []string `json:"scopes_supported"`
+}
+
+// ClientRegistrationRequest represents OAuth 2.0 Dynamic Client Registration request (RFC 7591)
+type ClientRegistrationRequest struct {
+ RedirectURIs []string `json:"redirect_uris"`
+ TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
+ GrantTypes []string `json:"grant_types"`
+ ResponseTypes []string `json:"response_types"`
+ ClientName string `json:"client_name"`
+ ClientURI string `json:"client_uri"`
+ Scope string `json:"scope"`
+}
+
+// NewAzureOAuthProvider creates a new Azure OAuth provider
+func NewAzureOAuthProvider(config *auth.OAuthConfig) (*AzureOAuthProvider, error) {
+ if err := config.Validate(); err != nil {
+ return nil, fmt.Errorf("invalid OAuth config: %w", err)
+ }
+
+ return &AzureOAuthProvider{
+ config: config,
+ enableCache: internalConfig.EnableCache, // Use config constant for cache control
+ httpClient: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ keyCache: &keyCache{
+ keys: make(map[string]*rsa.PublicKey),
+ },
+ }, nil
+}
+
+// GetProtectedResourceMetadata returns OAuth 2.0 Protected Resource Metadata (RFC 9728)
+func (p *AzureOAuthProvider) GetProtectedResourceMetadata(serverURL string) (*ProtectedResourceMetadata, error) {
+ // For MCP compliance, point to our local authorization server proxy
+ // which properly advertises PKCE support
+ parsedURL, err := url.Parse(serverURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid server URL: %v", err)
+ }
+
+ // Use the same scheme and host as the server URL
+ authServerURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
+
+ // RFC 9728 requires the resource field to identify this MCP server
+ return &ProtectedResourceMetadata{
+ AuthorizationServers: []string{authServerURL},
+ Resource: serverURL, // Required by MCP spec
+ ScopesSupported: p.config.RequiredScopes,
+ }, nil
+}
+
+// GetAuthorizationServerMetadata returns OAuth 2.0 Authorization Server Metadata (RFC 8414)
+func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (*AzureADMetadata, error) {
+ metadataURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", p.config.TenantID)
+ log.Printf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL)
+
+ resp, err := p.httpClient.Get(metadataURL)
+ if err != nil {
+ log.Printf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err)
+ return nil, fmt.Errorf("failed to fetch metadata from %s: %w", metadataURL, err)
+ }
+ defer func() {
+ if err := resp.Body.Close(); err != nil {
+ log.Printf("Failed to close response body: %v", err)
+ }
+ }()
+
+ if resp.StatusCode == http.StatusNotFound {
+ log.Printf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID)
+ return nil, fmt.Errorf("tenant ID '%s' not found (HTTP 404). Please verify your Azure AD tenant ID is correct", p.config.TenantID)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ log.Printf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body))
+ return nil, fmt.Errorf("metadata endpoint returned status %d: %s", resp.StatusCode, string(body))
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Printf("OAuth ERROR: Failed to read metadata response: %v", err)
+ return nil, fmt.Errorf("failed to read metadata response: %w", err)
+ }
+
+ var metadata AzureADMetadata
+ if err := json.Unmarshal(body, &metadata); err != nil {
+ log.Printf("OAuth ERROR: Failed to parse metadata JSON: %v", err)
+ return nil, fmt.Errorf("failed to parse metadata: %w", err)
+ }
+
+ log.Printf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported)
+
+ // Ensure grant_types_supported is populated for MCP Inspector compatibility
+ if len(metadata.GrantTypesSupported) == 0 {
+ log.Printf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)")
+ metadata.GrantTypesSupported = []string{"authorization_code", "refresh_token"}
+ }
+
+ // Ensure response_types_supported is populated for MCP Inspector compatibility
+ if len(metadata.ResponseTypesSupported) == 0 {
+ log.Printf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)")
+ metadata.ResponseTypesSupported = []string{"code"}
+ }
+
+ // Ensure subject_types_supported is populated for MCP Inspector compatibility
+ if len(metadata.SubjectTypesSupported) == 0 {
+ log.Printf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)")
+ metadata.SubjectTypesSupported = []string{"public"}
+ }
+
+ // Ensure token_endpoint_auth_methods_supported is populated for MCP Inspector compatibility
+ if len(metadata.TokenEndpointAuthMethodsSupported) == 0 {
+ log.Printf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)")
+ metadata.TokenEndpointAuthMethodsSupported = []string{"none"}
+ }
+
+ // Add S256 code challenge method support (Azure AD supports this but may not advertise it)
+ // MCP specification requires S256 support, so we always ensure it's present
+ log.Printf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)")
+ metadata.CodeChallengeMethodsSupported = []string{"S256"}
+
+ // Azure AD v2.0 has limited support for RFC 8707 Resource Indicators
+ // - Authorization endpoint: doesn't support resource parameter
+ // - Token endpoint: doesn't support resource parameter
+ // - Uses scope-based resource identification instead
+ // Our proxy handles MCP resource parameter translation
+ parsedURL, err := url.Parse(serverURL)
+ if err == nil {
+ // If the server URL includes /mcp path, include it in the proxy endpoint
+ proxyPath := "/oauth2/v2.0/authorize"
+ tokenPath := "/oauth2/v2.0/token" // #nosec G101 -- This is an OAuth endpoint path, not credentials
+ registrationPath := "/oauth/register"
+ proxyAuthURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, proxyPath)
+ tokenURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, tokenPath)
+ registrationURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, registrationPath)
+
+ metadata.AuthorizationEndpoint = proxyAuthURL
+ metadata.TokenEndpoint = tokenURL
+ // Add dynamic client registration endpoint
+ metadata.RegistrationEndpoint = registrationURL
+ }
+
+ log.Printf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v",
+ metadata.GrantTypesSupported, metadata.ResponseTypesSupported, metadata.CodeChallengeMethodsSupported)
+
+ return &metadata, nil
+}
+
+// ValidateToken validates an OAuth access token
+func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString string) (*auth.TokenInfo, error) {
+ // JWTs have three parts (header.payload.signature) separated by two dots.
+ const jwtExpectedDotCount = 2
+
+ dotCount := strings.Count(tokenString, ".")
+ if dotCount != jwtExpectedDotCount {
+ return nil, fmt.Errorf("invalid JWT token format: expected 3 parts separated by dots, got %d dots", dotCount)
+ }
+
+ // SECURITY WARNING: JWT validation bypass - for development and testing ONLY
+ // ValidateJWT should ALWAYS be true in production environments
+ // This bypass creates a significant security vulnerability if enabled in production
+ if !p.config.TokenValidation.ValidateJWT {
+ log.Printf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing")
+ return &auth.TokenInfo{
+ AccessToken: tokenString,
+ TokenType: "Bearer",
+ ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration
+ Scope: p.config.RequiredScopes, // Use configured scopes
+ Subject: "unknown", // Cannot extract without parsing
+ Audience: []string{p.config.TokenValidation.ExpectedAudience},
+ Issuer: fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.config.TenantID),
+ Claims: make(map[string]interface{}),
+ }, nil
+ }
+
+ // Parse and validate JWT token
+
+ // Parse token structure and check expiration
+ parserUnsafe := jwt.NewParser(jwt.WithoutClaimsValidation())
+ tokenUnsafe, _, err := parserUnsafe.ParseUnverified(tokenString, jwt.MapClaims{})
+ if err != nil {
+ return nil, fmt.Errorf("invalid token structure: %w", err)
+ }
+
+ // Check claims and expiration
+ if claims, ok := tokenUnsafe.Claims.(jwt.MapClaims); ok {
+ if exp, ok := claims["exp"].(float64); ok {
+ expTime := time.Unix(int64(exp), 0)
+ if time.Now().After(expTime) {
+ return nil, fmt.Errorf("token expired at %v", expTime)
+ }
+ }
+ }
+
+ // JWT signature validation
+ parser := jwt.NewParser(jwt.WithoutClaimsValidation())
+ token, err := parser.ParseWithClaims(tokenString, jwt.MapClaims{}, p.getKeyFunc)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse token: %w", err)
+ }
+
+ if !token.Valid {
+ return nil, fmt.Errorf("invalid token")
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ return nil, fmt.Errorf("invalid token claims")
+ }
+
+ // Validate issuer - with Azure Management API scope, we should get v2.0 format
+ issuer, ok := claims["iss"].(string)
+ if !ok {
+ return nil, fmt.Errorf("missing issuer claim")
+ }
+
+ expectedIssuerV2 := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.config.TenantID)
+ expectedIssuerV1 := fmt.Sprintf("https://sts.windows.net/%s/", p.config.TenantID)
+
+ if issuer != expectedIssuerV2 && issuer != expectedIssuerV1 {
+ return nil, fmt.Errorf("invalid issuer: expected %s (preferred) or %s (fallback), got %s", expectedIssuerV2, expectedIssuerV1, issuer)
+ }
+
+ // Azure AD may return v1.0 or v2.0 issuer format depending on token scope
+
+ // Validate audience and resource binding
+ if p.config.TokenValidation.ValidateAudience {
+ if err := p.validateAudience(claims); err != nil {
+ return nil, err
+ }
+ }
+
+ // Extract token information
+ tokenInfo := &auth.TokenInfo{
+ AccessToken: tokenString,
+ TokenType: "Bearer",
+ Claims: claims,
+ }
+
+ // Extract subject
+ if sub, ok := claims["sub"].(string); ok {
+ tokenInfo.Subject = sub
+ }
+
+ // Extract audience
+ if aud, ok := claims["aud"].(string); ok {
+ tokenInfo.Audience = []string{aud}
+ } else if audSlice, ok := claims["aud"].([]interface{}); ok {
+ for _, a := range audSlice {
+ if audStr, ok := a.(string); ok {
+ tokenInfo.Audience = append(tokenInfo.Audience, audStr)
+ }
+ }
+ }
+
+ // Extract scope from Azure AD token
+ // Check for 'scp' claim (Azure AD v2.0)
+ if scp, ok := claims["scp"].(string); ok {
+ tokenInfo.Scope = strings.Split(scp, " ")
+ } else if scope, ok := claims["scope"].(string); ok {
+ // Check for 'scope' claim (alternative)
+ tokenInfo.Scope = strings.Split(scope, " ")
+ }
+
+ // Check for 'roles' claim (Azure AD app roles)
+ if roles, ok := claims["roles"].([]interface{}); ok {
+ for _, role := range roles {
+ if roleStr, ok := role.(string); ok {
+ tokenInfo.Scope = append(tokenInfo.Scope, roleStr)
+ }
+ }
+ }
+
+ // Extract expiration
+ if exp, ok := claims["exp"].(float64); ok {
+ tokenInfo.ExpiresAt = time.Unix(int64(exp), 0)
+ }
+
+ // Set issuer
+ tokenInfo.Issuer = issuer
+
+ return tokenInfo, nil
+}
+
+// validateAudience validates the audience claim and resource binding (RFC 8707)
+func (p *AzureOAuthProvider) validateAudience(claims jwt.MapClaims) error {
+ expectedAudience := p.config.TokenValidation.ExpectedAudience
+
+ // Normalize expected audience - remove trailing slash for comparison
+ normalizedExpected := strings.TrimSuffix(expectedAudience, "/")
+
+ // Check single audience
+ if aud, ok := claims["aud"].(string); ok {
+ normalizedAud := strings.TrimSuffix(aud, "/")
+ if normalizedAud == normalizedExpected || aud == p.config.ClientID {
+ return nil
+ }
+ return fmt.Errorf("invalid audience: expected %s or %s, got %s", expectedAudience, p.config.ClientID, aud)
+ }
+
+ // Check audience array
+ if audSlice, ok := claims["aud"].([]interface{}); ok {
+ for _, a := range audSlice {
+ if audStr, ok := a.(string); ok {
+ normalizedAud := strings.TrimSuffix(audStr, "/")
+ if normalizedAud == normalizedExpected || audStr == p.config.ClientID {
+ return nil
+ }
+ }
+ }
+ return fmt.Errorf("invalid audience: expected %s or %s in audience list", expectedAudience, p.config.ClientID)
+ }
+
+ return fmt.Errorf("missing audience claim")
+}
+
+// getKeyFunc returns a function to retrieve JWT signing keys
+func (p *AzureOAuthProvider) getKeyFunc(token *jwt.Token) (interface{}, error) {
+ // Validate signing method
+ if token.Method.Alg() != "RS256" {
+ return nil, fmt.Errorf("unexpected signing method: expected RS256, got %v", token.Method.Alg())
+ }
+
+ // Also verify it's an RSA method
+ if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
+ return nil, fmt.Errorf("signing method is not RSA: %T", token.Method)
+ }
+
+ // Get key ID from token header
+ kid, ok := token.Header["kid"].(string)
+ if !ok {
+ return nil, fmt.Errorf("missing key ID in token header")
+ }
+
+ // Extract issuer from token to determine the correct JWKS endpoint
+ var issuer string
+ if claims, ok := token.Claims.(jwt.MapClaims); ok {
+ if iss, ok := claims["iss"].(string); ok {
+ issuer = iss
+ }
+ }
+
+ // Get the public key for this key ID using the appropriate issuer
+ key, err := p.getPublicKey(kid, issuer)
+ if err != nil {
+ log.Printf("PUBLIC KEY RETRIEVAL FAILED: %s\n", err)
+ return nil, fmt.Errorf("failed to get public key: %w", err)
+ }
+
+ return key, nil
+}
+
+// getPublicKey retrieves and caches Azure AD public keys
+func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.PublicKey, error) {
+ // Generate cache key based on both kid and issuer to avoid conflicts between v1.0 and v2.0 keys
+ cacheKey := fmt.Sprintf("%s_%s", kid, issuer)
+
+ // Check cache first if caching is enabled
+ if p.enableCache {
+ p.keyCache.mu.RLock()
+ if key, exists := p.keyCache.keys[cacheKey]; exists && time.Now().Before(p.keyCache.expiresAt) {
+ p.keyCache.mu.RUnlock()
+ return key, nil
+ }
+ p.keyCache.mu.RUnlock()
+ }
+
+ // With Azure Management API scope, we should always get v2.0 format tokens
+ // Force using v2.0 JWKS endpoint for consistency
+ jwksURL := fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/v2.0/keys", p.config.TenantID)
+
+ resp, err := p.httpClient.Get(jwksURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch JWKS from %s: %w", jwksURL, err)
+ }
+ defer func() {
+ if err := resp.Body.Close(); err != nil {
+ log.Printf("Failed to close response body: %v", err)
+ }
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read JWKS response: %w", err)
+ }
+
+ var jwks struct {
+ Keys []struct {
+ Kid string `json:"kid"`
+ N string `json:"n"`
+ E string `json:"e"`
+ Kty string `json:"kty"`
+ } `json:"keys"`
+ }
+
+ if err := json.Unmarshal(body, &jwks); err != nil {
+ return nil, fmt.Errorf("failed to parse JWKS: %w", err)
+ }
+
+ log.Printf("JWKS Contains %d keys, searching for kid=%s\n", len(jwks.Keys), kid)
+
+ // Parse keys and find the target key
+ var targetKey *rsa.PublicKey
+ var foundKeyIds []string
+
+ for _, key := range jwks.Keys {
+ foundKeyIds = append(foundKeyIds, key.Kid)
+
+ if key.Kty == "RSA" && key.Kid == kid {
+ pubKey, err := parseRSAPublicKey(key.N, key.E)
+ if err != nil {
+ log.Printf("JWKS Failed to parse RSA key %s: %v\n", key.Kid, err)
+ continue
+ }
+ targetKey = pubKey
+ break
+ }
+ }
+
+ // Cache the retrieved key and return it (only if caching is enabled)
+ if targetKey != nil {
+ if p.enableCache {
+ p.keyCache.mu.Lock()
+ if p.keyCache.keys == nil {
+ p.keyCache.keys = make(map[string]*rsa.PublicKey)
+ }
+ p.keyCache.keys[cacheKey] = targetKey
+ p.keyCache.expiresAt = time.Now().Add(24 * time.Hour) // Cache for 24 hours
+ p.keyCache.mu.Unlock()
+ }
+ return targetKey, nil
+ }
+
+ return nil, fmt.Errorf("key with ID %s not found in JWKS (available: %v)", kid, foundKeyIds)
+}
+
+// parseRSAPublicKey parses RSA public key from JWK format
+func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
+ // Decode base64url-encoded modulus
+ nBytes, err := base64.RawURLEncoding.DecodeString(nStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode modulus: %w", err)
+ }
+
+ // Decode base64url-encoded exponent
+ eBytes, err := base64.RawURLEncoding.DecodeString(eStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode exponent: %w", err)
+ }
+
+ // Convert bytes to big integers
+ n := new(big.Int).SetBytes(nBytes)
+ e := new(big.Int).SetBytes(eBytes)
+
+ // Create RSA public key
+ pubKey := &rsa.PublicKey{
+ N: n,
+ E: int(e.Int64()),
+ }
+
+ return pubKey, nil
+}
diff --git a/internal/auth/oauth/provider_test.go b/internal/auth/oauth/provider_test.go
new file mode 100644
index 0000000..657f2ba
--- /dev/null
+++ b/internal/auth/oauth/provider_test.go
@@ -0,0 +1,388 @@
+package oauth
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Azure/aks-mcp/internal/auth"
+)
+
+func TestNewAzureOAuthProvider(t *testing.T) {
+ tests := []struct {
+ name string
+ config *auth.OAuthConfig
+ wantErr bool
+ }{
+ {
+ name: "valid config should create provider",
+ config: &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: true,
+ ValidateAudience: true,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "invalid config should fail",
+ config: &auth.OAuthConfig{
+ Enabled: true,
+ // Missing required fields
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider, err := NewAzureOAuthProvider(tt.config)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("NewAzureOAuthProvider() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !tt.wantErr && provider == nil {
+ t.Error("NewAzureOAuthProvider() returned nil provider")
+ }
+ })
+ }
+}
+
+func TestGetProtectedResourceMetadata(t *testing.T) {
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant-id",
+ ClientID: "test-client-id",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: true,
+ ValidateAudience: true,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ serverURL := "http://localhost:8000"
+ metadata, err := provider.GetProtectedResourceMetadata(serverURL)
+ if err != nil {
+ t.Fatalf("GetProtectedResourceMetadata() error = %v", err)
+ }
+
+ expectedAuthServer := "http://localhost:8000"
+ if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServer {
+ t.Errorf("Expected authorization server %s, got %v", expectedAuthServer, metadata.AuthorizationServers)
+ }
+
+ // Note: AzureADProtectedResourceMetadata doesn't include a Resource field.
+ // The resource URL is implied by the context of the request endpoint.
+
+ if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" {
+ t.Errorf("Expected scopes %v, got %v", config.RequiredScopes, metadata.ScopesSupported)
+ }
+}
+
+func TestGetAuthorizationServerMetadataWithDefaults(t *testing.T) {
+ // Create a mock Azure AD metadata endpoint that's missing some fields
+ // This simulates the case where Azure AD doesn't provide all required fields
+ mockMetadata := AzureADMetadata{
+ Issuer: "https://login.microsoftonline.com/test-tenant/v2.0",
+ AuthorizationEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize",
+ TokenEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token",
+ JWKSUri: "https://login.microsoftonline.com/test-tenant/discovery/v2.0/keys",
+ ScopesSupported: []string{"openid", "profile", "email"},
+ // Intentionally omit GrantTypesSupported, ResponseTypesSupported, etc.
+ // to test our default value logic
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(mockMetadata); err != nil {
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ }
+ }))
+ defer server.Close()
+
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: true,
+ ValidateAudience: true,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ // Override the HTTP client to use our test server
+ provider.httpClient = &http.Client{
+ Transport: &roundTripperFunc{
+ fn: func(req *http.Request) (*http.Response, error) {
+ // Redirect all requests to our test server
+ req.URL.Scheme = "http"
+ req.URL.Host = server.URL[7:] // Remove "http://"
+ req.URL.Path = "/"
+ return http.DefaultTransport.RoundTrip(req)
+ },
+ },
+ }
+
+ metadata, err := provider.GetAuthorizationServerMetadata(server.URL)
+ if err != nil {
+ t.Fatalf("GetAuthorizationServerMetadata() error = %v", err)
+ }
+
+ // Verify that default values were populated for missing fields
+ expectedGrantTypes := []string{"authorization_code", "refresh_token"}
+ if len(metadata.GrantTypesSupported) != len(expectedGrantTypes) {
+ t.Errorf("Expected %d grant types, got %d", len(expectedGrantTypes), len(metadata.GrantTypesSupported))
+ }
+ for i, expected := range expectedGrantTypes {
+ if i >= len(metadata.GrantTypesSupported) || metadata.GrantTypesSupported[i] != expected {
+ t.Errorf("Expected grant type %s at index %d, got %v", expected, i, metadata.GrantTypesSupported)
+ }
+ }
+
+ expectedResponseTypes := []string{"code"}
+ if len(metadata.ResponseTypesSupported) != len(expectedResponseTypes) {
+ t.Errorf("Expected %d response types, got %d", len(expectedResponseTypes), len(metadata.ResponseTypesSupported))
+ }
+ if len(metadata.ResponseTypesSupported) > 0 && metadata.ResponseTypesSupported[0] != "code" {
+ t.Errorf("Expected response type 'code', got %s", metadata.ResponseTypesSupported[0])
+ }
+
+ expectedSubjectTypes := []string{"public"}
+ if len(metadata.SubjectTypesSupported) != len(expectedSubjectTypes) {
+ t.Errorf("Expected %d subject types, got %d", len(expectedSubjectTypes), len(metadata.SubjectTypesSupported))
+ }
+ if len(metadata.SubjectTypesSupported) > 0 && metadata.SubjectTypesSupported[0] != "public" {
+ t.Errorf("Expected subject type 'public', got %s", metadata.SubjectTypesSupported[0])
+ }
+
+ expectedTokenEndpointAuthMethods := []string{"none"}
+ if len(metadata.TokenEndpointAuthMethodsSupported) != len(expectedTokenEndpointAuthMethods) {
+ t.Errorf("Expected %d auth methods, got %d", len(expectedTokenEndpointAuthMethods), len(metadata.TokenEndpointAuthMethodsSupported))
+ }
+ if len(metadata.TokenEndpointAuthMethodsSupported) > 0 && metadata.TokenEndpointAuthMethodsSupported[0] != "none" {
+ t.Errorf("Expected auth method 'none', got %s", metadata.TokenEndpointAuthMethodsSupported[0])
+ }
+
+ // Verify that PKCE is properly configured
+ expectedCodeChallengeMethods := []string{"S256"}
+ if len(metadata.CodeChallengeMethodsSupported) != len(expectedCodeChallengeMethods) {
+ t.Errorf("Expected %d code challenge methods, got %d", len(expectedCodeChallengeMethods), len(metadata.CodeChallengeMethodsSupported))
+ }
+ if len(metadata.CodeChallengeMethodsSupported) > 0 && metadata.CodeChallengeMethodsSupported[0] != "S256" {
+ t.Errorf("Expected code challenge method 'S256', got %s", metadata.CodeChallengeMethodsSupported[0])
+ }
+}
+
+func TestGetAuthorizationServerMetadata(t *testing.T) {
+ // Create a mock Azure AD metadata endpoint
+ mockMetadata := AzureADMetadata{
+ Issuer: "https://login.microsoftonline.com/test-tenant/v2.0",
+ AuthorizationEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize",
+ TokenEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token",
+ JWKSUri: "https://login.microsoftonline.com/test-tenant/discovery/v2.0/keys",
+ ScopesSupported: []string{"openid", "profile", "email"},
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(mockMetadata); err != nil {
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ }
+ }))
+ defer server.Close()
+
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: true,
+ ValidateAudience: true,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ // Override the HTTP client to use our test server
+ provider.httpClient = &http.Client{
+ Transport: &roundTripperFunc{
+ fn: func(req *http.Request) (*http.Response, error) {
+ // Redirect all requests to our test server
+ req.URL.Scheme = "http"
+ req.URL.Host = server.URL[7:] // Remove "http://"
+ req.URL.Path = "/"
+ return http.DefaultTransport.RoundTrip(req)
+ },
+ },
+ }
+
+ metadata, err := provider.GetAuthorizationServerMetadata(server.URL)
+ if err != nil {
+ t.Fatalf("GetAuthorizationServerMetadata() error = %v", err)
+ }
+
+ if metadata.Issuer != mockMetadata.Issuer {
+ t.Errorf("Expected issuer %s, got %s", mockMetadata.Issuer, metadata.Issuer)
+ }
+
+ expectedAuthEndpoint := fmt.Sprintf("%s/oauth2/v2.0/authorize", server.URL)
+ if metadata.AuthorizationEndpoint != expectedAuthEndpoint {
+ t.Errorf("Expected auth endpoint %s, got %s", expectedAuthEndpoint, metadata.AuthorizationEndpoint)
+ }
+}
+
+func TestValidateTokenWithoutJWT(t *testing.T) {
+ // SECURITY WARNING: This test verifies the JWT validation bypass functionality
+ // ValidateJWT=false should ONLY be used in development/testing environments
+ // This functionality should NEVER be enabled in production
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: false, // Disable JWT validation
+ ValidateAudience: false,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ ctx := context.Background()
+ // Use a token that looks like a JWT to pass initial format checks
+ testToken := "header.payload.signature"
+ tokenInfo, err := provider.ValidateToken(ctx, testToken)
+ if err != nil {
+ t.Fatalf("ValidateToken() error = %v", err)
+ }
+
+ if tokenInfo.AccessToken != testToken {
+ t.Errorf("Expected access token %s, got %s", testToken, tokenInfo.AccessToken)
+ }
+
+ if tokenInfo.TokenType != "Bearer" {
+ t.Errorf("Expected token type Bearer, got %s", tokenInfo.TokenType)
+ }
+}
+
+func TestValidateAudience(t *testing.T) {
+ config := &auth.OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant",
+ ClientID: "test-client-id",
+ RequiredScopes: []string{"https://management.azure.com/.default"},
+ TokenValidation: auth.TokenValidationConfig{
+ ValidateJWT: true,
+ ValidateAudience: true,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: 5 * time.Minute,
+ ClockSkew: 1 * time.Minute,
+ },
+ }
+
+ provider, err := NewAzureOAuthProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ claims map[string]interface{}
+ wantErr bool
+ }{
+ {
+ name: "valid audience string",
+ claims: map[string]interface{}{
+ "aud": "https://management.azure.com/",
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid client ID audience",
+ claims: map[string]interface{}{
+ "aud": "test-client-id",
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid audience array",
+ claims: map[string]interface{}{
+ "aud": []interface{}{"https://management.azure.com/", "other-aud"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "invalid audience",
+ claims: map[string]interface{}{
+ "aud": "invalid-audience",
+ },
+ wantErr: true,
+ },
+ {
+ name: "missing audience",
+ claims: map[string]interface{}{
+ "sub": "user123",
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := provider.validateAudience(tt.claims)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateAudience() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+// roundTripperFunc is a helper type for creating custom HTTP transports in tests
+type roundTripperFunc struct {
+ fn func(*http.Request) (*http.Response, error)
+}
+
+func (f *roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return f.fn(req)
+}
diff --git a/internal/auth/types.go b/internal/auth/types.go
new file mode 100644
index 0000000..d89cca6
--- /dev/null
+++ b/internal/auth/types.go
@@ -0,0 +1,140 @@
+package auth
+
+import (
+ "fmt"
+ "time"
+)
+
+// OAuthConfig represents OAuth configuration for AKS-MCP
+type OAuthConfig struct {
+ // Enable OAuth authentication
+ Enabled bool `json:"enabled"`
+
+ // Azure AD tenant ID
+ TenantID string `json:"tenant_id"`
+
+ // Azure AD application (client) ID
+ ClientID string `json:"client_id"`
+
+ // Required OAuth scopes for accessing AKS-MCP
+ RequiredScopes []string `json:"required_scopes"`
+
+ // Allowed redirect URIs for OAuth callback
+ RedirectURIs []string `json:"redirect_uris"`
+
+ // Allowed CORS origins for OAuth endpoints (for security, wildcard "*" should be avoided)
+ AllowedOrigins []string `json:"allowed_origins"`
+
+ // Token validation settings
+ TokenValidation TokenValidationConfig `json:"token_validation"`
+}
+
+// TokenValidationConfig represents token validation configuration
+type TokenValidationConfig struct {
+ // SECURITY CRITICAL: Enable JWT token validation
+ // Setting this to false creates a security vulnerability - for development/testing ONLY
+ // MUST be true in production environments
+ ValidateJWT bool `json:"validate_jwt"`
+
+ // Enable audience validation
+ ValidateAudience bool `json:"validate_audience"`
+
+ // Expected audience for tokens
+ ExpectedAudience string `json:"expected_audience"`
+
+ // Token cache TTL
+ CacheTTL time.Duration `json:"cache_ttl"`
+
+ // Clock skew tolerance for token validation
+ ClockSkew time.Duration `json:"clock_skew"`
+}
+
+// TokenInfo represents validated token information
+type TokenInfo struct {
+ // Access token
+ AccessToken string `json:"access_token"`
+
+ // Token type (usually "Bearer")
+ TokenType string `json:"token_type"`
+
+ // Token expiration time
+ ExpiresAt time.Time `json:"expires_at"`
+
+ // Token scope
+ Scope []string `json:"scope"`
+
+ // Subject (user ID)
+ Subject string `json:"subject"`
+
+ // Audience
+ Audience []string `json:"audience"`
+
+ // Issuer
+ Issuer string `json:"issuer"`
+
+ // Additional claims
+ Claims map[string]interface{} `json:"claims"`
+}
+
+// AuthResult represents the result of authentication
+type AuthResult struct {
+ // Whether authentication was successful
+ Authenticated bool `json:"authenticated"`
+
+ // Token information (if authenticated)
+ TokenInfo *TokenInfo `json:"token_info,omitempty"`
+
+ // Error message (if authentication failed)
+ Error string `json:"error,omitempty"`
+
+ // HTTP status code to return
+ StatusCode int `json:"status_code"`
+}
+
+// Default OAuth configuration values
+const (
+ DefaultTokenCacheTTL = 5 * time.Minute
+ DefaultClockSkew = 1 * time.Minute
+ DefaultExpectedAudience = "https://management.azure.com"
+ AzureADScope = "https://management.azure.com/.default"
+)
+
+// NewDefaultOAuthConfig creates a default OAuth configuration
+func NewDefaultOAuthConfig() *OAuthConfig {
+ return &OAuthConfig{
+ Enabled: false,
+ // Use Azure Management API scope to get v2.0 format tokens
+ // This ensures we get v2.0 issuer format which works with v2.0 JWKS endpoints
+ RequiredScopes: []string{AzureADScope}, // "https://management.azure.com/.default"
+ // RedirectURIs will be populated dynamically based on host/port configuration
+ RedirectURIs: []string{},
+ TokenValidation: TokenValidationConfig{
+ ValidateJWT: true, // SECURITY CRITICAL: Always true in production
+ ValidateAudience: true, // Re-enabled with correct audience
+ ExpectedAudience: DefaultExpectedAudience, // "https://management.azure.com"
+ CacheTTL: DefaultTokenCacheTTL,
+ ClockSkew: DefaultClockSkew,
+ },
+ }
+}
+
+// Validate validates the OAuth configuration
+func (cfg *OAuthConfig) Validate() error {
+ if !cfg.Enabled {
+ return nil
+ }
+
+ if cfg.TenantID == "" {
+ return fmt.Errorf("tenant_id is required when OAuth is enabled")
+ }
+
+ if cfg.ClientID == "" {
+ return fmt.Errorf("client_id is required when OAuth is enabled")
+ }
+
+ // if len(cfg.RequiredScopes) == 0 {
+ // return fmt.Errorf("at least one required scope must be specified")
+ // }
+
+ return nil
+}
diff --git a/internal/auth/types_test.go b/internal/auth/types_test.go
new file mode 100644
index 0000000..d6becde
--- /dev/null
+++ b/internal/auth/types_test.go
@@ -0,0 +1,183 @@
+package auth
+
+import (
+ "os"
+ "testing"
+ "time"
+)
+
+func TestOAuthConfigValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ config *OAuthConfig
+ wantErr bool
+ }{
+ {
+ name: "disabled OAuth should pass validation",
+ config: &OAuthConfig{
+ Enabled: false,
+ },
+ wantErr: false,
+ },
+ {
+ name: "enabled OAuth with missing tenant ID should fail",
+ config: &OAuthConfig{
+ Enabled: true,
+ ClientID: "test-client-id",
+ RequiredScopes: []string{"scope1"},
+ },
+ wantErr: true,
+ },
+ {
+ name: "enabled OAuth with missing client ID should fail",
+ config: &OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant-id",
+ RequiredScopes: []string{"scope1"},
+ },
+ wantErr: true,
+ },
+ {
+ name: "enabled OAuth with empty scopes should pass",
+ config: &OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant-id",
+ ClientID: "test-client-id",
+ RequiredScopes: []string{},
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid enabled OAuth config should pass",
+ config: &OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant-id",
+ ClientID: "test-client-id",
+ RequiredScopes: []string{"scope1"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid enabled OAuth config with full token validation should pass",
+ config: &OAuthConfig{
+ Enabled: true,
+ TenantID: "test-tenant-id",
+ ClientID: "test-client-id",
+ RequiredScopes: []string{"scope1"},
+ TokenValidation: TokenValidationConfig{
+ ValidateJWT: true,
+ ValidateAudience: true,
+ ExpectedAudience: "https://management.azure.com/",
+ CacheTTL: DefaultTokenCacheTTL,
+ ClockSkew: DefaultClockSkew,
+ },
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := tt.config.Validate()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("OAuthConfig.Validate() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestNewDefaultOAuthConfig(t *testing.T) {
+ config := NewDefaultOAuthConfig()
+
+ if config.Enabled {
+ t.Error("Default config should have OAuth disabled")
+ }
+
+ if len(config.RequiredScopes) != 1 || config.RequiredScopes[0] != AzureADScope {
+ t.Errorf("Default config should have Azure AD scope, got %v", config.RequiredScopes)
+ }
+
+ if !config.TokenValidation.ValidateJWT {
+ t.Error("Default config should enable JWT validation for production")
+ }
+
+ if !config.TokenValidation.ValidateAudience {
+ t.Error("Default config should enable audience validation for security")
+ }
+
+ if config.TokenValidation.ExpectedAudience != DefaultExpectedAudience {
+ t.Errorf("Default config should have correct expected audience, got %s", config.TokenValidation.ExpectedAudience)
+ }
+
+ if config.TokenValidation.CacheTTL != DefaultTokenCacheTTL {
+ t.Errorf("Default config should have correct cache TTL, got %v", config.TokenValidation.CacheTTL)
+ }
+
+ if config.TokenValidation.ClockSkew != DefaultClockSkew {
+ t.Errorf("Default config should have correct clock skew, got %v", config.TokenValidation.ClockSkew)
+ }
+}
+
+func TestOAuthConfigConstants(t *testing.T) {
+ if DefaultTokenCacheTTL != 5*time.Minute {
+ t.Errorf("DefaultTokenCacheTTL should be 5 minutes, got %v", DefaultTokenCacheTTL)
+ }
+
+ if DefaultClockSkew != 1*time.Minute {
+ t.Errorf("DefaultClockSkew should be 1 minute, got %v", DefaultClockSkew)
+ }
+
+ if DefaultExpectedAudience != "https://management.azure.com" {
+ t.Errorf("DefaultExpectedAudience should be Azure management, got %s", DefaultExpectedAudience)
+ }
+
+ if AzureADScope != "https://management.azure.com/.default" {
+ t.Errorf("AzureADScope should be Azure management default, got %s", AzureADScope)
+ }
+}
+
+func TestOAuthConfigEnvironmentVariables(t *testing.T) {
+ // Test that environment variables are respected
+ oldTenantID := os.Getenv("AZURE_TENANT_ID")
+ oldClientID := os.Getenv("AZURE_CLIENT_ID")
+
+ defer func() {
+ if err := os.Setenv("AZURE_TENANT_ID", oldTenantID); err != nil {
+ t.Logf("Failed to restore AZURE_TENANT_ID: %v", err)
+ }
+ if err := os.Setenv("AZURE_CLIENT_ID", oldClientID); err != nil {
+ t.Logf("Failed to restore AZURE_CLIENT_ID: %v", err)
+ }
+ }()
+
+ if err := os.Setenv("AZURE_TENANT_ID", "env-tenant-id"); err != nil {
+ t.Fatalf("Failed to set AZURE_TENANT_ID: %v", err)
+ }
+ if err := os.Setenv("AZURE_CLIENT_ID", "env-client-id"); err != nil {
+ t.Fatalf("Failed to set AZURE_CLIENT_ID: %v", err)
+ }
+
+ config := NewDefaultOAuthConfig()
+ config.Enabled = true
+
+ // Simulate the environment variable loading that happens in config parsing
+ if config.TenantID == "" {
+ config.TenantID = os.Getenv("AZURE_TENANT_ID")
+ }
+ if config.ClientID == "" {
+ config.ClientID = os.Getenv("AZURE_CLIENT_ID")
+ }
+
+ if config.TenantID != "env-tenant-id" {
+ t.Errorf("Expected tenant ID from environment, got %s", config.TenantID)
+ }
+
+ if config.ClientID != "env-client-id" {
+ t.Errorf("Expected client ID from environment, got %s", config.ClientID)
+ }
+
+ // Should pass validation with environment variables
+ if err := config.Validate(); err != nil {
+ t.Errorf("Config with environment variables should be valid, got error: %v", err)
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index a653188..b60f068 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -5,15 +5,37 @@ import (
"fmt"
"log"
"os"
+ "regexp"
"strings"
"time"
+ "github.com/Azure/aks-mcp/internal/auth"
"github.com/Azure/aks-mcp/internal/security"
"github.com/Azure/aks-mcp/internal/telemetry"
"github.com/Azure/aks-mcp/internal/version"
flag "github.com/spf13/pflag"
)
+// EnableCache controls whether caching is enabled globally
+// Cache is enabled by default for production performance
+// This affects both web cache headers and AzureOAuthProvider cache
+// Can be disabled via DISABLE_CACHE environment variable
+var EnableCache = os.Getenv("DISABLE_CACHE") != "true"
+
+// validateGUID validates that a value is in valid GUID format
+func validateGUID(value, name string) error {
+ if value == "" {
+ return nil // Empty values are allowed (will be handled by OAuth validation)
+ }
+
+ // GUID pattern: 8-4-4-4-12 hexadecimal digits with hyphens
+ guidRegex := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`)
+ if !guidRegex.MatchString(value) {
+ return fmt.Errorf("%s must be a valid GUID format (e.g., 12345678-1234-1234-1234-123456789abc), got: %s", name, value)
+ }
+ return nil
+}
+
// ConfigData holds the global configuration
type ConfigData struct {
// Command execution timeout in seconds
@@ -22,6 +44,8 @@ type ConfigData struct {
CacheTimeout time.Duration
// Security configuration
SecurityConfig *security.SecurityConfig
+ // OAuth configuration
+ OAuthConfig *auth.OAuthConfig
// Command-line specific options
Transport string
@@ -51,6 +75,7 @@ func NewConfig() *ConfigData {
Timeout: 60,
CacheTimeout: 1 * time.Minute,
SecurityConfig: security.NewSecurityConfig(),
+ OAuthConfig: auth.NewDefaultOAuthConfig(),
Transport: "stdio",
Port: 8000,
AccessLevel: "readonly",
@@ -66,9 +91,23 @@ func (cfg *ConfigData) ParseFlags() {
flag.StringVar(&cfg.Host, "host", "127.0.0.1", "Host to listen for the server (only used with transport sse or streamable-http)")
flag.IntVar(&cfg.Port, "port", 8000, "Port to listen for the server (only used with transport sse or streamable-http)")
flag.IntVar(&cfg.Timeout, "timeout", 600, "Timeout for command execution in seconds, default is 600s")
+
// Security settings
flag.StringVar(&cfg.AccessLevel, "access-level", "readonly", "Access level (readonly, readwrite, admin)")
+ // OAuth configuration
+ flag.BoolVar(&cfg.OAuthConfig.Enabled, "oauth-enabled", false, "Enable OAuth authentication")
+ flag.StringVar(&cfg.OAuthConfig.TenantID, "oauth-tenant-id", "", "Azure AD tenant ID for OAuth (fallback to AZURE_TENANT_ID env var)")
+ flag.StringVar(&cfg.OAuthConfig.ClientID, "oauth-client-id", "", "Azure AD client ID for OAuth (fallback to AZURE_CLIENT_ID env var)")
+
+ // OAuth redirect URIs configuration
+ additionalRedirectURIs := flag.String("oauth-redirects", "",
+ "Comma-separated list of additional OAuth redirect URIs (e.g. http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback)")
+
+ // OAuth CORS origins configuration
+ allowedCORSOrigins := flag.String("oauth-cors-origins", "",
+ "Comma-separated list of allowed CORS origins for OAuth endpoints (e.g. http://localhost:6274). If empty, no cross-origin requests are allowed for security")
+
// Kubernetes-specific settings
additionalTools := flag.String("additional-tools", "",
"Comma-separated list of additional Kubernetes tools to support (kubectl is always enabled). Available: helm,cilium,hubble")
@@ -113,6 +152,12 @@ func (cfg *ConfigData) ParseFlags() {
cfg.SecurityConfig.AccessLevel = cfg.AccessLevel
cfg.SecurityConfig.AllowedNamespaces = cfg.AllowNamespaces
+ // Parse OAuth configuration
+ if err := cfg.parseOAuthConfig(*additionalRedirectURIs, *allowedCORSOrigins); err != nil {
+ fmt.Printf("OAuth configuration error: %v\n", err)
+ os.Exit(1)
+ }
+
// Parse additional tools
if *additionalTools != "" {
tools := strings.Split(*additionalTools, ",")
@@ -122,6 +167,96 @@ func (cfg *ConfigData) ParseFlags() {
}
}
+// parseOAuthConfig parses OAuth-related command line arguments
+func (cfg *ConfigData) parseOAuthConfig(additionalRedirectURIs, allowedCORSOrigins string) error {
+ // Note: OAuth scopes are automatically configured to use "https://management.azure.com/.default"
+ // and are not configurable via command line per design
+
+ // Track configuration sources for logging
+ var tenantIDSource, clientIDSource string
+
+ // Load OAuth configuration from environment variables if not set via CLI
+ if cfg.OAuthConfig.TenantID == "" {
+ if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" {
+ cfg.OAuthConfig.TenantID = tenantID
+ tenantIDSource = "environment variable AZURE_TENANT_ID"
+ log.Printf("OAuth Config: Using tenant ID from environment variable AZURE_TENANT_ID")
+ }
+ } else {
+ tenantIDSource = "command line flag --oauth-tenant-id"
+ log.Printf("OAuth Config: Using tenant ID from command line flag --oauth-tenant-id")
+ }
+
+ if cfg.OAuthConfig.ClientID == "" {
+ if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" {
+ cfg.OAuthConfig.ClientID = clientID
+ clientIDSource = "environment variable AZURE_CLIENT_ID"
+ log.Printf("OAuth Config: Using client ID from environment variable AZURE_CLIENT_ID")
+ }
+ } else {
+ clientIDSource = "command line flag --oauth-client-id"
+ log.Printf("OAuth Config: Using client ID from command line flag --oauth-client-id")
+ }
+
+ // Validate GUID formats for tenant ID and client ID
+ if err := validateGUID(cfg.OAuthConfig.TenantID, "OAuth tenant ID"); err != nil {
+ return fmt.Errorf("invalid OAuth tenant ID from %s: %w", tenantIDSource, err)
+ }
+
+ if err := validateGUID(cfg.OAuthConfig.ClientID, "OAuth client ID"); err != nil {
+ return fmt.Errorf("invalid OAuth client ID from %s: %w", clientIDSource, err)
+ }
+
+ // Set redirect URIs based on configured host and port
+ if cfg.OAuthConfig.Enabled {
+ redirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", cfg.Host, cfg.Port)
+ cfg.OAuthConfig.RedirectURIs = []string{redirectURI}
+
+ // Add localhost variant if using 127.0.0.1
+ if cfg.Host == "127.0.0.1" {
+ localhostURI := fmt.Sprintf("http://localhost:%d/oauth/callback", cfg.Port)
+ cfg.OAuthConfig.RedirectURIs = append(cfg.OAuthConfig.RedirectURIs, localhostURI)
+ }
+
+ // Add additional redirect URIs from command line flag
+ if additionalRedirectURIs != "" {
+ additionalURIs := strings.Split(additionalRedirectURIs, ",")
+ for _, uri := range additionalURIs {
+ trimmedURI := strings.TrimSpace(uri)
+ if trimmedURI != "" {
+ cfg.OAuthConfig.RedirectURIs = append(cfg.OAuthConfig.RedirectURIs, trimmedURI)
+ }
+ }
+ }
+ }
+
+ // Parse allowed CORS origins for OAuth endpoints
+ if allowedCORSOrigins != "" {
+ log.Printf("OAuth Config: Setting allowed CORS origins from command line flag --oauth-cors-origins")
+ origins := strings.Split(allowedCORSOrigins, ",")
+ for _, origin := range origins {
+ trimmedOrigin := strings.TrimSpace(origin)
+ if trimmedOrigin != "" {
+ cfg.OAuthConfig.AllowedOrigins = append(cfg.OAuthConfig.AllowedOrigins, trimmedOrigin)
+ }
+ }
+ } else {
+ log.Printf("OAuth Config: No CORS origins configured - cross-origin requests will be blocked for security")
+ }
+
+ return nil
+}
+
+// ValidateConfig validates the configuration for incompatible settings
+func (cfg *ConfigData) ValidateConfig() error {
+ // Validate OAuth + transport compatibility
+ if cfg.OAuthConfig.Enabled && cfg.Transport == "stdio" {
+ return fmt.Errorf("OAuth authentication is not supported with stdio transport per MCP specification")
+ }
+
+ return nil
+}
+
// InitializeTelemetry initializes the telemetry service
func (cfg *ConfigData) InitializeTelemetry(ctx context.Context, serviceName, serviceVersion string) {
// Create telemetry configuration
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000..56ad2e4
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,239 @@
+package config
+
+import (
+ "testing"
+)
+
+func TestBasicOAuthConfig(t *testing.T) {
+ // Test basic OAuth configuration parsing with valid GUIDs
+ cfg := NewConfig()
+ cfg.OAuthConfig.Enabled = true
+ cfg.OAuthConfig.TenantID = "12345678-1234-1234-1234-123456789abc"
+ cfg.OAuthConfig.ClientID = "87654321-4321-4321-4321-cba987654321"
+
+ // Parse OAuth configuration
+ if err := cfg.parseOAuthConfig("", ""); err != nil {
+ t.Fatalf("Unexpected error in parseOAuthConfig: %v", err)
+ }
+
+ // Verify basic configuration is preserved
+ if !cfg.OAuthConfig.Enabled {
+ t.Error("Expected OAuth to be enabled")
+ }
+ if cfg.OAuthConfig.TenantID != "12345678-1234-1234-1234-123456789abc" {
+ t.Errorf("Expected tenant ID '12345678-1234-1234-1234-123456789abc', got %s", cfg.OAuthConfig.TenantID)
+ }
+ if cfg.OAuthConfig.ClientID != "87654321-4321-4321-4321-cba987654321" {
+ t.Errorf("Expected client ID '87654321-4321-4321-4321-cba987654321', got %s", cfg.OAuthConfig.ClientID)
+ }
+}
+
+func TestOAuthRedirectURIsConfig(t *testing.T) {
+ // Test OAuth redirect URIs configuration with additional URIs
+ cfg := NewConfig()
+ cfg.OAuthConfig.Enabled = true
+ cfg.Host = "127.0.0.1"
+ cfg.Port = 8081
+
+ // Test with additional redirect URIs
+ additionalRedirectURIs := "http://localhost:6274/oauth/callback,http://localhost:8080/oauth/callback"
+ if err := cfg.parseOAuthConfig(additionalRedirectURIs, ""); err != nil {
+ t.Fatalf("Unexpected error in parseOAuthConfig: %v", err)
+ }
+
+ // Should have default URIs plus additional ones
+ expectedURIs := []string{
+ "http://127.0.0.1:8081/oauth/callback",
+ "http://localhost:8081/oauth/callback",
+ "http://localhost:6274/oauth/callback",
+ "http://localhost:8080/oauth/callback",
+ }
+
+ if len(cfg.OAuthConfig.RedirectURIs) != len(expectedURIs) {
+ t.Errorf("Expected %d redirect URIs, got %d", len(expectedURIs), len(cfg.OAuthConfig.RedirectURIs))
+ }
+
+ for i, expected := range expectedURIs {
+ if i >= len(cfg.OAuthConfig.RedirectURIs) || cfg.OAuthConfig.RedirectURIs[i] != expected {
+ t.Errorf("Expected redirect URI '%s' at index %d, got '%s'", expected, i,
+ func() string {
+ if i < len(cfg.OAuthConfig.RedirectURIs) {
+ return cfg.OAuthConfig.RedirectURIs[i]
+ }
+ return "missing"
+ }())
+ }
+ }
+}
+
+func TestOAuthRedirectURIsEmptyAdditional(t *testing.T) {
+ // Test OAuth redirect URIs configuration without additional URIs
+ cfg := NewConfig()
+ cfg.OAuthConfig.Enabled = true
+ cfg.Host = "127.0.0.1"
+ cfg.Port = 8081
+
+ // Test with empty additional redirect URIs
+ if err := cfg.parseOAuthConfig("", ""); err != nil {
+ t.Fatalf("Unexpected error in parseOAuthConfig: %v", err)
+ }
+
+ // Should have only default URIs
+ expectedURIs := []string{
+ "http://127.0.0.1:8081/oauth/callback",
+ "http://localhost:8081/oauth/callback",
+ }
+
+ if len(cfg.OAuthConfig.RedirectURIs) != len(expectedURIs) {
+ t.Errorf("Expected %d redirect URIs, got %d", len(expectedURIs), len(cfg.OAuthConfig.RedirectURIs))
+ }
+
+ for i, expected := range expectedURIs {
+ if cfg.OAuthConfig.RedirectURIs[i] != expected {
+ t.Errorf("Expected redirect URI '%s' at index %d, got '%s'", expected, i, cfg.OAuthConfig.RedirectURIs[i])
+ }
+ }
+}
+
+func TestValidateGUID(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ fieldName string
+ wantErr bool
+ }{
+ {
+ name: "valid GUID",
+ value: "12345678-1234-1234-1234-123456789abc",
+ fieldName: "test field",
+ wantErr: false,
+ },
+ {
+ name: "valid GUID uppercase",
+ value: "12345678-1234-1234-1234-123456789ABC",
+ fieldName: "test field",
+ wantErr: false,
+ },
+ {
+ name: "empty value allowed",
+ value: "",
+ fieldName: "test field",
+ wantErr: false,
+ },
+ {
+ name: "invalid format - missing hyphens",
+ value: "123456781234123412341234567890ab",
+ fieldName: "test field",
+ wantErr: true,
+ },
+ {
+ name: "invalid format - wrong length",
+ value: "12345678-1234-1234-1234-123456789",
+ fieldName: "test field",
+ wantErr: true,
+ },
+ {
+ name: "invalid format - non-hex characters",
+ value: "12345678-1234-1234-1234-123456789abg",
+ fieldName: "test field",
+ wantErr: true,
+ },
+ {
+ name: "invalid format - extra hyphens",
+ value: "12345678-1234-1234-1234-1234-56789abc",
+ fieldName: "test field",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateGUID(tt.value, tt.fieldName)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("validateGUID() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr && err != nil {
+ // Verify error message contains the field name and value
+ errorMsg := err.Error()
+ if !contains(errorMsg, tt.fieldName) {
+ t.Errorf("Error message should contain field name '%s', got: %s", tt.fieldName, errorMsg)
+ }
+ if tt.value != "" && !contains(errorMsg, tt.value) {
+ t.Errorf("Error message should contain value '%s', got: %s", tt.value, errorMsg)
+ }
+ }
+ })
+ }
+}
+
+func TestOAuthGUIDValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ tenantID string
+ clientID string
+ wantErr bool
+ }{
+ {
+ name: "valid GUIDs",
+ tenantID: "12345678-1234-1234-1234-123456789abc",
+ clientID: "87654321-4321-4321-4321-cba987654321",
+ wantErr: false,
+ },
+ {
+ name: "empty values allowed",
+ tenantID: "",
+ clientID: "",
+ wantErr: false,
+ },
+ {
+ name: "invalid tenant ID",
+ tenantID: "invalid-tenant-id",
+ clientID: "87654321-4321-4321-4321-cba987654321",
+ wantErr: true,
+ },
+ {
+ name: "invalid client ID",
+ tenantID: "12345678-1234-1234-1234-123456789abc",
+ clientID: "invalid-client-id",
+ wantErr: true,
+ },
+ {
+ name: "both invalid",
+ tenantID: "invalid-tenant",
+ clientID: "invalid-client",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := NewConfig()
+ cfg.OAuthConfig.Enabled = true
+ cfg.OAuthConfig.TenantID = tt.tenantID
+ cfg.OAuthConfig.ClientID = tt.clientID
+ cfg.Host = "127.0.0.1"
+ cfg.Port = 8081
+
+ err := cfg.parseOAuthConfig("", "")
+ if (err != nil) != tt.wantErr {
+ t.Errorf("parseOAuthConfig() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ })
+ }
+}
+
+// contains is a helper function to check if a string contains a substring
+func contains(s, substr string) bool {
+ return len(substr) == 0 || (len(s) >= len(substr) && findSubstring(s, substr))
+}
+
+func findSubstring(s, substr string) bool {
+ for i := 0; i <= len(s)-len(substr); i++ {
+ if s[i:i+len(substr)] == substr {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/config/validator.go b/internal/config/validator.go
index 3499192..2f5214f 100644
--- a/internal/config/validator.go
+++ b/internal/config/validator.go
@@ -64,12 +64,22 @@ func (v *Validator) validateCli() bool {
return valid
}
+// validateConfig checks configuration compatibility
+func (v *Validator) validateConfig() bool {
+ if err := v.config.ValidateConfig(); err != nil {
+ v.errors = append(v.errors, err.Error())
+ return false
+ }
+ return true
+}
+
// Validate runs all validation checks
func (v *Validator) Validate() bool {
// Run all validation checks
validCli := v.validateCli()
+ validConfig := v.validateConfig()
- return validCli
+ return validCli && validConfig
}
// GetErrors returns all errors found during validation
diff --git a/internal/server/server.go b/internal/server/server.go
index 555ca37..39b4513 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -7,6 +7,7 @@ import (
"net/http"
"time"
+ "github.com/Azure/aks-mcp/internal/auth/oauth"
"github.com/Azure/aks-mcp/internal/azcli"
"github.com/Azure/aks-mcp/internal/azureclient"
"github.com/Azure/aks-mcp/internal/components/advisor"
@@ -36,6 +37,9 @@ type Service struct {
mcpServer *server.MCPServer
azClient *azureclient.AzureClient
azcliProcFactory func(timeout int) azcli.Proc
+ oauthProvider *oauth.AzureOAuthProvider
+ authMiddleware *oauth.AuthMiddleware
+ endpointManager *oauth.EndpointManager
}
// ServiceOption defines a function that configures the AKS MCP service
@@ -83,6 +87,14 @@ func (s *Service) initializeInfrastructure() error {
s.azClient = azClient
log.Println("Azure client initialized successfully")
+ // Initialize OAuth components if enabled and transport is not stdio
+ // OAuth is not supported with stdio transport per MCP specification
+ if s.cfg.OAuthConfig.Enabled && s.cfg.Transport != "stdio" {
+ if err := s.initializeOAuth(); err != nil {
+ return fmt.Errorf("failed to initialize OAuth: %w", err)
+ }
+ }
+
// Ensure Azure CLI exists and is logged in
if s.azcliProcFactory != nil {
// Use injected factory to create an azcli.Proc
@@ -114,6 +126,35 @@ func (s *Service) initializeInfrastructure() error {
return nil
}
+// initializeOAuth initializes OAuth authentication components
+func (s *Service) initializeOAuth() error {
+ log.Println("Initializing OAuth authentication...")
+
+ // Validate OAuth configuration
+ if err := s.cfg.OAuthConfig.Validate(); err != nil {
+ return fmt.Errorf("invalid OAuth configuration: %w", err)
+ }
+
+ // Create OAuth provider
+ provider, err := oauth.NewAzureOAuthProvider(s.cfg.OAuthConfig)
+ if err != nil {
+ return fmt.Errorf("failed to create OAuth provider: %w", err)
+ }
+ s.oauthProvider = provider
+
+ // Create server URL for OAuth metadata
+ serverURL := fmt.Sprintf("http://%s:%d", s.cfg.Host, s.cfg.Port)
+
+ // Create auth middleware
+ s.authMiddleware = oauth.NewAuthMiddleware(provider, serverURL)
+
+ // Create endpoint manager
+ s.endpointManager = oauth.NewEndpointManager(provider, s.cfg)
+
+ log.Printf("OAuth authentication initialized with tenant: %s", s.cfg.OAuthConfig.TenantID)
+ return nil
+}
+
// registerAllComponents registers all component tools organized by category
func (s *Service) registerAllComponents() {
// Azure Components
@@ -142,6 +183,15 @@ func (s *Service) registerPrompts() {
func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server {
mux := http.NewServeMux()
+ // Register OAuth endpoints if OAuth is enabled
+ if s.cfg.OAuthConfig.Enabled {
+ if s.endpointManager == nil {
+ log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization")
+ }
+ log.Println("Registering OAuth endpoints...")
+ s.endpointManager.RegisterEndpoints(mux)
+ }
+
// Handle all other paths with a helpful 404 response
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/mcp" {
@@ -159,6 +209,20 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server {
},
}
+ // Add OAuth endpoints to the response if enabled
+ if s.cfg.OAuthConfig.Enabled {
+ oauthEndpoints := map[string]string{
+ "oauth-metadata": "GET /.well-known/oauth-protected-resource - OAuth metadata",
+ "auth-server-metadata": "GET /.well-known/oauth-authorization-server - Authorization server metadata",
+ "client-registration": "POST /oauth/register - Dynamic client registration",
+ "token-introspection": "POST /oauth/introspect - Token introspection",
+ "health": "GET /health - Health check",
+ }
+ for k, v := range oauthEndpoints {
+ response["endpoints"].(map[string]string)[k] = v
+ }
+ }
+
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
@@ -177,9 +241,28 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server {
func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, addr string) *http.Server {
mux := http.NewServeMux()
- // Register SSE and Message handlers
- mux.Handle("/sse", sseServer.SSEHandler())
- mux.Handle("/message", sseServer.MessageHandler())
+ // Register OAuth endpoints if OAuth is enabled
+ if s.cfg.OAuthConfig.Enabled {
+ if s.endpointManager == nil {
+ log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization")
+ }
+ log.Println("Registering OAuth endpoints for SSE server...")
+ s.endpointManager.RegisterEndpoints(mux)
+ }
+
+ // Register SSE and Message handlers with authentication if enabled
+ if s.cfg.OAuthConfig.Enabled {
+ if s.authMiddleware == nil {
+ log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization")
+ }
+ // Apply authentication middleware to SSE and Message endpoints
+ mux.Handle("/sse", s.authMiddleware.Middleware(sseServer.SSEHandler()))
+ mux.Handle("/message", s.authMiddleware.Middleware(sseServer.MessageHandler()))
+ } else {
+ // Register without authentication
+ mux.Handle("/sse", sseServer.SSEHandler())
+ mux.Handle("/message", sseServer.MessageHandler())
+ }
// Handle all other paths with a helpful 404 response
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
@@ -196,6 +279,26 @@ func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer,
},
}
+ // Add OAuth endpoints and authentication info if enabled
+ if s.cfg.OAuthConfig.Enabled {
+ response["authentication"] = map[string]interface{}{
+ "required": true,
+ "type": "Bearer",
+ "note": "Include 'Authorization: Bearer ' header for authenticated endpoints",
+ }
+
+ oauthEndpoints := map[string]string{
+ "oauth-metadata": "GET /.well-known/oauth-protected-resource - OAuth metadata",
+ "auth-server-metadata": "GET /.well-known/oauth-authorization-server - Authorization server metadata",
+ "client-registration": "POST /oauth/register - Dynamic client registration",
+ "token-introspection": "POST /oauth/introspect - Token introspection",
+ "health": "GET /health - Health check",
+ }
+ for k, v := range oauthEndpoints {
+ response["endpoints"].(map[string]string)[k] = v
+ }
+ }
+
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
@@ -231,6 +334,10 @@ func (s *Service) Run() error {
log.Printf("SSE endpoint available at: http://%s/sse", addr)
log.Printf("Message endpoint available at: http://%s/message", addr)
log.Printf("Connect to /sse for real-time events, send JSON-RPC to /message")
+ if s.cfg.OAuthConfig.Enabled {
+ log.Printf("OAuth authentication enabled - Bearer token required for SSE and Message endpoints")
+ log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr)
+ }
return customServer.ListenAndServe()
case "streamable-http":
@@ -247,12 +354,25 @@ func (s *Service) Run() error {
// Update the mux to use the actual streamable server as the MCP handler
if mux, ok := customServer.Handler.(*http.ServeMux); ok {
- mux.Handle("/mcp", streamableServer)
+ if s.cfg.OAuthConfig.Enabled {
+ if s.authMiddleware == nil {
+ log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization")
+ }
+ // Apply authentication middleware to MCP endpoint
+ mux.Handle("/mcp", s.authMiddleware.Middleware(streamableServer))
+ } else {
+ // Register without authentication
+ mux.Handle("/mcp", streamableServer)
+ }
}
log.Printf("Streamable HTTP server listening on %s", addr)
log.Printf("MCP endpoint available at: http://%s/mcp", addr)
log.Printf("Send POST requests to /mcp to initialize session and obtain Mcp-Session-Id")
+ if s.cfg.OAuthConfig.Enabled {
+ log.Printf("OAuth authentication enabled - Bearer token required for MCP endpoint")
+ log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr)
+ }
return customServer.ListenAndServe()
default: