diff --git a/docs/oauth-authentication.md b/docs/oauth-authentication.md new file mode 100644 index 0000000..344279f --- /dev/null +++ b/docs/oauth-authentication.md @@ -0,0 +1,474 @@ +# OAuth Authentication for AKS-MCP + +This document describes how to configure and use OAuth authentication with AKS-MCP. + +## Overview + +AKS-MCP now supports OAuth 2.1 authentication using Azure Active Directory as the authorization server. When enabled, OAuth authentication provides secure access control for MCP endpoints using Bearer tokens. + +## Features + +- **Azure AD Integration**: Uses Azure Active Directory as the OAuth authorization server +- **JWT Token Validation**: Validates JWT tokens with Azure AD signing keys +- **OAuth 2.0 Metadata Endpoints**: Provides standard OAuth metadata discovery endpoints +- **Dynamic Client Registration**: Supports RFC 7591 dynamic client registration +- **Token Introspection**: Implements RFC 7662 token introspection +- **Transport Support**: Works with both SSE and HTTP Streamable transports +- **Flexible Configuration**: Supports environment variables and command-line configuration + +## Environment Setup and Azure AD Configuration + +### Prerequisites + +Before setting up OAuth authentication, ensure you have: + +- Azure CLI installed and configured (`az login`) +- An Azure subscription with appropriate permissions to create applications +- Azure Active Directory tenant access + +### Important: Environment Variables Shared Between OAuth and Azure CLI + +AKS-MCP uses the same environment variables (`AZURE_TENANT_ID`, `AZURE_CLIENT_ID`) for both OAuth authentication and Azure CLI operations. This design provides configuration simplicity but requires careful permission setup: + +**When `AZURE_CLIENT_ID` is set:** +- OAuth: Used for validating user tokens accessing the MCP server +- Azure CLI: Used for managed identity/workload identity authentication to access Azure resources + +**Permission Requirements:** +- The Azure AD application must have both **OAuth permissions** (for user authentication) AND **Azure resource permissions** (for az CLI operations) +- Missing either set of permissions will cause authentication failures + +### Step 1: Create Azure AD Application + +#### Using Azure Portal (Recommended) + +1. **Navigate to Azure Portal** + - Go to https://portal.azure.com + - Sign in with your Azure account + +2. **Create App Registration** + ``` + Navigation: Azure Active Directory → App registrations → New registration + ``` + + Configure the following: + - **Name**: `AKS-MCP-OAuth` (or your preferred name) + - **Supported account types**: "Accounts in this organizational directory only" + - **Redirect URI Platform Options**: + +#### Supported Platform Types + +**✅ Mobile and desktop applications (Recommended)** +- **Platform**: "Mobile and desktop applications" +- **Redirect URIs**: + - `http://localhost:8000/oauth/callback` +- **Benefits**: + - Native support for PKCE (required by OAuth 2.1) + - No client secret required (public client) + - Better security for localhost redirects +- **Status**: ✅ **Confirmed working** + +**❌ Single-page application (SPA) - Not Recommended** +- **Platform**: "Single-page application (SPA)" +- **Redirect URIs**: Same as above +- **Benefits**: + - Designed for PKCE flow + - No client secret required +- **Critical Limitations**: + - **Token exchange restriction**: Azure AD error AADSTS9002327 - "Tokens issued for the 'Single-Page Application' client-type may only be redeemed via cross-origin requests" + - **Architecture mismatch**: SPA platform expects frontend JavaScript to handle token exchange, but AKS-MCP performs backend token exchange + - **CORS requirements**: Requires complex frontend-backend coordination for OAuth flow +- **Status**: ❌ **Not compatible with AKS-MCP's backend OAuth implementation** + +**❌ Web application** +- **Platform**: "Web" +- **Why not supported**: + - Requires client secret (confidential client) + - AKS-MCP implements public client flow without secrets + - PKCE handling may differ + +**Choose Platform Recommendation:** +1. **Primary**: Use "Mobile and desktop applications" (✅ confirmed working) +2. **Avoid**: "Single-page application" - incompatible with backend OAuth implementation (AADSTS9002327 error) +3. **Avoid**: "Web" platform due to client secret requirements + +3. **Record Essential Information** + From the "Overview" page, note: + - **Application (client) ID** - This is your `CLIENT_ID` + - **Directory (tenant) ID** - This is your `TENANT_ID` + +#### Using Azure CLI (Alternative) + +**For Mobile and desktop applications platform:** +```bash +# Create Azure AD application with public client platform +az ad app create --display-name "AKS-MCP-OAuth" \ + --public-client-redirect-uris "http://localhost:8000/oauth/callback" + +# Get application details +az ad app list --display-name "AKS-MCP-OAuth" --query "[0].{appId:appId,objectId:objectId}" + +# Get your tenant ID +az account show --query "tenantId" -o tsv +``` + +### Step 2: Configure API Permissions + +**Critical: Both OAuth and Azure CLI require proper permissions** + +1. **Add Required API Permissions** + ``` + Navigation: Azure Active Directory → App registrations → [Your App] → API permissions + ``` + +2. **Add Azure Service Management Permission (Required for OAuth)** + - Click "Add a permission" + - Select "Microsoft APIs" → "Azure Service Management" + - Choose "Delegated permissions" + - Select `user_impersonation` + - Click "Add permissions" + +3. **Add Azure Resource Management Permissions (Required for Azure CLI)** + + When `AZURE_CLIENT_ID` is set, Azure CLI will use this application for authentication. Add these permissions based on your AKS-MCP access level: + + **For readonly access:** + - Microsoft Graph → Application permissions → `Directory.Read.All` + - Azure Service Management → Delegated permissions → `user_impersonation` + + **For readwrite/admin access:** + - Microsoft Graph → Application permissions → `Directory.Read.All` + - Azure Service Management → Delegated permissions → `user_impersonation` + - Consider adding specific Azure resource permissions based on your needs + +4. **Grant Admin Consent (Required)** + - Click "Grant admin consent for [Your Organization]" + - Confirm the consent + +**⚠️ Important Notes:** +- Without proper Azure CLI permissions, you'll see "Insufficient privileges" errors when AKS-MCP tries to access Azure resources +- The same application serves both OAuth authentication (user access to MCP) and Azure CLI authentication (MCP access to Azure) +- Test both OAuth flow AND Azure resource access after permission changes + +### Step 3: Environment Configuration + +Set the required environment variables: + +```bash +# Replace with your actual values from Step 1 +export AZURE_TENANT_ID="your-tenant-id" +export AZURE_CLIENT_ID="your-client-id" +export AZURE_SUBSCRIPTION_ID="your-subscription-id" # Optional, for AKS operations +``` + +**⚠️ Important: Dual Authentication Impact** + +When you set `AZURE_CLIENT_ID`, it affects both OAuth and Azure CLI authentication: + +1. **OAuth Authentication**: Validates user tokens for MCP server access +2. **Azure CLI Authentication**: AKS-MCP uses this client ID for managed identity authentication when accessing Azure resources + +**Common Issues:** +- If you only configured OAuth permissions, Azure CLI operations will fail with "Insufficient privileges" +- If you only configured Azure resource permissions, OAuth token validation may fail +- Solution: Ensure your Azure AD application has BOTH sets of permissions (see Step 2) + +**Testing Both Authentication Paths:** +```bash +# Test OAuth (should work after proper setup) +curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp + +# Test Azure CLI access (should work after proper permissions) +# This happens automatically when AKS-MCP tries to access Azure resources +./aks-mcp --oauth-enabled --access-level=readonly +``` + +### Step 4: Start AKS-MCP with OAuth + +```bash +# Using HTTP Streamable transport with OAuth (recommended) +./aks-mcp \ + --transport=streamable-http \ + --port=8000 \ + --oauth-enabled \ + --oauth-tenant-id="$AZURE_TENANT_ID" \ + --oauth-client-id="$AZURE_CLIENT_ID" \ + --oauth-redirects="http://localhost:8000/oauth/callback" \ + --access-level=readonly + +# Using SSE transport with OAuth (alternative) +./aks-mcp \ + --transport=sse \ + --port=8000 \ + --oauth-enabled \ + --oauth-tenant-id="$AZURE_TENANT_ID" \ + --oauth-client-id="$AZURE_CLIENT_ID" \ + --oauth-redirects="http://localhost:8000/oauth/callback" \ + --access-level=readonly + +# Environment variables are automatically used if set +# You can also just use: +./aks-mcp --transport=streamable-http --port=8000 --oauth-enabled --access-level=readonly +``` + +## Configuration Options + +### Command Line Flags + +- `--oauth-enabled`: Enable OAuth authentication (default: false) +- `--oauth-tenant-id`: Azure AD tenant ID (or use AZURE_TENANT_ID env var) +- `--oauth-client-id`: Azure AD client ID (or use AZURE_CLIENT_ID env var) +- `--oauth-redirects`: Comma-separated list of allowed redirect URIs (required when OAuth enabled) +- `--oauth-cors-origins`: Comma-separated list of allowed CORS origins for OAuth endpoints (e.g. http://localhost:6274 for MCP Inspector). If empty, no cross-origin requests are allowed for security + +**Note**: OAuth scopes are automatically configured to use `https://management.azure.com/.default` for optimal Azure AD compatibility. Custom scopes are not currently configurable via command line. + +### Example with Command Line Flags + +```bash +./aks-mcp --transport=sse --oauth-enabled=true \ + --oauth-tenant-id="12345678-1234-1234-1234-123456789012" \ + --oauth-client-id="87654321-4321-4321-4321-210987654321" +``` + +**Note**: Scopes are automatically set to `https://management.azure.com/.default` and cannot be customized via command line. + +## OAuth Endpoints + +When OAuth is enabled, the following endpoints are available: + +### Metadata Endpoints (Unauthenticated) + +- `GET /.well-known/oauth-protected-resource` - OAuth 2.0 Protected Resource Metadata (RFC 9728) +- `GET /.well-known/oauth-authorization-server` - OAuth 2.0 Authorization Server Metadata (RFC 8414) +- `GET /.well-known/openid-configuration` - OpenID Connect Discovery (alias for authorization server metadata) +- `GET /health` - Health check endpoint + +### OAuth Flow Endpoints (Unauthenticated) + +- `GET /oauth2/v2.0/authorize` - Authorization endpoint proxy to Azure AD +- `POST /oauth2/v2.0/token` - Token exchange endpoint proxy to Azure AD +- `GET /oauth/callback` - Authorization Code flow callback handler +- `POST /oauth/register` - Dynamic Client Registration (RFC 7591) + +### Token Management (Unauthenticated for simplicity) + +- `POST /oauth/introspect` - Token Introspection (RFC 7662) + +### Authenticated MCP Endpoints + +When OAuth is enabled, these endpoints require Bearer token authentication: + +- **SSE Transport**: `GET /sse`, `POST /message` +- **HTTP Streamable Transport**: `POST /mcp` + +## Client Integration + +### Obtaining an Access Token + +Use the Azure AD OAuth flow to obtain an access token: + +```bash +# Example using Azure CLI (for testing) +az account get-access-token --resource https://management.azure.com/ --query accessToken -o tsv +``` + +### Making Authenticated Requests + +Include the Bearer token in the Authorization header: + +```bash +# Example authenticated request to SSE endpoint +curl -H "Authorization: Bearer YOUR_ACCESS_TOKEN" \ + -H "Accept: text/event-stream" \ + http://localhost:8000/sse + +# Example authenticated request to HTTP Streamable endpoint +curl -H "Authorization: Bearer YOUR_ACCESS_TOKEN" \ + -H "Content-Type: application/json" \ + -X POST http://localhost:8000/mcp \ + -d '{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}' +``` + +## Testing OAuth Integration + +### 1. Test OAuth Metadata + +```bash +# Get protected resource metadata +curl http://localhost:8000/.well-known/oauth-protected-resource + +# Get authorization server metadata +curl http://localhost:8000/.well-known/oauth-authorization-server +``` + +### 2. Test Dynamic Client Registration + +```bash +curl -X POST http://localhost:8000/oauth/register \ + -H "Content-Type: application/json" \ + -d '{ + "redirect_uris": ["http://localhost:3000/oauth/callback"], + "client_name": "Test MCP Client", + "grant_types": ["authorization_code"] + }' +``` + +### 3. Test Token Introspection + +```bash +curl -X POST http://localhost:8000/oauth/introspect \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "token=YOUR_ACCESS_TOKEN" +``` + +## Security Considerations + +1. **HTTPS in Production**: Always use HTTPS in production environments +2. **Token Validation**: JWT tokens are validated against Azure AD signing keys +3. **Scope Validation**: Tokens must include required scopes +4. **Audience Validation**: Tokens must have the correct audience claim +5. **Redirect URI Validation**: Only configured redirect URIs are allowed + +## Troubleshooting + +### Common Issues + +#### Authentication and Token Issues +1. **Invalid Token**: Ensure the token is valid and not expired +2. **Wrong Audience**: Verify the token audience matches `https://management.azure.com` +3. **Missing Scopes**: Ensure the token includes `https://management.azure.com/.default` scope +4. **JWT Signature Validation Failed**: + - Check that Azure AD application platform is set correctly + - Verify tenant ID matches the issuer in the token + - Ensure token is using v2.0 format (from Azure Management API scope) + +#### Azure AD Application Configuration Issues +5. **Client ID Not Found**: Verify the Application (client) ID is correct +6. **Redirect URI Mismatch**: Ensure redirect URIs match exactly in Azure AD app registration +7. **Wrong Platform Type**: Use "Mobile and desktop applications", NOT "Web" or "Single-page application" +8. **Insufficient Permissions**: Verify both OAuth and Azure resource permissions are configured +9. **SPA Platform Incompatibility (AADSTS9002327)**: + - Error: "Tokens issued for the 'Single-Page Application' client-type may only be redeemed via cross-origin requests" + - Solution: Change Azure AD app platform to "Mobile and desktop applications" + - Cause: SPA platform requires frontend token exchange, incompatible with AKS-MCP's backend implementation + +#### Network and Endpoint Issues +10. **CORS Errors**: Check that redirect URIs are properly configured for localhost +11. **Network Issues**: Check connectivity to Azure AD endpoints +11. **Port Conflicts**: Ensure the configured port (default 8000) is available + +#### Scope and Permission Issues +12. **Scope Mixing Error**: + - Error: "scope can't be combined with resource-specific scopes" + - Solution: Our implementation automatically handles this by using only Azure Management API scope +13. **Resource Parameter Issues**: + - Azure AD doesn't support RFC 8707 resource parameter + - Our implementation works around this limitation automatically + +### Debug Logging + +Enable verbose logging for OAuth debugging: + +```bash +./aks-mcp --oauth-enabled=true --verbose +``` + +### Health Check + +Use the health endpoint to verify OAuth configuration: + +```bash +curl http://localhost:8000/health +``` + +Expected response with OAuth enabled: +```json +{ + "status": "healthy", + "oauth": { + "enabled": true + } +} +``` + +### Testing OAuth Flow Step by Step + +1. **Test Metadata Discovery**: +```bash +# Should return authorization server URLs +curl http://localhost:8000/.well-known/oauth-protected-resource + +# Should return PKCE support and endpoints +curl http://localhost:8000/.well-known/oauth-authorization-server +``` + +2. **Test Client Registration**: +```bash +curl -X POST http://localhost:8000/oauth/register \ + -H "Content-Type: application/json" \ + -d '{ + "redirect_uris": ["http://localhost:8000/oauth/callback"], + "client_name": "Test Client" + }' +``` + +3. **Test Authorization Flow**: + - Open browser to: `http://localhost:8000/oauth2/v2.0/authorize?response_type=code&client_id=YOUR_CLIENT_ID&redirect_uri=http://localhost:8000/oauth/callback&scope=https://management.azure.com/.default&code_challenge=CHALLENGE&code_challenge_method=S256&state=STATE` + +4. **Verify Token Validation**: +```bash +# Use a valid Azure AD token +curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp +``` + +## Migration from Non-OAuth + +To migrate from a non-OAuth AKS-MCP deployment: + +1. Update clients to obtain and include Bearer tokens +2. Enable OAuth on the server with `--oauth-enabled=true` +3. Configure Azure AD application and credentials +4. Test with a subset of clients before full migration +5. Monitor logs for authentication errors + +## Integration with MCP Inspector + +The MCP Inspector tool can be used to test OAuth-enabled AKS-MCP servers. Configure the Inspector's OAuth settings to match your AKS-MCP OAuth configuration for testing. + +### Important: Redirect URI Configuration for MCP Inspector + +When using MCP Inspector with OAuth authentication, you need to add the Inspector's proxy redirect URI to your OAuth configuration: + +```bash +# Add Inspector's redirect URI (typically http://localhost:6274/oauth/callback) +./aks-mcp \ + --transport=streamable-http \ + --port=8000 \ + --oauth-enabled \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback" \ + --access-level=readonly +``` + +**Key Points:** +- MCP Inspector typically runs on port 6274 by default +- The Inspector creates a proxy redirect URI at `/oauth/callback` +- You must include both your server's redirect URI AND the Inspector's redirect URI +- You must also configure CORS origins to allow the Inspector's web interface to make requests +- Comma-separate multiple redirect URIs in the `--oauth-redirects` parameter +- Comma-separate multiple CORS origins in the `--oauth-cors-origins` parameter +- Without the Inspector's redirect URI, OAuth authentication will fail with "redirect_uri not registered" error +- Without the Inspector's CORS origin, the web interface will be blocked by browser CORS policy + +**Example with MCP Inspector configuration:** +```bash +./aks-mcp \ + --transport=streamable-http \ + --port=8000 \ + --oauth-enabled \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback" \ + --oauth-cors-origins="http://localhost:6274" \ + --access-level=readonly +``` + +For more information, see the MCP OAuth specification and Azure AD documentation. \ No newline at end of file diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go new file mode 100644 index 0000000..af49899 --- /dev/null +++ b/internal/auth/oauth/endpoints.go @@ -0,0 +1,1021 @@ +package oauth + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Azure/aks-mcp/internal/auth" + "github.com/Azure/aks-mcp/internal/config" +) + +// validateAzureADURL validates that the URL is a legitimate Azure AD endpoint +func validateAzureADURL(tokenURL string) error { + parsedURL, err := url.Parse(tokenURL) + if err != nil { + return fmt.Errorf("invalid URL format: %w", err) + } + + // Only allow HTTPS for security + if parsedURL.Scheme != "https" { + return fmt.Errorf("only HTTPS URLs are allowed") + } + + // Only allow Azure AD endpoints + if parsedURL.Host != "login.microsoftonline.com" { + return fmt.Errorf("only Azure AD endpoints are allowed") + } + + // Validate path format for token endpoint (should be /{tenantId}/oauth2/v2.0/token) + if !strings.Contains(parsedURL.Path, "/oauth2/v2.0/token") { + return fmt.Errorf("invalid Azure AD token endpoint path") + } + + return nil +} + +// EndpointManager manages OAuth-related HTTP endpoints +type EndpointManager struct { + provider *AzureOAuthProvider + cfg *config.ConfigData +} + +// NewEndpointManager creates a new OAuth endpoint manager +func NewEndpointManager(provider *AzureOAuthProvider, cfg *config.ConfigData) *EndpointManager { + return &EndpointManager{ + provider: provider, + cfg: cfg, + } +} + +// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting +func (em *EndpointManager) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + requestOrigin := r.Header.Get("Origin") + + // Check if the request origin is in the allowed list + var allowedOrigin string + for _, allowed := range em.provider.config.AllowedOrigins { + if requestOrigin == allowed { + allowedOrigin = requestOrigin + break + } + } + + // Only set CORS headers if origin is allowed + if allowedOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + w.Header().Set("Access-Control-Allow-Credentials", "false") + } else if requestOrigin != "" { + log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) + } +} + +// setCacheHeaders sets cache control headers based on EnableCache configuration +func (em *EndpointManager) setCacheHeaders(w http.ResponseWriter) { + if config.EnableCache { + // Enable caching for 1 hour when cache is enabled + w.Header().Set("Cache-Control", "max-age=3600") + } else { + // Disable all caching when cache is disabled (for debugging) + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Expires", "0") + } +} + +// RegisterEndpoints registers OAuth endpoints with the provided HTTP mux +func (em *EndpointManager) RegisterEndpoints(mux *http.ServeMux) { + // OAuth 2.0 Protected Resource Metadata endpoint (RFC 9728) + mux.HandleFunc("/.well-known/oauth-protected-resource", em.protectedResourceMetadataHandler()) + + // OAuth 2.0 Authorization Server Metadata endpoint (RFC 8414) + // Note: This would typically be served by Azure AD, but we provide a proxy for convenience + mux.HandleFunc("/.well-known/oauth-authorization-server", em.authServerMetadataProxyHandler()) + + // OpenID Connect Discovery endpoint (compatibility with MCP Inspector) + mux.HandleFunc("/.well-known/openid-configuration", em.authServerMetadataProxyHandler()) + + // Authorization endpoint proxy to handle Azure AD compatibility + mux.HandleFunc("/oauth2/v2.0/authorize", em.authorizationProxyHandler()) + + // Dynamic Client Registration endpoint (RFC 7591) + mux.HandleFunc("/oauth/register", em.clientRegistrationHandler()) + + // Token introspection endpoint (RFC 7662) - optional + mux.HandleFunc("/oauth/introspect", em.tokenIntrospectionHandler()) + + // OAuth 2.0 callback endpoint for Authorization Code flow + mux.HandleFunc("/oauth/callback", em.callbackHandler()) + + // OAuth 2.0 token endpoint for Authorization Code exchange + mux.HandleFunc("/oauth2/v2.0/token", em.tokenHandler()) + + // Health check endpoint (unauthenticated) + mux.HandleFunc("/health", em.healthHandler()) +} + +// authServerMetadataProxyHandler proxies authorization server metadata from Azure AD +func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received request for authorization server metadata: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for metadata endpoint", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Get metadata from Azure AD + provider := em.provider + + // Build server URL based on the request + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + // Use the Host header from the request + host := r.Host + if host == "" { + host = r.URL.Host + } + + serverURL := fmt.Sprintf("%s://%s", scheme, host) + + metadata, err := provider.GetAuthorizationServerMetadata(serverURL) + if err != nil { + log.Printf("Failed to fetch authorization server metadata: %v\n", err) + http.Error(w, fmt.Sprintf("Failed to fetch authorization server metadata: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + em.setCacheHeaders(w) + + if err := json.NewEncoder(w).Encode(metadata); err != nil { + log.Printf("Failed to encode response: %v\n", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// clientRegistrationHandler implements OAuth 2.0 Dynamic Client Registration (RFC 7591) +func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received client registration request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodPost { + log.Printf("OAuth ERROR: Invalid method %s for client registration endpoint, only POST allowed", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse client registration request + var registrationRequest ClientRegistrationRequest + + if err := json.NewDecoder(r.Body).Decode(®istrationRequest); err != nil { + log.Printf("OAuth ERROR: Failed to parse client registration JSON: %v", err) + em.writeErrorResponse(w, "invalid_request", "Invalid JSON in request body", http.StatusBadRequest) + return + } + + log.Printf("OAuth DEBUG: Client registration request parsed - client_name: %s, redirect_uris: %v", registrationRequest.ClientName, registrationRequest.RedirectURIs) + + // Validate registration request + if err := em.validateClientRegistration(®istrationRequest); err != nil { + log.Printf("OAuth ERROR: Client registration validation failed: %v", err) + em.writeErrorResponse(w, "invalid_client_metadata", err.Error(), http.StatusBadRequest) + return + } + + // Use client-requested grant types if provided and valid, otherwise use defaults + grantTypes := registrationRequest.GrantTypes + if len(grantTypes) == 0 { + grantTypes = []string{"authorization_code", "refresh_token"} + } + + // Use client-requested response types if provided and valid, otherwise use defaults + responseTypes := registrationRequest.ResponseTypes + if len(responseTypes) == 0 { + responseTypes = []string{"code"} + } + + // For Azure AD compatibility, use the configured client ID + // In a full RFC 7591 implementation, each registration would get a unique ID + // But since Azure AD requires pre-registered client IDs, we return the configured one + clientID := em.cfg.OAuthConfig.ClientID + + clientInfo := map[string]interface{}{ + "client_id": clientID, // Use configured Azure AD client ID + "client_id_issued_at": time.Now().Unix(), // RFC 7591: timestamp of issuance + "redirect_uris": registrationRequest.RedirectURIs, + "token_endpoint_auth_method": "none", // Public client (PKCE required) + "grant_types": grantTypes, + "response_types": responseTypes, + "client_name": registrationRequest.ClientName, + "client_uri": registrationRequest.ClientURI, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + + if err := json.NewEncoder(w).Encode(clientInfo); err != nil { + log.Printf("OAuth ERROR: Failed to encode client registration response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// validateClientRegistration validates a client registration request +func (em *EndpointManager) validateClientRegistration(req *ClientRegistrationRequest) error { + // Validate redirect URIs - require at least one + if len(req.RedirectURIs) == 0 { + return fmt.Errorf("at least one redirect_uri is required") + } + + // Basic URL validation for redirect URIs + for _, redirectURI := range req.RedirectURIs { + if _, err := url.Parse(redirectURI); err != nil { + return fmt.Errorf("invalid redirect_uri format: %s", redirectURI) + } + } + + // Validate grant types + validGrantTypes := map[string]bool{ + "authorization_code": true, + "refresh_token": true, + } + + for _, grantType := range req.GrantTypes { + if !validGrantTypes[grantType] { + return fmt.Errorf("unsupported grant_type: %s", grantType) + } + } + + // Validate response types + validResponseTypes := map[string]bool{ + "code": true, + } + + for _, responseType := range req.ResponseTypes { + if !validResponseTypes[responseType] { + return fmt.Errorf("unsupported response_type: %s", responseType) + } + } + + return nil +} + +// validateRedirectURI validates that a redirect URI is registered and allowed +func (em *EndpointManager) validateRedirectURI(redirectURI string) error { + if len(em.cfg.OAuthConfig.RedirectURIs) == 0 { + return fmt.Errorf("no redirect URIs configured") + } + + for _, allowed := range em.cfg.OAuthConfig.RedirectURIs { + if redirectURI == allowed { + return nil + } + } + + log.Printf("OAuth SECURITY WARNING: Invalid redirect URI attempted: %s, allowed: %v", + redirectURI, em.cfg.OAuthConfig.RedirectURIs) + return fmt.Errorf("redirect_uri not registered: %s", redirectURI) +} + +// tokenIntrospectionHandler implements RFC 7662 OAuth 2.0 Token Introspection +func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // This endpoint should be protected with client authentication + // For simplicity, we'll skip client auth in this implementation + + token := r.FormValue("token") + if token == "" { + em.writeErrorResponse(w, "invalid_request", "Missing token parameter", http.StatusBadRequest) + return + } + + // Validate the token + provider := em.provider + + tokenInfo, err := provider.ValidateToken(r.Context(), token) + if err != nil { + // Return inactive token response + response := map[string]interface{}{ + "active": false, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode introspection response: %v", err) + } + return + } + + // Return active token response + response := map[string]interface{}{ + "active": true, + "client_id": em.cfg.OAuthConfig.ClientID, + "scope": strings.Join(tokenInfo.Scope, " "), + "sub": tokenInfo.Subject, + "aud": tokenInfo.Audience, + "iss": tokenInfo.Issuer, + "exp": tokenInfo.ExpiresAt.Unix(), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// healthHandler provides a simple health check endpoint +func (em *EndpointManager) healthHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + response := map[string]interface{}{ + "status": "healthy", + "oauth": map[string]interface{}{ + "enabled": em.cfg.OAuthConfig.Enabled, + }, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// protectedResourceMetadataHandler handles OAuth 2.0 Protected Resource Metadata requests +func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received request for protected resource metadata: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for protected resource metadata endpoint", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Build resource URL based on the request + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + // Use the Host header from the request + host := r.Host + if host == "" { + host = r.URL.Host + } + + // Build the resource URL + resourceURL := fmt.Sprintf("%s://%s", scheme, host) + log.Printf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL) + + provider := em.provider + + metadata, err := provider.GetProtectedResourceMetadata(resourceURL) + if err != nil { + log.Printf("OAuth ERROR: Failed to get protected resource metadata: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + log.Printf("OAuth DEBUG: Successfully generated protected resource metadata with %d authorization servers", len(metadata.AuthorizationServers)) + + w.Header().Set("Content-Type", "application/json") + em.setCacheHeaders(w) + + if err := json.NewEncoder(w).Encode(metadata); err != nil { + log.Printf("OAuth ERROR: Failed to encode protected resource metadata response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// writeErrorResponse writes an OAuth error response +func (em *EndpointManager) writeErrorResponse(w http.ResponseWriter, errorCode, description string, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + response := map[string]interface{}{ + "error": errorCode, + "error_description": description, + } + + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode error response: %v", err) + } +} + +// authorizationProxyHandler proxies authorization requests to Azure AD with resource parameter filtering +func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received authorization proxy request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for authorization endpoint, only GET allowed", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse query parameters + query := r.URL.Query() + + // Validate redirect_uri parameter for security and better user experience + redirectURI := query.Get("redirect_uri") + if redirectURI == "" { + log.Printf("OAuth ERROR: Missing redirect_uri parameter in authorization request") + log.Printf("OAuth HELP: To fix this error, configure redirect URIs using --oauth-redirects flag") + log.Printf("OAuth HELP: For MCP Inspector, use: --oauth-redirects=\"http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback\"") + em.writeErrorResponse(w, "invalid_request", "redirect_uri parameter is required", http.StatusBadRequest) + return + } + + // Validate that the redirect_uri is registered and allowed + if err := em.validateRedirectURI(redirectURI); err != nil { + log.Printf("OAuth ERROR: redirect_uri %s not registered - requests will be blocked for security", redirectURI) + em.writeErrorResponse(w, "invalid_request", fmt.Sprintf("redirect_uri not registered: %s", redirectURI), http.StatusBadRequest) + return + } + + // Enforce PKCE for OAuth 2.1 compliance (MCP requirement) + codeChallenge := query.Get("code_challenge") + codeChallengeMethod := query.Get("code_challenge_method") + + if codeChallenge == "" { + log.Printf("OAuth ERROR: Missing PKCE code_challenge parameter (required for OAuth 2.1)") + em.writeErrorResponse(w, "invalid_request", "PKCE code_challenge is required", http.StatusBadRequest) + return + } + + if codeChallengeMethod == "" { + // Default to S256 if not specified + query.Set("code_challenge_method", "S256") + log.Printf("OAuth DEBUG: Setting default code_challenge_method to S256") + } else if codeChallengeMethod != "S256" { + log.Printf("OAuth ERROR: Unsupported code_challenge_method: %s (only S256 supported)", codeChallengeMethod) + em.writeErrorResponse(w, "invalid_request", "Only S256 code_challenge_method is supported", http.StatusBadRequest) + return + } + + // Resource parameter handling for MCP compliance + // requestedScopes := strings.Split(query.Get("scope"), " ") + + // Azure AD v2.0 doesn't support RFC 8707 Resource Indicators in authorization requests + // Remove the resource parameter if present for Azure AD compatibility + resourceParam := query.Get("resource") + if resourceParam != "" { + log.Printf("OAuth DEBUG: Removing resource parameter for Azure AD compatibility: %s", resourceParam) + query.Del("resource") + } + + // Use only server-required scopes for Azure AD compatibility + // Azure AD .default scopes cannot be mixed with OpenID Connect scopes + // We prioritize Azure Management API access over OpenID Connect user info + finalScopes := em.cfg.OAuthConfig.RequiredScopes + + finalScopeString := strings.Join(finalScopes, " ") + query.Set("scope", finalScopeString) + log.Printf("OAuth DEBUG: Setting final scope for Azure AD: %s", finalScopeString) + + // Build the Azure AD authorization URL + azureAuthURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", em.cfg.OAuthConfig.TenantID) + + // Create the redirect URL with filtered parameters + redirectURL := fmt.Sprintf("%s?%s", azureAuthURL, query.Encode()) + log.Printf("OAuth DEBUG: Redirecting to Azure AD authorization endpoint: %s", azureAuthURL) + + // Redirect to Azure AD + http.Redirect(w, r, redirectURL, http.StatusFound) + } +} + +// callbackHandler handles OAuth 2.0 Authorization Code flow callback +func (em *EndpointManager) callbackHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received callback request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for callback endpoint, only GET allowed", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse query parameters + query := r.URL.Query() + + // Check for error response from authorization server + if authError := query.Get("error"); authError != "" { + errorDesc := query.Get("error_description") + log.Printf("OAuth ERROR: Authorization server returned error: %s - %s", authError, errorDesc) + em.writeCallbackErrorResponse(w, fmt.Sprintf("Authorization failed: %s - %s", authError, errorDesc)) + return + } + + // Get authorization code + code := query.Get("code") + if code == "" { + log.Printf("OAuth ERROR: Missing authorization code in callback") + em.writeCallbackErrorResponse(w, "Missing authorization code") + return + } + + // Get state parameter for CSRF protection + state := query.Get("state") + if state == "" { + log.Printf("OAuth ERROR: Missing state parameter in callback") + em.writeCallbackErrorResponse(w, "Missing state parameter") + return + } + + log.Printf("OAuth DEBUG: Callback parameters validated - has_code: true, state: %s", state) + + // Validate redirect URI for security - construct expected URI and validate it + expectedRedirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port) + if err := em.validateRedirectURI(expectedRedirectURI); err != nil { + log.Printf("OAuth ERROR: Redirect URI validation failed: %v", err) + em.writeCallbackErrorResponse(w, "Invalid redirect URI") + return + } + + // Exchange authorization code for access token + tokenResponse, err := em.exchangeCodeForToken(code, state) + if err != nil { + log.Printf("OAuth ERROR: Failed to exchange authorization code for token: %v", err) + em.writeCallbackErrorResponse(w, fmt.Sprintf("Failed to exchange code for token: %v", err)) + return + } + + // Skip token validation in callback - validation happens during MCP requests + // Create minimal token info for callback success page + tokenInfo := &auth.TokenInfo{ + AccessToken: tokenResponse.AccessToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration + Scope: em.cfg.OAuthConfig.RequiredScopes, // Use configured scopes + Subject: "authenticated_user", // Placeholder + Audience: []string{fmt.Sprintf("https://sts.windows.net/%s/", em.cfg.OAuthConfig.TenantID)}, + Issuer: fmt.Sprintf("https://sts.windows.net/%s/", em.cfg.OAuthConfig.TenantID), + Claims: make(map[string]interface{}), + } + + // Return success response with token information + em.writeCallbackSuccessResponse(w, tokenResponse, tokenInfo) + } +} + +// TokenResponse represents the response from token exchange +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// exchangeCodeForToken exchanges authorization code for access token +func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenResponse, error) { + // Prepare token exchange request + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.cfg.OAuthConfig.TenantID) + + // Validate URL for security + if err := validateAzureADURL(tokenURL); err != nil { + return nil, fmt.Errorf("invalid token URL: %w", err) + } + + // Use default callback redirect URI for token exchange + redirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port) + + // Prepare form data + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("client_id", em.cfg.OAuthConfig.ClientID) + data.Set("code", code) + data.Set("redirect_uri", redirectURI) + data.Set("scope", strings.Join(em.cfg.OAuthConfig.RequiredScopes, " ")) + + // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests + // It uses scope-based resource identification instead + // For MCP compliance, we handle resource binding through audience validation + + // Make token exchange request + resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse token response + var tokenResponse TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &tokenResponse, nil +} + +// writeCallbackErrorResponse writes an error response for callback +func (em *EndpointManager) writeCallbackErrorResponse(w http.ResponseWriter, message string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + + html := fmt.Sprintf(` + + + + OAuth Authentication Error + + + +
+

Authentication Error

+

%s

+

Please try again or contact your administrator.

+
+ +`, message) + + if _, err := w.Write([]byte(html)); err != nil { + log.Printf("Failed to write error response: %v", err) + } +} + +// writeCallbackSuccessResponse writes a success response for callback +func (em *EndpointManager) writeCallbackSuccessResponse(w http.ResponseWriter, tokenResponse *TokenResponse, tokenInfo *auth.TokenInfo) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + + // Generate a secure session token for the client to use + _, err := em.generateSessionToken() + if err != nil { + em.writeCallbackErrorResponse(w, "Failed to generate session token") + return + } + + html := fmt.Sprintf(` + + + + OAuth Authentication Success + + + +
+

Authentication Successful

+

You have been successfully authenticated with Azure AD.

+ +
+

Access Token (use as Bearer token):

+
%s
+ +
+ +
+

Token Information:

+ +
+ +
+

For MCP Client Usage:

+

Use this token in the Authorization header:

+
Authorization: Bearer %s
+ +
+
+ + + +`, + tokenResponse.AccessToken, + tokenInfo.Subject, + strings.Join(tokenInfo.Audience, ", "), + strings.Join(tokenInfo.Scope, ", "), + tokenInfo.ExpiresAt.Format("2006-01-02 15:04:05 UTC"), + tokenResponse.AccessToken, + tokenResponse.AccessToken) + + if _, err := w.Write([]byte(html)); err != nil { + log.Printf("Failed to write success response: %v", err) + } +} + +// isValidClientID validates if a client ID is acceptable +func (em *EndpointManager) isValidClientID(clientID string) bool { + // Accept configured client ID (primary method for Azure AD) + if clientID == em.cfg.OAuthConfig.ClientID { + return true + } + + // For future extensibility, could accept other registered client IDs + // But for Azure AD integration, we primarily use the configured client ID + + return false +} + +// generateSessionToken generates a secure random session token +func (em *EndpointManager) generateSessionToken() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(bytes), nil +} + +// tokenHandler handles OAuth 2.0 token endpoint requests (Authorization Code exchange) +func (em *EndpointManager) tokenHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received token endpoint request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w, r) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodPost { + log.Printf("OAuth ERROR: Invalid method %s for token endpoint, only POST allowed", r.Method) + em.writeErrorResponse(w, "invalid_request", "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + log.Printf("OAuth ERROR: Failed to parse form data: %v", err) + em.writeErrorResponse(w, "invalid_request", "Failed to parse form data", http.StatusBadRequest) + return + } + + // Validate grant type + grantType := r.FormValue("grant_type") + if grantType != "authorization_code" { + log.Printf("OAuth ERROR: Unsupported grant type: %s", grantType) + em.writeErrorResponse(w, "unsupported_grant_type", fmt.Sprintf("Unsupported grant type: %s", grantType), http.StatusBadRequest) + return + } + + // Extract required parameters + code := r.FormValue("code") + clientID := r.FormValue("client_id") + redirectURI := r.FormValue("redirect_uri") + codeVerifier := r.FormValue("code_verifier") // PKCE parameter + + if code == "" { + log.Printf("OAuth ERROR: Missing authorization code in token request") + em.writeErrorResponse(w, "invalid_request", "Missing authorization code", http.StatusBadRequest) + return + } + + if clientID == "" { + log.Printf("OAuth ERROR: Missing client_id in token request") + em.writeErrorResponse(w, "invalid_request", "Missing client_id", http.StatusBadRequest) + return + } + + if redirectURI == "" { + log.Printf("OAuth ERROR: Missing redirect_uri in token request") + em.writeErrorResponse(w, "invalid_request", "Missing redirect_uri", http.StatusBadRequest) + return + } + + // Enforce PKCE code_verifier for OAuth 2.1 compliance + if codeVerifier == "" { + log.Printf("OAuth ERROR: Missing PKCE code_verifier (required for OAuth 2.1)") + em.writeErrorResponse(w, "invalid_request", "PKCE code_verifier is required", http.StatusBadRequest) + return + } + + // Validate client ID (accept both configured and dynamically registered clients) + if !em.isValidClientID(clientID) { + log.Printf("OAuth ERROR: Invalid client_id: %s", clientID) + em.writeErrorResponse(w, "invalid_client", "Invalid client_id", http.StatusBadRequest) + return + } + + // Validate redirect URI for security + if err := em.validateRedirectURI(redirectURI); err != nil { + log.Printf("OAuth ERROR: Redirect URI validation failed in token endpoint: %v", err) + em.writeErrorResponse(w, "invalid_request", "Invalid redirect_uri", http.StatusBadRequest) + return + } + + // Extract scope from the token request (MCP client should send the same scope) + requestedScope := r.FormValue("scope") + if requestedScope == "" { + // Fallback to server required scopes if not provided + requestedScope = strings.Join(em.cfg.OAuthConfig.RequiredScopes, " ") + } + + log.Printf("OAuth DEBUG: Exchanging authorization code for access token with Azure AD, scope: %s", requestedScope) + + // Exchange authorization code for access token with Azure AD + tokenResponse, err := em.exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, requestedScope) + if err != nil { + log.Printf("OAuth ERROR: Token exchange with Azure AD failed: %v", err) + em.writeErrorResponse(w, "invalid_grant", fmt.Sprintf("Authorization code exchange failed: %v", err), http.StatusBadRequest) + return + } + + // Return token response + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + if err := json.NewEncoder(w).Encode(tokenResponse); err != nil { + log.Printf("OAuth ERROR: Failed to encode token response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// exchangeCodeForTokenDirect exchanges authorization code for access token directly with Azure AD +func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, scope string) (*TokenResponse, error) { + // Prepare token exchange request to Azure AD + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.cfg.OAuthConfig.TenantID) + + // Validate URL for security + if err := validateAzureADURL(tokenURL); err != nil { + return nil, fmt.Errorf("invalid token URL: %w", err) + } + + // Prepare form data + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("client_id", em.cfg.OAuthConfig.ClientID) + data.Set("code", code) + data.Set("redirect_uri", redirectURI) + data.Set("scope", scope) // Use the scope provided by the client + + // Add PKCE code_verifier if present + if codeVerifier != "" { + data.Set("code_verifier", codeVerifier) + log.Printf("Including PKCE code_verifier in Azure AD token request") + } else { + log.Printf("No PKCE code_verifier provided - this may cause PKCE verification to fail") + } + + // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests + // It uses scope-based resource identification instead + // For MCP compliance, we handle resource binding through audience validation + log.Printf("Azure AD token request with scope: %s", scope) + + // Make token exchange request to Azure AD + resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse token response + var tokenResponse TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + log.Printf("Token exchange successful: access_token received (length: %d)", len(tokenResponse.AccessToken)) + + return &tokenResponse, nil +} diff --git a/internal/auth/oauth/endpoints_test.go b/internal/auth/oauth/endpoints_test.go new file mode 100644 index 0000000..f457d5c --- /dev/null +++ b/internal/auth/oauth/endpoints_test.go @@ -0,0 +1,601 @@ +package oauth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Azure/aks-mcp/internal/auth" + "github.com/Azure/aks-mcp/internal/config" +) + +// createTestConfig creates a test ConfigData with OAuth configuration +func createTestConfig() *config.ConfigData { + cfg := config.NewConfig() + cfg.Host = "127.0.0.1" + cfg.Port = 8000 + cfg.OAuthConfig = &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + RedirectURIs: []string{"http://127.0.0.1:8000/oauth/callback", "http://localhost:8000/oauth/callback"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + }, + } + return cfg +} + +func TestEndpointManager_RegisterEndpoints(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + mux := http.NewServeMux() + manager.RegisterEndpoints(mux) + + // Test that endpoints are registered by making requests + testCases := []struct { + method string + path string + status int + }{ + {"GET", "/.well-known/oauth-protected-resource", http.StatusOK}, + {"GET", "/.well-known/oauth-authorization-server", http.StatusInternalServerError}, // Will fail without real Azure AD + {"POST", "/oauth/register", http.StatusBadRequest}, // Missing required data + {"POST", "/oauth/introspect", http.StatusBadRequest}, // Missing token param + {"GET", "/oauth/callback", http.StatusBadRequest}, // Missing required params + {"GET", "/health", http.StatusOK}, + } + + for _, tc := range testCases { + t.Run(tc.method+" "+tc.path, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.path, nil) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + if w.Code != tc.status { + t.Errorf("Expected status %d for %s %s, got %d", tc.status, tc.method, tc.path, w.Code) + } + }) + } +} + +func TestProtectedResourceMetadataEndpoint(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) + w := httptest.NewRecorder() + + handler := manager.protectedResourceMetadataHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var metadata ProtectedResourceMetadata + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + expectedAuthServer := "http://example.com" + if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServer { + t.Errorf("Expected auth server %s, got %v", expectedAuthServer, metadata.AuthorizationServers) + } + + if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" { + t.Errorf("Expected scopes %v, got %v", cfg.OAuthConfig.RequiredScopes, metadata.ScopesSupported) + } +} + +func TestClientRegistrationEndpoint(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test valid registration request + registrationRequest := map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "token_endpoint_auth_method": "none", + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + "scope": "https://management.azure.com/.default", + "client_name": "Test Client", + } + + reqBody, _ := json.Marshal(registrationRequest) + req := httptest.NewRequest("POST", "/oauth/register", strings.NewReader(string(reqBody))) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + handler := manager.clientRegistrationHandler() + handler(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d", w.Code) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if response["client_id"] == "" { + t.Error("Expected client_id in response") + } + + redirectURIs, ok := response["redirect_uris"].([]interface{}) + if !ok || len(redirectURIs) != 1 { + t.Errorf("Expected redirect URIs in response") + } +} + +func TestTokenIntrospectionEndpoint(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test with valid token (since JWT validation is disabled, any token works) + // Note: Must use a token that looks like a JWT (has dots) to pass initial format checks + req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader("token=header.payload.signature")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + w := httptest.NewRecorder() + handler := manager.tokenIntrospectionHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if active, ok := response["active"].(bool); !ok || !active { + t.Error("Expected active token") + } +} + +func TestTokenIntrospectionEndpointMissingToken(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test without token parameter + req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader("")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + w := httptest.NewRecorder() + handler := manager.tokenIntrospectionHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing token, got %d", w.Code) + } +} + +func TestHealthEndpoint(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + + handler := manager.healthHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if response["status"] != "healthy" { + t.Errorf("Expected status healthy, got %v", response["status"]) + } + + oauth, ok := response["oauth"].(map[string]interface{}) + if !ok { + t.Error("Expected oauth object in response") + } + + if oauth["enabled"] != true { + t.Errorf("Expected oauth enabled true, got %v", oauth["enabled"]) + } +} + +func TestValidateClientRegistration(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + request map[string]interface{} + wantErr bool + }{ + { + name: "valid request", + request: map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + }, + wantErr: false, + }, + { + name: "missing redirect URIs", + request: map[string]interface{}{ + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + }, + wantErr: true, + }, + { + name: "invalid grant type", + request: map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "grant_types": []string{"client_credentials"}, + "response_types": []string{"code"}, + }, + wantErr: true, + }, + { + name: "invalid response type", + request: map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "grant_types": []string{"authorization_code"}, + "response_types": []string{"token"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Convert test request to the expected struct format + req := &ClientRegistrationRequest{} + + if redirectURIs, ok := tt.request["redirect_uris"].([]string); ok { + req.RedirectURIs = redirectURIs + } + if grantTypes, ok := tt.request["grant_types"].([]string); ok { + req.GrantTypes = grantTypes + } + if responseTypes, ok := tt.request["response_types"].([]string); ok { + req.ResponseTypes = responseTypes + } + + err := manager.validateClientRegistration(req) + if (err != nil) != tt.wantErr { + t.Errorf("validateClientRegistration() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCallbackEndpointMissingCode(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test callback without authorization code + req := httptest.NewRequest("GET", "/oauth/callback?state=test-state", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing code, got %d", w.Code) + } + + // Check that response contains HTML error page + contentType := w.Header().Get("Content-Type") + if !strings.Contains(contentType, "text/html") { + t.Errorf("Expected HTML content type, got %s", contentType) + } + + body := w.Body.String() + if !strings.Contains(body, "Missing authorization code") { + t.Error("Expected error message about missing authorization code") + } +} + +func TestCallbackEndpointMissingState(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test callback without state parameter + req := httptest.NewRequest("GET", "/oauth/callback?code=test-code", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing state, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "Missing state parameter") { + t.Error("Expected error message about missing state parameter") + } +} + +func TestCallbackEndpointAuthError(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test callback with authorization error + req := httptest.NewRequest("GET", "/oauth/callback?error=access_denied&error_description=User%20denied%20access", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for auth error, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "Authorization failed") { + t.Error("Expected error message about authorization failure") + } + if !strings.Contains(body, "access_denied") { + t.Error("Expected specific error code in response") + } +} + +func TestCallbackEndpointMethodNotAllowed(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + // Test callback with POST method (should only accept GET) + req := httptest.NewRequest("POST", "/oauth/callback", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405 for POST method, got %d", w.Code) + } +} + +func TestValidateRedirectURI(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + redirectURI string + wantErr bool + }{ + { + name: "valid redirect URI - 127.0.0.1", + redirectURI: "http://127.0.0.1:8000/oauth/callback", + wantErr: false, + }, + { + name: "valid redirect URI - localhost", + redirectURI: "http://localhost:8000/oauth/callback", + wantErr: false, + }, + { + name: "invalid redirect URI - wrong port", + redirectURI: "http://127.0.0.1:9000/oauth/callback", + wantErr: true, + }, + { + name: "invalid redirect URI - wrong path", + redirectURI: "http://127.0.0.1:8000/oauth/malicious", + wantErr: true, + }, + { + name: "invalid redirect URI - external domain", + redirectURI: "http://malicious.com:8000/oauth/callback", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := manager.validateRedirectURI(tt.redirectURI) + if (err != nil) != tt.wantErr { + t.Errorf("validateRedirectURI() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + + // Test with empty redirect URIs configuration + cfgEmpty := createTestConfig() + cfgEmpty.OAuthConfig.RedirectURIs = []string{} + managerEmpty := NewEndpointManager(provider, cfgEmpty) + + err := managerEmpty.validateRedirectURI("http://127.0.0.1:8000/oauth/callback") + if err == nil { + t.Error("Expected error when no redirect URIs are configured") + } +} + +// TestAuthorizationProxyRedirectURIValidation tests the authorization endpoint redirect URI validation +func TestCORSHeaders(t *testing.T) { + cfg := createTestConfig() + cfg.OAuthConfig.AllowedOrigins = []string{"http://localhost:6274"} + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + origin string + expectCORSSet bool + expectOrigin string + }{ + { + name: "allowed origin", + origin: "http://localhost:6274", + expectCORSSet: true, + expectOrigin: "http://localhost:6274", + }, + { + name: "disallowed origin", + origin: "http://malicious.com", + expectCORSSet: false, + expectOrigin: "", + }, + { + name: "no origin header", + origin: "", + expectCORSSet: false, + expectOrigin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/health", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + w := httptest.NewRecorder() + + handler := manager.healthHandler() + handler(w, req) + + corsOrigin := w.Header().Get("Access-Control-Allow-Origin") + if tt.expectCORSSet { + if corsOrigin != tt.expectOrigin { + t.Errorf("Expected CORS origin %s, got %s", tt.expectOrigin, corsOrigin) + } + } else { + if corsOrigin != "" { + t.Errorf("Expected no CORS headers, but got Access-Control-Allow-Origin: %s", corsOrigin) + } + } + }) + } +} + +func TestAuthorizationProxyRedirectURIValidation(t *testing.T) { + cfg := createTestConfig() + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + redirectURI string + expectError bool + expectCode int + }{ + { + name: "missing redirect_uri", + redirectURI: "", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "valid redirect_uri - 127.0.0.1", + redirectURI: "http://127.0.0.1:8000/oauth/callback", + expectError: false, + expectCode: http.StatusFound, // Should redirect to Azure AD + }, + { + name: "valid redirect_uri - localhost", + redirectURI: "http://localhost:8000/oauth/callback", + expectError: false, + expectCode: http.StatusFound, // Should redirect to Azure AD + }, + { + name: "invalid redirect_uri - wrong port", + redirectURI: "http://127.0.0.1:9000/oauth/callback", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "invalid redirect_uri - wrong path", + redirectURI: "http://127.0.0.1:8000/oauth/malicious", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "invalid redirect_uri - external domain", + redirectURI: "http://malicious.com:8000/oauth/callback", + expectError: true, + expectCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create request URL with redirect_uri parameter if provided + requestURL := "/oauth2/v2.0/authorize?response_type=code&client_id=test-client&code_challenge=test&code_challenge_method=S256&state=test" + if tt.redirectURI != "" { + requestURL += "&redirect_uri=" + tt.redirectURI + } + + req := httptest.NewRequest("GET", requestURL, nil) + w := httptest.NewRecorder() + + handler := manager.authorizationProxyHandler() + handler(w, req) + + if tt.expectError { + if w.Code != tt.expectCode { + t.Errorf("Expected status code %d, got %d", tt.expectCode, w.Code) + } + + // Check that error response contains helpful information + body := w.Body.String() + if !strings.Contains(body, "redirect_uri") { + t.Errorf("Error response should mention redirect_uri, got: %s", body) + } + } else { + if w.Code != tt.expectCode { + t.Errorf("Expected status code %d, got %d", tt.expectCode, w.Code) + } + + // For successful cases, check redirect location contains expected parameters + location := w.Header().Get("Location") + if location == "" { + t.Errorf("Expected redirect location header, got empty") + } + if !strings.Contains(location, "login.microsoftonline.com") { + t.Errorf("Expected redirect to Azure AD, got: %s", location) + } + } + }) + } +} diff --git a/internal/auth/oauth/middleware.go b/internal/auth/oauth/middleware.go new file mode 100644 index 0000000..5170385 --- /dev/null +++ b/internal/auth/oauth/middleware.go @@ -0,0 +1,299 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + + "github.com/Azure/aks-mcp/internal/auth" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const tokenInfoKey contextKey = "token_info" + +// AuthMiddleware handles OAuth authentication for HTTP requests +type AuthMiddleware struct { + provider *AzureOAuthProvider + serverURL string +} + +// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting +func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + requestOrigin := r.Header.Get("Origin") + + // Check if the request origin is in the allowed list + var allowedOrigin string + for _, allowed := range m.provider.config.AllowedOrigins { + if requestOrigin == allowed { + allowedOrigin = requestOrigin + break + } + } + + // Only set CORS headers if origin is allowed + if allowedOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + w.Header().Set("Access-Control-Allow-Credentials", "false") + } else if requestOrigin != "" { + log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) + } +} + +// NewAuthMiddleware creates a new authentication middleware +func NewAuthMiddleware(provider *AzureOAuthProvider, serverURL string) *AuthMiddleware { + return &AuthMiddleware{ + provider: provider, + serverURL: serverURL, + } +} + +// Middleware returns an HTTP middleware function for OAuth authentication +func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // Skip authentication for specific endpoints + if m.shouldSkipAuth(r) { + log.Printf("Skipping auth for path: %s\n", r.URL.Path) + next.ServeHTTP(w, r) + return + } + + // Perform authentication + authResult := m.authenticateRequest(r) + + if !authResult.Authenticated { + log.Printf("Authentication FAILED - handling error\n") + m.handleAuthError(w, r, authResult) + return + } + + // Add token info to request context + ctx := context.WithValue(r.Context(), tokenInfoKey, authResult.TokenInfo) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) +} + +// shouldSkipAuth determines if authentication should be skipped for this request +func (m *AuthMiddleware) shouldSkipAuth(r *http.Request) bool { + // Skip auth for OAuth metadata endpoints + path := r.URL.Path + + skipPaths := []string{ + "/.well-known/oauth-protected-resource", + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", + "/oauth2/v2.0/authorize", + "/oauth/register", + "/oauth/callback", + "/oauth2/v2.0/token", + "/oauth/introspect", + "/health", + "/ping", + } + + for _, skipPath := range skipPaths { + if path == skipPath { + return true + } + } + + return false +} + +// authenticateRequest performs OAuth authentication on the request +func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { + // Extract Bearer token from Authorization header + authHeader := r.Header.Get("Authorization") + + if authHeader == "" { + log.Printf("OAuth DEBUG - Missing authorization header for %s %s\n", r.Method, r.URL.Path) + log.Printf("OAuth DEBUG - Request headers: %+v\n", r.Header) + return &auth.AuthResult{ + Authenticated: false, + Error: "missing authorization header", + StatusCode: http.StatusUnauthorized, + } + } + + // Check for Bearer token format + const bearerPrefix = "Bearer " + if !strings.HasPrefix(authHeader, bearerPrefix) { + log.Printf("FAILED - Invalid authorization header format (missing Bearer prefix)\n") + return &auth.AuthResult{ + Authenticated: false, + Error: "invalid authorization header format", + StatusCode: http.StatusUnauthorized, + } + } + + token := strings.TrimPrefix(authHeader, bearerPrefix) + if token == "" { + log.Printf("FAILED - Empty bearer token\n") + return &auth.AuthResult{ + Authenticated: false, + Error: "empty bearer token", + StatusCode: http.StatusUnauthorized, + } + } + + // Basic JWT structure validation + tokenParts := strings.Split(token, ".") + if len(tokenParts) != 3 { + log.Printf("FAILED - JWT structure validation (has %d parts, expected 3)\n", len(tokenParts)) + return &auth.AuthResult{ + Authenticated: false, + Error: "invalid JWT structure", + StatusCode: http.StatusUnauthorized, + } + } + + // Validate the token + tokenInfo, err := m.provider.ValidateToken(r.Context(), token) + if err != nil { + log.Printf("FAILED - Provider token validation failed: %v\n", err) + return &auth.AuthResult{ + Authenticated: false, + Error: fmt.Sprintf("token validation failed: %v", err), + StatusCode: http.StatusUnauthorized, + } + } + + // Validate required scopes - strict enforcement for security + if !m.validateScopes(tokenInfo.Scope) { + log.Printf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes) + return &auth.AuthResult{ + Authenticated: false, + Error: "insufficient scope", + StatusCode: http.StatusForbidden, + } + } + + return &auth.AuthResult{ + Authenticated: true, + TokenInfo: tokenInfo, + StatusCode: http.StatusOK, + } +} + +// validateScopes checks if the token has required scopes +func (m *AuthMiddleware) validateScopes(tokenScopes []string) bool { + requiredScopes := m.provider.config.RequiredScopes + if len(requiredScopes) == 0 { + return true // No scopes required + } + + // Check if token has at least one required scope + for _, required := range requiredScopes { + if m.hasScopePermission(required, tokenScopes) { + return true + } + } + + return false +} + +// hasScopePermission checks if the token scopes satisfy the required scope +func (m *AuthMiddleware) hasScopePermission(requiredScope string, tokenScopes []string) bool { + // Direct scope match + for _, tokenScope := range tokenScopes { + if tokenScope == requiredScope { + return true + } + } + + // Azure resource scope mapping + azureResourceMappings := map[string][]string{ + "https://management.azure.com/.default": { + "user_impersonation", + "https://management.azure.com/user_impersonation", + "https://management.azure.com/.default", + "https://management.core.windows.net/", + "https://management.azure.com/", + }, + "https://graph.microsoft.com/.default": { + "User.Read", + "https://graph.microsoft.com/User.Read", + }, + } + + if allowedScopes, exists := azureResourceMappings[requiredScope]; exists { + for _, allowedScope := range allowedScopes { + for _, tokenScope := range tokenScopes { + if tokenScope == allowedScope { + return true + } + } + } + } + + return false +} + +// handleAuthError handles authentication errors +func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, authResult *auth.AuthResult) { + // Set CORS headers + m.setCORSHeaders(w, r) + w.Header().Set("Content-Type", "application/json") + + // Add WWW-Authenticate header for 401 responses (RFC 9728 Section 5.1) + if authResult.StatusCode == http.StatusUnauthorized { + // Build the resource metadata URL + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + host := r.Host + if host == "" { + host = r.URL.Host + } + serverURL := fmt.Sprintf("%s://%s", scheme, host) + resourceMetadataURL := fmt.Sprintf("%s/.well-known/oauth-protected-resource", serverURL) + + // RFC 9728 compliant WWW-Authenticate header + wwwAuth := fmt.Sprintf(`Bearer realm="%s", resource_metadata="%s"`, serverURL, resourceMetadataURL) + + // Add error information if available + if authResult.Error != "" { + wwwAuth += fmt.Sprintf(`, error="invalid_token", error_description="%s"`, authResult.Error) + } + + w.Header().Set("WWW-Authenticate", wwwAuth) + } + + w.WriteHeader(authResult.StatusCode) + + errorResponse := map[string]interface{}{ + "error": getOAuthErrorCode(authResult.StatusCode), + "error_description": authResult.Error, + } + + if err := json.NewEncoder(w).Encode(errorResponse); err != nil { + log.Printf("MIDDLEWARE ERROR: Failed to encode error response: %v\n", err) + } else { + log.Printf("MIDDLEWARE ERROR: Error response sent\n") + } +} + +// getOAuthErrorCode returns appropriate OAuth error code for HTTP status +func getOAuthErrorCode(statusCode int) string { + switch statusCode { + case http.StatusUnauthorized: + return "invalid_token" + case http.StatusForbidden: + return "insufficient_scope" + case http.StatusBadRequest: + return "invalid_request" + default: + return "server_error" + } +} diff --git a/internal/auth/oauth/middleware_test.go b/internal/auth/oauth/middleware_test.go new file mode 100644 index 0000000..358cb1e --- /dev/null +++ b/internal/auth/oauth/middleware_test.go @@ -0,0 +1,253 @@ +package oauth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Azure/aks-mcp/internal/auth" +) + +// GetTokenInfo extracts token information from request context (test helper) +func GetTokenInfo(r *http.Request) (*auth.TokenInfo, bool) { + tokenInfo, ok := r.Context().Value(tokenInfoKey).(*auth.TokenInfo) + return tokenInfo, ok +} + +func TestAuthMiddleware(t *testing.T) { + // Create test config with minimal required scopes for testing + // Note: We cannot test with empty RequiredScopes because the OAuth configuration + // validation now requires at least one scope to be specified. This is intentional + // to prevent security misconfigurations in production environments. + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + middleware := NewAuthMiddleware(provider, "http://localhost:8000") + + // Create a test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("success")); err != nil { + t.Errorf("Failed to write test response: %v", err) + } + }) + + wrappedHandler := middleware.Middleware(testHandler) + + tests := []struct { + name string + authHeader string + expectedStatus int + path string + }{ + { + name: "valid bearer token", + authHeader: "Bearer header.payload.signature", + expectedStatus: http.StatusOK, + path: "/test", + }, + { + name: "missing authorization header", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + path: "/test", + }, + { + name: "invalid token format", + authHeader: "InvalidFormat", + expectedStatus: http.StatusUnauthorized, + path: "/test", + }, + { + name: "non-bearer token", + authHeader: "Basic dXNlcjpwYXNz", + expectedStatus: http.StatusUnauthorized, + path: "/test", + }, + { + name: "skip auth for metadata endpoint", + authHeader: "", + expectedStatus: http.StatusOK, + path: "/.well-known/oauth-protected-resource", + }, + { + name: "skip auth for health endpoint", + authHeader: "", + expectedStatus: http.StatusOK, + path: "/health", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.path, nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + // Check WWW-Authenticate header for 401 responses + if w.Code == http.StatusUnauthorized { + wwwAuth := w.Header().Get("WWW-Authenticate") + if wwwAuth == "" { + t.Error("Expected WWW-Authenticate header for 401 response") + } + } + }) + } +} + +func TestAuthMiddlewareContextPropagation(t *testing.T) { + // Note: We cannot test with empty RequiredScopes because the OAuth configuration + // validation now requires at least one scope to be specified. + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + middleware := NewAuthMiddleware(provider, "http://localhost:8000") + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if token info is available in context + tokenInfo, ok := GetTokenInfo(r) + if !ok { + t.Error("Token info not found in context") + return + } + + if tokenInfo.AccessToken != "header.payload.signature" { + t.Errorf("Expected token header.payload.signature, got %s", tokenInfo.AccessToken) + } + + w.WriteHeader(http.StatusOK) + }) + + wrappedHandler := middleware.Middleware(testHandler) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer header.payload.signature") + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } +} + +func TestShouldSkipAuth(t *testing.T) { + // Note: We cannot test with empty RequiredScopes because the OAuth configuration + // validation now requires at least one scope to be specified. + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, // Minimal scope for testing + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + middleware := NewAuthMiddleware(provider, "http://localhost:8000") + + tests := []struct { + path string + expected bool + }{ + {"/.well-known/oauth-protected-resource", true}, + {"/.well-known/oauth-authorization-server", true}, + {"/.well-known/openid-configuration", true}, + {"/oauth2/v2.0/authorize", true}, + {"/oauth/register", true}, + {"/oauth/callback", true}, + {"/oauth2/v2.0/token", true}, + {"/oauth/introspect", true}, + {"/health", true}, + {"/ping", true}, + {"/test", false}, + {"/mcp", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.path, nil) + result := middleware.shouldSkipAuth(req) + if result != tt.expected { + t.Errorf("Expected %v for path %s, got %v", tt.expected, tt.path, result) + } + }) + } +} + +func TestGetTokenInfo(t *testing.T) { + // Test with valid token info + tokenInfo := &auth.TokenInfo{ + AccessToken: "test-token", + TokenType: "Bearer", + Subject: "user123", + } + + ctx := context.WithValue(context.Background(), tokenInfoKey, tokenInfo) + req := httptest.NewRequest("GET", "/test", nil) + req = req.WithContext(ctx) + + retrievedTokenInfo, ok := GetTokenInfo(req) + if !ok { + t.Error("Expected to find token info in context") + } + + if retrievedTokenInfo.AccessToken != "test-token" { + t.Errorf("Expected access token test-token, got %s", retrievedTokenInfo.AccessToken) + } + + // Test without token info + req = httptest.NewRequest("GET", "/test", nil) + _, ok = GetTokenInfo(req) + if ok { + t.Error("Expected not to find token info in empty context") + } +} diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go new file mode 100644 index 0000000..d6ec9bc --- /dev/null +++ b/internal/auth/oauth/provider.go @@ -0,0 +1,523 @@ +package oauth + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "math/big" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Azure/aks-mcp/internal/auth" + internalConfig "github.com/Azure/aks-mcp/internal/config" + "github.com/golang-jwt/jwt/v5" +) + +// AzureOAuthProvider implements OAuth authentication for Azure AD +type AzureOAuthProvider struct { + config *auth.OAuthConfig + httpClient *http.Client + keyCache *keyCache + enableCache bool +} + +// keyCache caches Azure AD signing keys +type keyCache struct { + keys map[string]*rsa.PublicKey + expiresAt time.Time + mu sync.RWMutex +} + +// AzureADMetadata represents Azure AD OAuth metadata +type AzureADMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + JWKSUri string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` +} + +// ProtectedResourceMetadata represents MCP protected resource metadata (RFC 9728 compliant) +type ProtectedResourceMetadata struct { + AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + ScopesSupported []string `json:"scopes_supported"` +} + +// ClientRegistrationRequest represents OAuth 2.0 Dynamic Client Registration request (RFC 7591) +type ClientRegistrationRequest struct { + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + ClientName string `json:"client_name"` + ClientURI string `json:"client_uri"` + Scope string `json:"scope"` +} + +// NewAzureOAuthProvider creates a new Azure OAuth provider +func NewAzureOAuthProvider(config *auth.OAuthConfig) (*AzureOAuthProvider, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid OAuth config: %w", err) + } + + return &AzureOAuthProvider{ + config: config, + enableCache: internalConfig.EnableCache, // Use config constant for cache control + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + keyCache: &keyCache{ + keys: make(map[string]*rsa.PublicKey), + }, + }, nil +} + +// GetProtectedResourceMetadata returns OAuth 2.0 Protected Resource Metadata (RFC 9728) +func (p *AzureOAuthProvider) GetProtectedResourceMetadata(serverURL string) (*ProtectedResourceMetadata, error) { + // For MCP compliance, point to our local authorization server proxy + // which properly advertises PKCE support + parsedURL, err := url.Parse(serverURL) + if err != nil { + return nil, fmt.Errorf("invalid server URL: %v", err) + } + + // Use the same scheme and host as the server URL + authServerURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) + + // RFC 9728 requires the resource field to identify this MCP server + return &ProtectedResourceMetadata{ + AuthorizationServers: []string{authServerURL}, + Resource: serverURL, // Required by MCP spec + ScopesSupported: p.config.RequiredScopes, + }, nil +} + +// GetAuthorizationServerMetadata returns OAuth 2.0 Authorization Server Metadata (RFC 8414) +func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (*AzureADMetadata, error) { + metadataURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", p.config.TenantID) + log.Printf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL) + + resp, err := p.httpClient.Get(metadataURL) + if err != nil { + log.Printf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err) + return nil, fmt.Errorf("failed to fetch metadata from %s: %w", metadataURL, err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() + + if resp.StatusCode == http.StatusNotFound { + log.Printf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID) + return nil, fmt.Errorf("tenant ID '%s' not found (HTTP 404). Please verify your Azure AD tenant ID is correct", p.config.TenantID) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Printf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("OAuth ERROR: Failed to read metadata response: %v", err) + return nil, fmt.Errorf("failed to read metadata response: %w", err) + } + + var metadata AzureADMetadata + if err := json.Unmarshal(body, &metadata); err != nil { + log.Printf("OAuth ERROR: Failed to parse metadata JSON: %v", err) + return nil, fmt.Errorf("failed to parse metadata: %w", err) + } + + log.Printf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported) + + // Ensure grant_types_supported is populated for MCP Inspector compatibility + if len(metadata.GrantTypesSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)") + metadata.GrantTypesSupported = []string{"authorization_code", "refresh_token"} + } + + // Ensure response_types_supported is populated for MCP Inspector compatibility + if len(metadata.ResponseTypesSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)") + metadata.ResponseTypesSupported = []string{"code"} + } + + // Ensure subject_types_supported is populated for MCP Inspector compatibility + if len(metadata.SubjectTypesSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)") + metadata.SubjectTypesSupported = []string{"public"} + } + + // Ensure token_endpoint_auth_methods_supported is populated for MCP Inspector compatibility + if len(metadata.TokenEndpointAuthMethodsSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)") + metadata.TokenEndpointAuthMethodsSupported = []string{"none"} + } + + // Add S256 code challenge method support (Azure AD supports this but may not advertise it) + // MCP specification requires S256 support, so we always ensure it's present + log.Printf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)") + metadata.CodeChallengeMethodsSupported = []string{"S256"} + + // Azure AD v2.0 has limited support for RFC 8707 Resource Indicators + // - Authorization endpoint: doesn't support resource parameter + // - Token endpoint: doesn't support resource parameter + // - Uses scope-based resource identification instead + // Our proxy handles MCP resource parameter translation + parsedURL, err := url.Parse(serverURL) + if err == nil { + // If the server URL includes /mcp path, include it in the proxy endpoint + proxyPath := "/oauth2/v2.0/authorize" + tokenPath := "/oauth2/v2.0/token" // #nosec G101 -- This is an OAuth endpoint path, not credentials + registrationPath := "/oauth/register" + proxyAuthURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, proxyPath) + tokenURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, tokenPath) + registrationURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, registrationPath) + + metadata.AuthorizationEndpoint = proxyAuthURL + metadata.TokenEndpoint = tokenURL + // Add dynamic client registration endpoint + metadata.RegistrationEndpoint = registrationURL + } + + log.Printf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v", + metadata.GrantTypesSupported, metadata.ResponseTypesSupported, metadata.CodeChallengeMethodsSupported) + + return &metadata, nil +} + +// ValidateToken validates an OAuth access token +func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { + // JWTs have three parts (header.payload.signature) separated by two dots. + const jwtExpectedDotCount = 2 + + dotCount := strings.Count(tokenString, ".") + if dotCount != jwtExpectedDotCount { + return nil, fmt.Errorf("invalid JWT token format: expected 3 parts separated by dots, got %d dots", dotCount) + } + + // SECURITY WARNING: JWT validation bypass - for development and testing ONLY + // ValidateJWT should ALWAYS be true in production environments + // This bypass creates a significant security vulnerability if enabled in production + if !p.config.TokenValidation.ValidateJWT { + log.Printf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing") + return &auth.TokenInfo{ + AccessToken: tokenString, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration + Scope: p.config.RequiredScopes, // Use configured scopes + Subject: "unknown", // Cannot extract without parsing + Audience: []string{p.config.TokenValidation.ExpectedAudience}, + Issuer: fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.config.TenantID), + Claims: make(map[string]interface{}), + }, nil + } + + // Parse and validate JWT token + + // Parse token structure and check expiration + parserUnsafe := jwt.NewParser(jwt.WithoutClaimsValidation()) + tokenUnsafe, _, err := parserUnsafe.ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("invalid token structure: %w", err) + } + + // Check claims and expiration + if claims, ok := tokenUnsafe.Claims.(jwt.MapClaims); ok { + if exp, ok := claims["exp"].(float64); ok { + expTime := time.Unix(int64(exp), 0) + if time.Now().After(expTime) { + return nil, fmt.Errorf("token expired at %v", expTime) + } + } + } + + // JWT signature validation + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, err := parser.ParseWithClaims(tokenString, jwt.MapClaims{}, p.getKeyFunc) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + if !token.Valid { + return nil, fmt.Errorf("invalid token") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid token claims") + } + + // Validate issuer - with Azure Management API scope, we should get v2.0 format + issuer, ok := claims["iss"].(string) + if !ok { + return nil, fmt.Errorf("missing issuer claim") + } + + expectedIssuerV2 := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.config.TenantID) + expectedIssuerV1 := fmt.Sprintf("https://sts.windows.net/%s/", p.config.TenantID) + + if issuer != expectedIssuerV2 && issuer != expectedIssuerV1 { + return nil, fmt.Errorf("invalid issuer: expected %s (preferred) or %s (fallback), got %s", expectedIssuerV2, expectedIssuerV1, issuer) + } + + // Azure AD may return v1.0 or v2.0 issuer format depending on token scope + + // Validate audience and resource binding + if p.config.TokenValidation.ValidateAudience { + if err := p.validateAudience(claims); err != nil { + return nil, err + } + } + + // Extract token information + tokenInfo := &auth.TokenInfo{ + AccessToken: tokenString, + TokenType: "Bearer", + Claims: claims, + } + + // Extract subject + if sub, ok := claims["sub"].(string); ok { + tokenInfo.Subject = sub + } + + // Extract audience + if aud, ok := claims["aud"].(string); ok { + tokenInfo.Audience = []string{aud} + } else if audSlice, ok := claims["aud"].([]interface{}); ok { + for _, a := range audSlice { + if audStr, ok := a.(string); ok { + tokenInfo.Audience = append(tokenInfo.Audience, audStr) + } + } + } + + // Extract scope from Azure AD token + // Check for 'scp' claim (Azure AD v2.0) + if scp, ok := claims["scp"].(string); ok { + tokenInfo.Scope = strings.Split(scp, " ") + } else if scope, ok := claims["scope"].(string); ok { + // Check for 'scope' claim (alternative) + tokenInfo.Scope = strings.Split(scope, " ") + } + + // Check for 'roles' claim (Azure AD app roles) + if roles, ok := claims["roles"].([]interface{}); ok { + for _, role := range roles { + if roleStr, ok := role.(string); ok { + tokenInfo.Scope = append(tokenInfo.Scope, roleStr) + } + } + } + + // Extract expiration + if exp, ok := claims["exp"].(float64); ok { + tokenInfo.ExpiresAt = time.Unix(int64(exp), 0) + } + + // Set issuer + tokenInfo.Issuer = issuer + + return tokenInfo, nil +} + +// validateAudience validates the audience claim and resource binding (RFC 8707) +func (p *AzureOAuthProvider) validateAudience(claims jwt.MapClaims) error { + expectedAudience := p.config.TokenValidation.ExpectedAudience + + // Normalize expected audience - remove trailing slash for comparison + normalizedExpected := strings.TrimSuffix(expectedAudience, "/") + + // Check single audience + if aud, ok := claims["aud"].(string); ok { + normalizedAud := strings.TrimSuffix(aud, "/") + if normalizedAud == normalizedExpected || aud == p.config.ClientID { + return nil + } + return fmt.Errorf("invalid audience: expected %s or %s, got %s", expectedAudience, p.config.ClientID, aud) + } + + // Check audience array + if audSlice, ok := claims["aud"].([]interface{}); ok { + for _, a := range audSlice { + if audStr, ok := a.(string); ok { + normalizedAud := strings.TrimSuffix(audStr, "/") + if normalizedAud == normalizedExpected || audStr == p.config.ClientID { + return nil + } + } + } + return fmt.Errorf("invalid audience: expected %s or %s in audience list", expectedAudience, p.config.ClientID) + } + + return fmt.Errorf("missing audience claim") +} + +// getKeyFunc returns a function to retrieve JWT signing keys +func (p *AzureOAuthProvider) getKeyFunc(token *jwt.Token) (interface{}, error) { + // Validate signing method + if token.Method.Alg() != "RS256" { + return nil, fmt.Errorf("unexpected signing method: expected RS256, got %v", token.Method.Alg()) + } + + // Also verify it's an RSA method + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("signing method is not RSA: %T", token.Method) + } + + // Get key ID from token header + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("missing key ID in token header") + } + + // Extract issuer from token to determine the correct JWKS endpoint + var issuer string + if claims, ok := token.Claims.(jwt.MapClaims); ok { + if iss, ok := claims["iss"].(string); ok { + issuer = iss + } + } + + // Get the public key for this key ID using the appropriate issuer + key, err := p.getPublicKey(kid, issuer) + if err != nil { + log.Printf("PUBLIC KEY RETRIEVAL FAILED: %s\n", err) + return nil, fmt.Errorf("failed to get public key: %w", err) + } + + return key, nil +} + +// getPublicKey retrieves and caches Azure AD public keys +func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.PublicKey, error) { + // Generate cache key based on both kid and issuer to avoid conflicts between v1.0 and v2.0 keys + cacheKey := fmt.Sprintf("%s_%s", kid, issuer) + + // Check cache first if caching is enabled + if p.enableCache { + p.keyCache.mu.RLock() + if key, exists := p.keyCache.keys[cacheKey]; exists && time.Now().Before(p.keyCache.expiresAt) { + p.keyCache.mu.RUnlock() + return key, nil + } + p.keyCache.mu.RUnlock() + } + + // With Azure Management API scope, we should always get v2.0 format tokens + // Force using v2.0 JWKS endpoint for consistency + jwksURL := fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/v2.0/keys", p.config.TenantID) + + resp, err := p.httpClient.Get(jwksURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS from %s: %w", jwksURL, err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read JWKS response: %w", err) + } + + var jwks struct { + Keys []struct { + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` + Kty string `json:"kty"` + } `json:"keys"` + } + + if err := json.Unmarshal(body, &jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS: %w", err) + } + + log.Printf("JWKS Contains %d keys, searching for kid=%s\n", len(jwks.Keys), kid) + + // Parse keys and find the target key + var targetKey *rsa.PublicKey + var foundKeyIds []string + + for _, key := range jwks.Keys { + foundKeyIds = append(foundKeyIds, key.Kid) + + if key.Kty == "RSA" && key.Kid == kid { + pubKey, err := parseRSAPublicKey(key.N, key.E) + if err != nil { + log.Printf("JWKS Failed to parse RSA key %s: %v\n", key.Kid, err) + continue + } + targetKey = pubKey + break + } + } + + // Cache the retrieved key and return it (only if caching is enabled) + if targetKey != nil { + if p.enableCache { + p.keyCache.mu.Lock() + if p.keyCache.keys == nil { + p.keyCache.keys = make(map[string]*rsa.PublicKey) + } + p.keyCache.keys[cacheKey] = targetKey + p.keyCache.expiresAt = time.Now().Add(24 * time.Hour) // Cache for 24 hours + p.keyCache.mu.Unlock() + } + return targetKey, nil + } + + return nil, fmt.Errorf("key with ID %s not found in JWKS (available: %v)", kid, foundKeyIds) +} + +// parseRSAPublicKey parses RSA public key from JWK format +func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { + // Decode base64url-encoded modulus + nBytes, err := base64.RawURLEncoding.DecodeString(nStr) + if err != nil { + return nil, fmt.Errorf("failed to decode modulus: %w", err) + } + + // Decode base64url-encoded exponent + eBytes, err := base64.RawURLEncoding.DecodeString(eStr) + if err != nil { + return nil, fmt.Errorf("failed to decode exponent: %w", err) + } + + // Convert bytes to big integers + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + + // Create RSA public key + pubKey := &rsa.PublicKey{ + N: n, + E: int(e.Int64()), + } + + return pubKey, nil +} diff --git a/internal/auth/oauth/provider_test.go b/internal/auth/oauth/provider_test.go new file mode 100644 index 0000000..657f2ba --- /dev/null +++ b/internal/auth/oauth/provider_test.go @@ -0,0 +1,388 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Azure/aks-mcp/internal/auth" +) + +func TestNewAzureOAuthProvider(t *testing.T) { + tests := []struct { + name string + config *auth.OAuthConfig + wantErr bool + }{ + { + name: "valid config should create provider", + config: &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + }, + wantErr: false, + }, + { + name: "invalid config should fail", + config: &auth.OAuthConfig{ + Enabled: true, + // Missing required fields + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewAzureOAuthProvider(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewAzureOAuthProvider() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && provider == nil { + t.Error("NewAzureOAuthProvider() returned nil provider") + } + }) + } +} + +func TestGetProtectedResourceMetadata(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + serverURL := "http://localhost:8000" + metadata, err := provider.GetProtectedResourceMetadata(serverURL) + if err != nil { + t.Fatalf("GetProtectedResourceMetadata() error = %v", err) + } + + expectedAuthServer := "http://localhost:8000" + if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServer { + t.Errorf("Expected authorization server %s, got %v", expectedAuthServer, metadata.AuthorizationServers) + } + + // Note: AzureADProtectedResourceMetadata doesn't include a Resource field. + // The resource URL is implied by the context of the request endpoint. + + if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" { + t.Errorf("Expected scopes %v, got %v", config.RequiredScopes, metadata.ScopesSupported) + } +} + +func TestGetAuthorizationServerMetadataWithDefaults(t *testing.T) { + // Create a mock Azure AD metadata endpoint that's missing some fields + // This simulates the case where Azure AD doesn't provide all required fields + mockMetadata := AzureADMetadata{ + Issuer: "https://login.microsoftonline.com/test-tenant/v2.0", + AuthorizationEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize", + TokenEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token", + JWKSUri: "https://login.microsoftonline.com/test-tenant/discovery/v2.0/keys", + ScopesSupported: []string{"openid", "profile", "email"}, + // Intentionally omit GrantTypesSupported, ResponseTypesSupported, etc. + // to test our default value logic + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(mockMetadata); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Override the HTTP client to use our test server + provider.httpClient = &http.Client{ + Transport: &roundTripperFunc{ + fn: func(req *http.Request) (*http.Response, error) { + // Redirect all requests to our test server + req.URL.Scheme = "http" + req.URL.Host = server.URL[7:] // Remove "http://" + req.URL.Path = "/" + return http.DefaultTransport.RoundTrip(req) + }, + }, + } + + metadata, err := provider.GetAuthorizationServerMetadata(server.URL) + if err != nil { + t.Fatalf("GetAuthorizationServerMetadata() error = %v", err) + } + + // Verify that default values were populated for missing fields + expectedGrantTypes := []string{"authorization_code", "refresh_token"} + if len(metadata.GrantTypesSupported) != len(expectedGrantTypes) { + t.Errorf("Expected %d grant types, got %d", len(expectedGrantTypes), len(metadata.GrantTypesSupported)) + } + for i, expected := range expectedGrantTypes { + if i >= len(metadata.GrantTypesSupported) || metadata.GrantTypesSupported[i] != expected { + t.Errorf("Expected grant type %s at index %d, got %v", expected, i, metadata.GrantTypesSupported) + } + } + + expectedResponseTypes := []string{"code"} + if len(metadata.ResponseTypesSupported) != len(expectedResponseTypes) { + t.Errorf("Expected %d response types, got %d", len(expectedResponseTypes), len(metadata.ResponseTypesSupported)) + } + if len(metadata.ResponseTypesSupported) > 0 && metadata.ResponseTypesSupported[0] != "code" { + t.Errorf("Expected response type 'code', got %s", metadata.ResponseTypesSupported[0]) + } + + expectedSubjectTypes := []string{"public"} + if len(metadata.SubjectTypesSupported) != len(expectedSubjectTypes) { + t.Errorf("Expected %d subject types, got %d", len(expectedSubjectTypes), len(metadata.SubjectTypesSupported)) + } + if len(metadata.SubjectTypesSupported) > 0 && metadata.SubjectTypesSupported[0] != "public" { + t.Errorf("Expected subject type 'public', got %s", metadata.SubjectTypesSupported[0]) + } + + expectedTokenEndpointAuthMethods := []string{"none"} + if len(metadata.TokenEndpointAuthMethodsSupported) != len(expectedTokenEndpointAuthMethods) { + t.Errorf("Expected %d auth methods, got %d", len(expectedTokenEndpointAuthMethods), len(metadata.TokenEndpointAuthMethodsSupported)) + } + if len(metadata.TokenEndpointAuthMethodsSupported) > 0 && metadata.TokenEndpointAuthMethodsSupported[0] != "none" { + t.Errorf("Expected auth method 'none', got %s", metadata.TokenEndpointAuthMethodsSupported[0]) + } + + // Verify that PKCE is properly configured + expectedCodeChallengeMethods := []string{"S256"} + if len(metadata.CodeChallengeMethodsSupported) != len(expectedCodeChallengeMethods) { + t.Errorf("Expected %d code challenge methods, got %d", len(expectedCodeChallengeMethods), len(metadata.CodeChallengeMethodsSupported)) + } + if len(metadata.CodeChallengeMethodsSupported) > 0 && metadata.CodeChallengeMethodsSupported[0] != "S256" { + t.Errorf("Expected code challenge method 'S256', got %s", metadata.CodeChallengeMethodsSupported[0]) + } +} + +func TestGetAuthorizationServerMetadata(t *testing.T) { + // Create a mock Azure AD metadata endpoint + mockMetadata := AzureADMetadata{ + Issuer: "https://login.microsoftonline.com/test-tenant/v2.0", + AuthorizationEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize", + TokenEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token", + JWKSUri: "https://login.microsoftonline.com/test-tenant/discovery/v2.0/keys", + ScopesSupported: []string{"openid", "profile", "email"}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(mockMetadata); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Override the HTTP client to use our test server + provider.httpClient = &http.Client{ + Transport: &roundTripperFunc{ + fn: func(req *http.Request) (*http.Response, error) { + // Redirect all requests to our test server + req.URL.Scheme = "http" + req.URL.Host = server.URL[7:] // Remove "http://" + req.URL.Path = "/" + return http.DefaultTransport.RoundTrip(req) + }, + }, + } + + metadata, err := provider.GetAuthorizationServerMetadata(server.URL) + if err != nil { + t.Fatalf("GetAuthorizationServerMetadata() error = %v", err) + } + + if metadata.Issuer != mockMetadata.Issuer { + t.Errorf("Expected issuer %s, got %s", mockMetadata.Issuer, metadata.Issuer) + } + + expectedAuthEndpoint := fmt.Sprintf("%s/oauth2/v2.0/authorize", server.URL) + if metadata.AuthorizationEndpoint != expectedAuthEndpoint { + t.Errorf("Expected auth endpoint %s, got %s", expectedAuthEndpoint, metadata.AuthorizationEndpoint) + } +} + +func TestValidateTokenWithoutJWT(t *testing.T) { + // SECURITY WARNING: This test verifies the JWT validation bypass functionality + // ValidateJWT=false should ONLY be used in development/testing environments + // This functionality should NEVER be enabled in production + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, // Disable JWT validation + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + ctx := context.Background() + // Use a token that looks like a JWT to pass initial format checks + testToken := "header.payload.signature" + tokenInfo, err := provider.ValidateToken(ctx, testToken) + if err != nil { + t.Fatalf("ValidateToken() error = %v", err) + } + + if tokenInfo.AccessToken != testToken { + t.Errorf("Expected access token %s, got %s", testToken, tokenInfo.AccessToken) + } + + if tokenInfo.TokenType != "Bearer" { + t.Errorf("Expected token type Bearer, got %s", tokenInfo.TokenType) + } +} + +func TestValidateAudience(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client-id", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + tests := []struct { + name string + claims map[string]interface{} + wantErr bool + }{ + { + name: "valid audience string", + claims: map[string]interface{}{ + "aud": "https://management.azure.com/", + }, + wantErr: false, + }, + { + name: "valid client ID audience", + claims: map[string]interface{}{ + "aud": "test-client-id", + }, + wantErr: false, + }, + { + name: "valid audience array", + claims: map[string]interface{}{ + "aud": []interface{}{"https://management.azure.com/", "other-aud"}, + }, + wantErr: false, + }, + { + name: "invalid audience", + claims: map[string]interface{}{ + "aud": "invalid-audience", + }, + wantErr: true, + }, + { + name: "missing audience", + claims: map[string]interface{}{ + "sub": "user123", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := provider.validateAudience(tt.claims) + if (err != nil) != tt.wantErr { + t.Errorf("validateAudience() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// roundTripperFunc is a helper type for creating custom HTTP transports in tests +type roundTripperFunc struct { + fn func(*http.Request) (*http.Response, error) +} + +func (f *roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f.fn(req) +} diff --git a/internal/auth/types.go b/internal/auth/types.go new file mode 100644 index 0000000..d89cca6 --- /dev/null +++ b/internal/auth/types.go @@ -0,0 +1,140 @@ +package auth + +import ( + "fmt" + "time" +) + +// OAuthConfig represents OAuth configuration for AKS-MCP +type OAuthConfig struct { + // Enable OAuth authentication + Enabled bool `json:"enabled"` + + // Azure AD tenant ID + TenantID string `json:"tenant_id"` + + // Azure AD application (client) ID + ClientID string `json:"client_id"` + + // Required OAuth scopes for accessing AKS-MCP + RequiredScopes []string `json:"required_scopes"` + + // Allowed redirect URIs for OAuth callback + RedirectURIs []string `json:"redirect_uris"` + + // Allowed CORS origins for OAuth endpoints (for security, wildcard "*" should be avoided) + AllowedOrigins []string `json:"allowed_origins"` + + // Token validation settings + TokenValidation TokenValidationConfig `json:"token_validation"` +} + +// TokenValidationConfig represents token validation configuration +type TokenValidationConfig struct { + // SECURITY CRITICAL: Enable JWT token validation + // Setting this to false creates a security vulnerability - for development/testing ONLY + // MUST be true in production environments + ValidateJWT bool `json:"validate_jwt"` + + // Enable audience validation + ValidateAudience bool `json:"validate_audience"` + + // Expected audience for tokens + ExpectedAudience string `json:"expected_audience"` + + // Token cache TTL + CacheTTL time.Duration `json:"cache_ttl"` + + // Clock skew tolerance for token validation + ClockSkew time.Duration `json:"clock_skew"` +} + +// TokenInfo represents validated token information +type TokenInfo struct { + // Access token + AccessToken string `json:"access_token"` + + // Token type (usually "Bearer") + TokenType string `json:"token_type"` + + // Token expiration time + ExpiresAt time.Time `json:"expires_at"` + + // Token scope + Scope []string `json:"scope"` + + // Subject (user ID) + Subject string `json:"subject"` + + // Audience + Audience []string `json:"audience"` + + // Issuer + Issuer string `json:"issuer"` + + // Additional claims + Claims map[string]interface{} `json:"claims"` +} + +// AuthResult represents the result of authentication +type AuthResult struct { + // Whether authentication was successful + Authenticated bool `json:"authenticated"` + + // Token information (if authenticated) + TokenInfo *TokenInfo `json:"token_info,omitempty"` + + // Error message (if authentication failed) + Error string `json:"error,omitempty"` + + // HTTP status code to return + StatusCode int `json:"status_code"` +} + +// Default OAuth configuration values +const ( + DefaultTokenCacheTTL = 5 * time.Minute + DefaultClockSkew = 1 * time.Minute + DefaultExpectedAudience = "https://management.azure.com" + AzureADScope = "https://management.azure.com/.default" +) + +// NewDefaultOAuthConfig creates a default OAuth configuration +func NewDefaultOAuthConfig() *OAuthConfig { + return &OAuthConfig{ + Enabled: false, + // Use Azure Management API scope to get v2.0 format tokens + // This ensures we get v2.0 issuer format which works with v2.0 JWKS endpoints + RequiredScopes: []string{AzureADScope}, // "https://management.azure.com/.default" + // RedirectURIs will be populated dynamically based on host/port configuration + RedirectURIs: []string{}, + TokenValidation: TokenValidationConfig{ + ValidateJWT: true, // SECURITY CRITICAL: Always true in production + ValidateAudience: true, // Re-enabled with correct audience + ExpectedAudience: DefaultExpectedAudience, // "https://management.azure.com" + CacheTTL: DefaultTokenCacheTTL, + ClockSkew: DefaultClockSkew, + }, + } +} + +// Validate validates the OAuth configuration +func (cfg *OAuthConfig) Validate() error { + if !cfg.Enabled { + return nil + } + + if cfg.TenantID == "" { + return fmt.Errorf("tenant_id is required when OAuth is enabled") + } + + if cfg.ClientID == "" { + return fmt.Errorf("client_id is required when OAuth is enabled") + } + + // if len(cfg.RequiredScopes) == 0 { + // return fmt.Errorf("at least one required scope must be specified") + // } + + return nil +} diff --git a/internal/auth/types_test.go b/internal/auth/types_test.go new file mode 100644 index 0000000..d6becde --- /dev/null +++ b/internal/auth/types_test.go @@ -0,0 +1,183 @@ +package auth + +import ( + "os" + "testing" + "time" +) + +func TestOAuthConfigValidation(t *testing.T) { + tests := []struct { + name string + config *OAuthConfig + wantErr bool + }{ + { + name: "disabled OAuth should pass validation", + config: &OAuthConfig{ + Enabled: false, + }, + wantErr: false, + }, + { + name: "enabled OAuth with missing tenant ID should fail", + config: &OAuthConfig{ + Enabled: true, + ClientID: "test-client-id", + RequiredScopes: []string{"scope1"}, + }, + wantErr: true, + }, + { + name: "enabled OAuth with missing client ID should fail", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + RequiredScopes: []string{"scope1"}, + }, + wantErr: true, + }, + { + name: "enabled OAuth with empty scopes should pass", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{}, + }, + wantErr: false, + }, + { + name: "valid enabled OAuth config should pass", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{"scope1"}, + }, + wantErr: false, + }, + { + name: "valid enabled OAuth config with full token validation should pass", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{"scope1"}, + TokenValidation: TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: DefaultTokenCacheTTL, + ClockSkew: DefaultClockSkew, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("OAuthConfig.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNewDefaultOAuthConfig(t *testing.T) { + config := NewDefaultOAuthConfig() + + if config.Enabled { + t.Error("Default config should have OAuth disabled") + } + + if len(config.RequiredScopes) != 1 || config.RequiredScopes[0] != AzureADScope { + t.Errorf("Default config should have Azure AD scope, got %v", config.RequiredScopes) + } + + if !config.TokenValidation.ValidateJWT { + t.Error("Default config should enable JWT validation for production") + } + + if !config.TokenValidation.ValidateAudience { + t.Error("Default config should enable audience validation for security") + } + + if config.TokenValidation.ExpectedAudience != DefaultExpectedAudience { + t.Errorf("Default config should have correct expected audience, got %s", config.TokenValidation.ExpectedAudience) + } + + if config.TokenValidation.CacheTTL != DefaultTokenCacheTTL { + t.Errorf("Default config should have correct cache TTL, got %v", config.TokenValidation.CacheTTL) + } + + if config.TokenValidation.ClockSkew != DefaultClockSkew { + t.Errorf("Default config should have correct clock skew, got %v", config.TokenValidation.ClockSkew) + } +} + +func TestOAuthConfigConstants(t *testing.T) { + if DefaultTokenCacheTTL != 5*time.Minute { + t.Errorf("DefaultTokenCacheTTL should be 5 minutes, got %v", DefaultTokenCacheTTL) + } + + if DefaultClockSkew != 1*time.Minute { + t.Errorf("DefaultClockSkew should be 1 minute, got %v", DefaultClockSkew) + } + + if DefaultExpectedAudience != "https://management.azure.com" { + t.Errorf("DefaultExpectedAudience should be Azure management, got %s", DefaultExpectedAudience) + } + + if AzureADScope != "https://management.azure.com/.default" { + t.Errorf("AzureADScope should be Azure management default, got %s", AzureADScope) + } +} + +func TestOAuthConfigEnvironmentVariables(t *testing.T) { + // Test that environment variables are respected + oldTenantID := os.Getenv("AZURE_TENANT_ID") + oldClientID := os.Getenv("AZURE_CLIENT_ID") + + defer func() { + if err := os.Setenv("AZURE_TENANT_ID", oldTenantID); err != nil { + t.Logf("Failed to restore AZURE_TENANT_ID: %v", err) + } + if err := os.Setenv("AZURE_CLIENT_ID", oldClientID); err != nil { + t.Logf("Failed to restore AZURE_CLIENT_ID: %v", err) + } + }() + + if err := os.Setenv("AZURE_TENANT_ID", "env-tenant-id"); err != nil { + t.Fatalf("Failed to set AZURE_TENANT_ID: %v", err) + } + if err := os.Setenv("AZURE_CLIENT_ID", "env-client-id"); err != nil { + t.Fatalf("Failed to set AZURE_CLIENT_ID: %v", err) + } + + config := NewDefaultOAuthConfig() + config.Enabled = true + + // Simulate the environment variable loading that happens in config parsing + if config.TenantID == "" { + config.TenantID = os.Getenv("AZURE_TENANT_ID") + } + if config.ClientID == "" { + config.ClientID = os.Getenv("AZURE_CLIENT_ID") + } + + if config.TenantID != "env-tenant-id" { + t.Errorf("Expected tenant ID from environment, got %s", config.TenantID) + } + + if config.ClientID != "env-client-id" { + t.Errorf("Expected client ID from environment, got %s", config.ClientID) + } + + // Should pass validation with environment variables + if err := config.Validate(); err != nil { + t.Errorf("Config with environment variables should be valid, got error: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index a653188..b60f068 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,15 +5,37 @@ import ( "fmt" "log" "os" + "regexp" "strings" "time" + "github.com/Azure/aks-mcp/internal/auth" "github.com/Azure/aks-mcp/internal/security" "github.com/Azure/aks-mcp/internal/telemetry" "github.com/Azure/aks-mcp/internal/version" flag "github.com/spf13/pflag" ) +// EnableCache controls whether caching is enabled globally +// Cache is enabled by default for production performance +// This affects both web cache headers and AzureOAuthProvider cache +// Can be disabled via DISABLE_CACHE environment variable +var EnableCache = os.Getenv("DISABLE_CACHE") != "true" + +// validateGUID validates that a value is in valid GUID format +func validateGUID(value, name string) error { + if value == "" { + return nil // Empty values are allowed (will be handled by OAuth validation) + } + + // GUID pattern: 8-4-4-4-12 hexadecimal digits with hyphens + guidRegex := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + if !guidRegex.MatchString(value) { + return fmt.Errorf("%s must be a valid GUID format (e.g., 12345678-1234-1234-1234-123456789abc), got: %s", name, value) + } + return nil +} + // ConfigData holds the global configuration type ConfigData struct { // Command execution timeout in seconds @@ -22,6 +44,8 @@ type ConfigData struct { CacheTimeout time.Duration // Security configuration SecurityConfig *security.SecurityConfig + // OAuth configuration + OAuthConfig *auth.OAuthConfig // Command-line specific options Transport string @@ -51,6 +75,7 @@ func NewConfig() *ConfigData { Timeout: 60, CacheTimeout: 1 * time.Minute, SecurityConfig: security.NewSecurityConfig(), + OAuthConfig: auth.NewDefaultOAuthConfig(), Transport: "stdio", Port: 8000, AccessLevel: "readonly", @@ -66,9 +91,23 @@ func (cfg *ConfigData) ParseFlags() { flag.StringVar(&cfg.Host, "host", "127.0.0.1", "Host to listen for the server (only used with transport sse or streamable-http)") flag.IntVar(&cfg.Port, "port", 8000, "Port to listen for the server (only used with transport sse or streamable-http)") flag.IntVar(&cfg.Timeout, "timeout", 600, "Timeout for command execution in seconds, default is 600s") + // Security settings flag.StringVar(&cfg.AccessLevel, "access-level", "readonly", "Access level (readonly, readwrite, admin)") + // OAuth configuration + flag.BoolVar(&cfg.OAuthConfig.Enabled, "oauth-enabled", false, "Enable OAuth authentication") + flag.StringVar(&cfg.OAuthConfig.TenantID, "oauth-tenant-id", "", "Azure AD tenant ID for OAuth (fallback to AZURE_TENANT_ID env var)") + flag.StringVar(&cfg.OAuthConfig.ClientID, "oauth-client-id", "", "Azure AD client ID for OAuth (fallback to AZURE_CLIENT_ID env var)") + + // OAuth redirect URIs configuration + additionalRedirectURIs := flag.String("oauth-redirects", "", + "Comma-separated list of additional OAuth redirect URIs (e.g. http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback)") + + // OAuth CORS origins configuration + allowedCORSOrigins := flag.String("oauth-cors-origins", "", + "Comma-separated list of allowed CORS origins for OAuth endpoints (e.g. http://localhost:6274). If empty, no cross-origin requests are allowed for security") + // Kubernetes-specific settings additionalTools := flag.String("additional-tools", "", "Comma-separated list of additional Kubernetes tools to support (kubectl is always enabled). Available: helm,cilium,hubble") @@ -113,6 +152,12 @@ func (cfg *ConfigData) ParseFlags() { cfg.SecurityConfig.AccessLevel = cfg.AccessLevel cfg.SecurityConfig.AllowedNamespaces = cfg.AllowNamespaces + // Parse OAuth configuration + if err := cfg.parseOAuthConfig(*additionalRedirectURIs, *allowedCORSOrigins); err != nil { + fmt.Printf("OAuth configuration error: %v\n", err) + os.Exit(1) + } + // Parse additional tools if *additionalTools != "" { tools := strings.Split(*additionalTools, ",") @@ -122,6 +167,96 @@ func (cfg *ConfigData) ParseFlags() { } } +// parseOAuthConfig parses OAuth-related command line arguments +func (cfg *ConfigData) parseOAuthConfig(additionalRedirectURIs, allowedCORSOrigins string) error { + // Note: OAuth scopes are automatically configured to use "https://management.azure.com/.default" + // and are not configurable via command line per design + + // Track configuration sources for logging + var tenantIDSource, clientIDSource string + + // Load OAuth configuration from environment variables if not set via CLI + if cfg.OAuthConfig.TenantID == "" { + if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" { + cfg.OAuthConfig.TenantID = tenantID + tenantIDSource = "environment variable AZURE_TENANT_ID" + log.Printf("OAuth Config: Using tenant ID from environment variable AZURE_TENANT_ID") + } + } else { + tenantIDSource = "command line flag --oauth-tenant-id" + log.Printf("OAuth Config: Using tenant ID from command line flag --oauth-tenant-id") + } + + if cfg.OAuthConfig.ClientID == "" { + if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" { + cfg.OAuthConfig.ClientID = clientID + clientIDSource = "environment variable AZURE_CLIENT_ID" + log.Printf("OAuth Config: Using client ID from environment variable AZURE_CLIENT_ID") + } + } else { + clientIDSource = "command line flag --oauth-client-id" + log.Printf("OAuth Config: Using client ID from command line flag --oauth-client-id") + } + + // Validate GUID formats for tenant ID and client ID + if err := validateGUID(cfg.OAuthConfig.TenantID, "OAuth tenant ID"); err != nil { + return fmt.Errorf("invalid OAuth tenant ID from %s: %w", tenantIDSource, err) + } + + if err := validateGUID(cfg.OAuthConfig.ClientID, "OAuth client ID"); err != nil { + return fmt.Errorf("invalid OAuth client ID from %s: %w", clientIDSource, err) + } + + // Set redirect URIs based on configured host and port + if cfg.OAuthConfig.Enabled { + redirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", cfg.Host, cfg.Port) + cfg.OAuthConfig.RedirectURIs = []string{redirectURI} + + // Add localhost variant if using 127.0.0.1 + if cfg.Host == "127.0.0.1" { + localhostURI := fmt.Sprintf("http://localhost:%d/oauth/callback", cfg.Port) + cfg.OAuthConfig.RedirectURIs = append(cfg.OAuthConfig.RedirectURIs, localhostURI) + } + + // Add additional redirect URIs from command line flag + if additionalRedirectURIs != "" { + additionalURIs := strings.Split(additionalRedirectURIs, ",") + for _, uri := range additionalURIs { + trimmedURI := strings.TrimSpace(uri) + if trimmedURI != "" { + cfg.OAuthConfig.RedirectURIs = append(cfg.OAuthConfig.RedirectURIs, trimmedURI) + } + } + } + } + + // Parse allowed CORS origins for OAuth endpoints + if allowedCORSOrigins != "" { + log.Printf("OAuth Config: Setting allowed CORS origins from command line flag --oauth-cors-origins") + origins := strings.Split(allowedCORSOrigins, ",") + for _, origin := range origins { + trimmedOrigin := strings.TrimSpace(origin) + if trimmedOrigin != "" { + cfg.OAuthConfig.AllowedOrigins = append(cfg.OAuthConfig.AllowedOrigins, trimmedOrigin) + } + } + } else { + log.Printf("OAuth Config: No CORS origins configured - cross-origin requests will be blocked for security") + } + + return nil +} + +// ValidateConfig validates the configuration for incompatible settings +func (cfg *ConfigData) ValidateConfig() error { + // Validate OAuth + transport compatibility + if cfg.OAuthConfig.Enabled && cfg.Transport == "stdio" { + return fmt.Errorf("OAuth authentication is not supported with stdio transport per MCP specification") + } + + return nil +} + // InitializeTelemetry initializes the telemetry service func (cfg *ConfigData) InitializeTelemetry(ctx context.Context, serviceName, serviceVersion string) { // Create telemetry configuration diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..56ad2e4 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,239 @@ +package config + +import ( + "testing" +) + +func TestBasicOAuthConfig(t *testing.T) { + // Test basic OAuth configuration parsing with valid GUIDs + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.OAuthConfig.TenantID = "12345678-1234-1234-1234-123456789abc" + cfg.OAuthConfig.ClientID = "87654321-4321-4321-4321-cba987654321" + + // Parse OAuth configuration + if err := cfg.parseOAuthConfig("", ""); err != nil { + t.Fatalf("Unexpected error in parseOAuthConfig: %v", err) + } + + // Verify basic configuration is preserved + if !cfg.OAuthConfig.Enabled { + t.Error("Expected OAuth to be enabled") + } + if cfg.OAuthConfig.TenantID != "12345678-1234-1234-1234-123456789abc" { + t.Errorf("Expected tenant ID '12345678-1234-1234-1234-123456789abc', got %s", cfg.OAuthConfig.TenantID) + } + if cfg.OAuthConfig.ClientID != "87654321-4321-4321-4321-cba987654321" { + t.Errorf("Expected client ID '87654321-4321-4321-4321-cba987654321', got %s", cfg.OAuthConfig.ClientID) + } +} + +func TestOAuthRedirectURIsConfig(t *testing.T) { + // Test OAuth redirect URIs configuration with additional URIs + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.Host = "127.0.0.1" + cfg.Port = 8081 + + // Test with additional redirect URIs + additionalRedirectURIs := "http://localhost:6274/oauth/callback,http://localhost:8080/oauth/callback" + if err := cfg.parseOAuthConfig(additionalRedirectURIs, ""); err != nil { + t.Fatalf("Unexpected error in parseOAuthConfig: %v", err) + } + + // Should have default URIs plus additional ones + expectedURIs := []string{ + "http://127.0.0.1:8081/oauth/callback", + "http://localhost:8081/oauth/callback", + "http://localhost:6274/oauth/callback", + "http://localhost:8080/oauth/callback", + } + + if len(cfg.OAuthConfig.RedirectURIs) != len(expectedURIs) { + t.Errorf("Expected %d redirect URIs, got %d", len(expectedURIs), len(cfg.OAuthConfig.RedirectURIs)) + } + + for i, expected := range expectedURIs { + if i >= len(cfg.OAuthConfig.RedirectURIs) || cfg.OAuthConfig.RedirectURIs[i] != expected { + t.Errorf("Expected redirect URI '%s' at index %d, got '%s'", expected, i, + func() string { + if i < len(cfg.OAuthConfig.RedirectURIs) { + return cfg.OAuthConfig.RedirectURIs[i] + } + return "missing" + }()) + } + } +} + +func TestOAuthRedirectURIsEmptyAdditional(t *testing.T) { + // Test OAuth redirect URIs configuration without additional URIs + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.Host = "127.0.0.1" + cfg.Port = 8081 + + // Test with empty additional redirect URIs + if err := cfg.parseOAuthConfig("", ""); err != nil { + t.Fatalf("Unexpected error in parseOAuthConfig: %v", err) + } + + // Should have only default URIs + expectedURIs := []string{ + "http://127.0.0.1:8081/oauth/callback", + "http://localhost:8081/oauth/callback", + } + + if len(cfg.OAuthConfig.RedirectURIs) != len(expectedURIs) { + t.Errorf("Expected %d redirect URIs, got %d", len(expectedURIs), len(cfg.OAuthConfig.RedirectURIs)) + } + + for i, expected := range expectedURIs { + if cfg.OAuthConfig.RedirectURIs[i] != expected { + t.Errorf("Expected redirect URI '%s' at index %d, got '%s'", expected, i, cfg.OAuthConfig.RedirectURIs[i]) + } + } +} + +func TestValidateGUID(t *testing.T) { + tests := []struct { + name string + value string + fieldName string + wantErr bool + }{ + { + name: "valid GUID", + value: "12345678-1234-1234-1234-123456789abc", + fieldName: "test field", + wantErr: false, + }, + { + name: "valid GUID uppercase", + value: "12345678-1234-1234-1234-123456789ABC", + fieldName: "test field", + wantErr: false, + }, + { + name: "empty value allowed", + value: "", + fieldName: "test field", + wantErr: false, + }, + { + name: "invalid format - missing hyphens", + value: "123456781234123412341234567890ab", + fieldName: "test field", + wantErr: true, + }, + { + name: "invalid format - wrong length", + value: "12345678-1234-1234-1234-123456789", + fieldName: "test field", + wantErr: true, + }, + { + name: "invalid format - non-hex characters", + value: "12345678-1234-1234-1234-123456789abg", + fieldName: "test field", + wantErr: true, + }, + { + name: "invalid format - extra hyphens", + value: "12345678-1234-1234-1234-1234-56789abc", + fieldName: "test field", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateGUID(tt.value, tt.fieldName) + if (err != nil) != tt.wantErr { + t.Errorf("validateGUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err != nil { + // Verify error message contains the field name and value + errorMsg := err.Error() + if !contains(errorMsg, tt.fieldName) { + t.Errorf("Error message should contain field name '%s', got: %s", tt.fieldName, errorMsg) + } + if tt.value != "" && !contains(errorMsg, tt.value) { + t.Errorf("Error message should contain value '%s', got: %s", tt.value, errorMsg) + } + } + }) + } +} + +func TestOAuthGUIDValidation(t *testing.T) { + tests := []struct { + name string + tenantID string + clientID string + wantErr bool + }{ + { + name: "valid GUIDs", + tenantID: "12345678-1234-1234-1234-123456789abc", + clientID: "87654321-4321-4321-4321-cba987654321", + wantErr: false, + }, + { + name: "empty values allowed", + tenantID: "", + clientID: "", + wantErr: false, + }, + { + name: "invalid tenant ID", + tenantID: "invalid-tenant-id", + clientID: "87654321-4321-4321-4321-cba987654321", + wantErr: true, + }, + { + name: "invalid client ID", + tenantID: "12345678-1234-1234-1234-123456789abc", + clientID: "invalid-client-id", + wantErr: true, + }, + { + name: "both invalid", + tenantID: "invalid-tenant", + clientID: "invalid-client", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.OAuthConfig.TenantID = tt.tenantID + cfg.OAuthConfig.ClientID = tt.clientID + cfg.Host = "127.0.0.1" + cfg.Port = 8081 + + err := cfg.parseOAuthConfig("", "") + if (err != nil) != tt.wantErr { + t.Errorf("parseOAuthConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +// contains is a helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(substr) == 0 || (len(s) >= len(substr) && findSubstring(s, substr)) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/config/validator.go b/internal/config/validator.go index 3499192..2f5214f 100644 --- a/internal/config/validator.go +++ b/internal/config/validator.go @@ -64,12 +64,22 @@ func (v *Validator) validateCli() bool { return valid } +// validateConfig checks configuration compatibility +func (v *Validator) validateConfig() bool { + if err := v.config.ValidateConfig(); err != nil { + v.errors = append(v.errors, err.Error()) + return false + } + return true +} + // Validate runs all validation checks func (v *Validator) Validate() bool { // Run all validation checks validCli := v.validateCli() + validConfig := v.validateConfig() - return validCli + return validCli && validConfig } // GetErrors returns all errors found during validation diff --git a/internal/server/server.go b/internal/server/server.go index 555ca37..39b4513 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/Azure/aks-mcp/internal/auth/oauth" "github.com/Azure/aks-mcp/internal/azcli" "github.com/Azure/aks-mcp/internal/azureclient" "github.com/Azure/aks-mcp/internal/components/advisor" @@ -36,6 +37,9 @@ type Service struct { mcpServer *server.MCPServer azClient *azureclient.AzureClient azcliProcFactory func(timeout int) azcli.Proc + oauthProvider *oauth.AzureOAuthProvider + authMiddleware *oauth.AuthMiddleware + endpointManager *oauth.EndpointManager } // ServiceOption defines a function that configures the AKS MCP service @@ -83,6 +87,14 @@ func (s *Service) initializeInfrastructure() error { s.azClient = azClient log.Println("Azure client initialized successfully") + // Initialize OAuth components if enabled and transport is not stdio + // OAuth is not supported with stdio transport per MCP specification + if s.cfg.OAuthConfig.Enabled && s.cfg.Transport != "stdio" { + if err := s.initializeOAuth(); err != nil { + return fmt.Errorf("failed to initialize OAuth: %w", err) + } + } + // Ensure Azure CLI exists and is logged in if s.azcliProcFactory != nil { // Use injected factory to create an azcli.Proc @@ -114,6 +126,35 @@ func (s *Service) initializeInfrastructure() error { return nil } +// initializeOAuth initializes OAuth authentication components +func (s *Service) initializeOAuth() error { + log.Println("Initializing OAuth authentication...") + + // Validate OAuth configuration + if err := s.cfg.OAuthConfig.Validate(); err != nil { + return fmt.Errorf("invalid OAuth configuration: %w", err) + } + + // Create OAuth provider + provider, err := oauth.NewAzureOAuthProvider(s.cfg.OAuthConfig) + if err != nil { + return fmt.Errorf("failed to create OAuth provider: %w", err) + } + s.oauthProvider = provider + + // Create server URL for OAuth metadata + serverURL := fmt.Sprintf("http://%s:%d", s.cfg.Host, s.cfg.Port) + + // Create auth middleware + s.authMiddleware = oauth.NewAuthMiddleware(provider, serverURL) + + // Create endpoint manager + s.endpointManager = oauth.NewEndpointManager(provider, s.cfg) + + log.Printf("OAuth authentication initialized with tenant: %s", s.cfg.OAuthConfig.TenantID) + return nil +} + // registerAllComponents registers all component tools organized by category func (s *Service) registerAllComponents() { // Azure Components @@ -142,6 +183,15 @@ func (s *Service) registerPrompts() { func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { mux := http.NewServeMux() + // Register OAuth endpoints if OAuth is enabled + if s.cfg.OAuthConfig.Enabled { + if s.endpointManager == nil { + log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") + } + log.Println("Registering OAuth endpoints...") + s.endpointManager.RegisterEndpoints(mux) + } + // Handle all other paths with a helpful 404 response mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/mcp" { @@ -159,6 +209,20 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { }, } + // Add OAuth endpoints to the response if enabled + if s.cfg.OAuthConfig.Enabled { + oauthEndpoints := map[string]string{ + "oauth-metadata": "GET /.well-known/oauth-protected-resource - OAuth metadata", + "auth-server-metadata": "GET /.well-known/oauth-authorization-server - Authorization server metadata", + "client-registration": "POST /oauth/register - Dynamic client registration", + "token-introspection": "POST /oauth/introspect - Token introspection", + "health": "GET /health - Health check", + } + for k, v := range oauthEndpoints { + response["endpoints"].(map[string]string)[k] = v + } + } + if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } @@ -177,9 +241,28 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, addr string) *http.Server { mux := http.NewServeMux() - // Register SSE and Message handlers - mux.Handle("/sse", sseServer.SSEHandler()) - mux.Handle("/message", sseServer.MessageHandler()) + // Register OAuth endpoints if OAuth is enabled + if s.cfg.OAuthConfig.Enabled { + if s.endpointManager == nil { + log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") + } + log.Println("Registering OAuth endpoints for SSE server...") + s.endpointManager.RegisterEndpoints(mux) + } + + // Register SSE and Message handlers with authentication if enabled + if s.cfg.OAuthConfig.Enabled { + if s.authMiddleware == nil { + log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") + } + // Apply authentication middleware to SSE and Message endpoints + mux.Handle("/sse", s.authMiddleware.Middleware(sseServer.SSEHandler())) + mux.Handle("/message", s.authMiddleware.Middleware(sseServer.MessageHandler())) + } else { + // Register without authentication + mux.Handle("/sse", sseServer.SSEHandler()) + mux.Handle("/message", sseServer.MessageHandler()) + } // Handle all other paths with a helpful 404 response mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -196,6 +279,26 @@ func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, }, } + // Add OAuth endpoints and authentication info if enabled + if s.cfg.OAuthConfig.Enabled { + response["authentication"] = map[string]interface{}{ + "required": true, + "type": "Bearer", + "note": "Include 'Authorization: Bearer ' header for authenticated endpoints", + } + + oauthEndpoints := map[string]string{ + "oauth-metadata": "GET /.well-known/oauth-protected-resource - OAuth metadata", + "auth-server-metadata": "GET /.well-known/oauth-authorization-server - Authorization server metadata", + "client-registration": "POST /oauth/register - Dynamic client registration", + "token-introspection": "POST /oauth/introspect - Token introspection", + "health": "GET /health - Health check", + } + for k, v := range oauthEndpoints { + response["endpoints"].(map[string]string)[k] = v + } + } + if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } @@ -231,6 +334,10 @@ func (s *Service) Run() error { log.Printf("SSE endpoint available at: http://%s/sse", addr) log.Printf("Message endpoint available at: http://%s/message", addr) log.Printf("Connect to /sse for real-time events, send JSON-RPC to /message") + if s.cfg.OAuthConfig.Enabled { + log.Printf("OAuth authentication enabled - Bearer token required for SSE and Message endpoints") + log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) + } return customServer.ListenAndServe() case "streamable-http": @@ -247,12 +354,25 @@ func (s *Service) Run() error { // Update the mux to use the actual streamable server as the MCP handler if mux, ok := customServer.Handler.(*http.ServeMux); ok { - mux.Handle("/mcp", streamableServer) + if s.cfg.OAuthConfig.Enabled { + if s.authMiddleware == nil { + log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") + } + // Apply authentication middleware to MCP endpoint + mux.Handle("/mcp", s.authMiddleware.Middleware(streamableServer)) + } else { + // Register without authentication + mux.Handle("/mcp", streamableServer) + } } log.Printf("Streamable HTTP server listening on %s", addr) log.Printf("MCP endpoint available at: http://%s/mcp", addr) log.Printf("Send POST requests to /mcp to initialize session and obtain Mcp-Session-Id") + if s.cfg.OAuthConfig.Enabled { + log.Printf("OAuth authentication enabled - Bearer token required for MCP endpoint") + log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) + } return customServer.ListenAndServe() default: