Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ Usage of ./aks-mcp:
--port int Port to listen for the server (only used with transport sse or streamable-http) (default 8000)
--timeout int Timeout for command execution in seconds, default is 600s (default 600)
--transport string Transport mechanism to use (stdio, sse or streamable-http) (default "stdio")
-v, --verbose Enable verbose logging
--log-level string Log level (debug, info, warn, error) (default "info")
```

**Environment variables:**
Expand Down
14 changes: 11 additions & 3 deletions cmd/aks-mcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"

"github.com/Azure/aks-mcp/internal/config"
"github.com/Azure/aks-mcp/internal/logger"
"github.com/Azure/aks-mcp/internal/server"
"github.com/Azure/aks-mcp/internal/version"
)
Expand All @@ -18,6 +18,13 @@ func main() {
cfg := config.NewConfig()
cfg.ParseFlags()

// Initialize logger with configured level
if err := logger.SetLevel(cfg.LogLevel); err != nil {
fmt.Fprintf(os.Stderr, "Invalid log level '%s': %v\n", cfg.LogLevel, err)
os.Exit(1)
}
logger.Debugf("Log level set to: %s", cfg.LogLevel)

// Create validator and run validation checks
v := config.NewValidator(cfg)
if !v.Validate() {
Expand All @@ -41,7 +48,7 @@ func main() {
defer func() {
if cfg.TelemetryService != nil {
if err := cfg.TelemetryService.Shutdown(context.Background()); err != nil {
log.Printf("Failed to shutdown telemetry: %v", err)
logger.Errorf("Failed to shutdown telemetry: %v", err)
}
}
}()
Expand All @@ -65,7 +72,8 @@ func main() {
cancel()
case err := <-errChan:
if err != nil {
log.Fatalf("Service error: %v\n", err)
logger.Errorf("Service error: %v", err)
os.Exit(1)
}
}
}
124 changes: 62 additions & 62 deletions internal/auth/oauth/endpoints.go

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions internal/auth/oauth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"

"github.com/Azure/aks-mcp/internal/auth"
"github.com/Azure/aks-mcp/internal/logger"
)

// contextKey is a custom type for context keys to avoid collisions
Expand Down Expand Up @@ -43,7 +43,7 @@ func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter, r *http.Request)
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
w.Header().Set("Access-Control-Allow-Credentials", "false")
} else if requestOrigin != "" {
log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin)
logger.Errorf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin)
}
}

Expand All @@ -61,7 +61,7 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {

// Skip authentication for specific endpoints
if m.shouldSkipAuth(r) {
log.Printf("Skipping auth for path: %s\n", r.URL.Path)
logger.Debugf("Skipping auth for path: %s", r.URL.Path)
next.ServeHTTP(w, r)
return
}
Expand All @@ -70,7 +70,7 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
authResult := m.authenticateRequest(r)

if !authResult.Authenticated {
log.Printf("Authentication FAILED - handling error\n")
logger.Errorf("Authentication FAILED - handling error")
m.handleAuthError(w, r, authResult)
return
}
Expand Down Expand Up @@ -116,8 +116,8 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {
authHeader := r.Header.Get("Authorization")

if authHeader == "" {
log.Printf("OAuth DEBUG - Missing authorization header for %s %s\n", r.Method, r.URL.Path)
log.Printf("OAuth DEBUG - Request headers: %+v\n", r.Header)
logger.Debugf("OAuth DEBUG - Missing authorization header for %s %s", r.Method, r.URL.Path)
logger.Debugf("OAuth DEBUG - Request headers: %+v", r.Header)
return &auth.AuthResult{
Authenticated: false,
Error: "missing authorization header",
Expand All @@ -128,7 +128,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {
// Check for Bearer token format
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
log.Printf("FAILED - Invalid authorization header format (missing Bearer prefix)\n")
logger.Errorf("FAILED - Invalid authorization header format (missing Bearer prefix)")
return &auth.AuthResult{
Authenticated: false,
Error: "invalid authorization header format",
Expand All @@ -138,7 +138,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {

token := strings.TrimPrefix(authHeader, bearerPrefix)
if token == "" {
log.Printf("FAILED - Empty bearer token\n")
logger.Errorf("FAILED - Empty bearer token")
return &auth.AuthResult{
Authenticated: false,
Error: "empty bearer token",
Expand All @@ -149,7 +149,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {
// Basic JWT structure validation
tokenParts := strings.Split(token, ".")
if len(tokenParts) != 3 {
log.Printf("FAILED - JWT structure validation (has %d parts, expected 3)\n", len(tokenParts))
logger.Errorf("FAILED - JWT structure validation (has %d parts, expected 3)", len(tokenParts))
return &auth.AuthResult{
Authenticated: false,
Error: "invalid JWT structure",
Expand All @@ -160,7 +160,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {
// Validate the token
tokenInfo, err := m.provider.ValidateToken(r.Context(), token)
if err != nil {
log.Printf("FAILED - Provider token validation failed: %v\n", err)
logger.Errorf("FAILED - Provider token validation failed: %v", err)
return &auth.AuthResult{
Authenticated: false,
Error: fmt.Sprintf("token validation failed: %v", err),
Expand All @@ -170,7 +170,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult {

// Validate required scopes - strict enforcement for security
if !m.validateScopes(tokenInfo.Scope) {
log.Printf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes)
logger.Errorf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes)
return &auth.AuthResult{
Authenticated: false,
Error: "insufficient scope",
Expand Down Expand Up @@ -278,9 +278,9 @@ func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request,
}

if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
log.Printf("MIDDLEWARE ERROR: Failed to encode error response: %v\n", err)
logger.Errorf("MIDDLEWARE ERROR: Failed to encode error response: %v", err)
} else {
log.Printf("MIDDLEWARE ERROR: Error response sent\n")
logger.Errorf("MIDDLEWARE ERROR: Error response sent")
}
}

Expand Down
40 changes: 20 additions & 20 deletions internal/auth/oauth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"net/url"
Expand All @@ -17,6 +16,7 @@ import (

"github.com/Azure/aks-mcp/internal/auth"
internalConfig "github.com/Azure/aks-mcp/internal/config"
"github.com/Azure/aks-mcp/internal/logger"
"github.com/golang-jwt/jwt/v5"
)

Expand Down Expand Up @@ -109,71 +109,71 @@ func (p *AzureOAuthProvider) GetProtectedResourceMetadata(serverURL string) (*Pr
// GetAuthorizationServerMetadata returns OAuth 2.0 Authorization Server Metadata (RFC 8414)
func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (*AzureADMetadata, error) {
metadataURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", p.config.TenantID)
log.Printf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL)
logger.Debugf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL)

resp, err := p.httpClient.Get(metadataURL)
if err != nil {
log.Printf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err)
logger.Errorf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err)
return nil, fmt.Errorf("failed to fetch metadata from %s: %w", metadataURL, err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Printf("Failed to close response body: %v", err)
logger.Errorf("Failed to close response body: %v", err)
}
}()

if resp.StatusCode == http.StatusNotFound {
log.Printf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID)
logger.Errorf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID)
return nil, fmt.Errorf("tenant ID '%s' not found (HTTP 404). Please verify your Azure AD tenant ID is correct", p.config.TenantID)
}

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body))
logger.Errorf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("metadata endpoint returned status %d: %s", resp.StatusCode, string(body))
}

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("OAuth ERROR: Failed to read metadata response: %v", err)
logger.Errorf("OAuth ERROR: Failed to read metadata response: %v", err)
return nil, fmt.Errorf("failed to read metadata response: %w", err)
}

var metadata AzureADMetadata
if err := json.Unmarshal(body, &metadata); err != nil {
log.Printf("OAuth ERROR: Failed to parse metadata JSON: %v", err)
logger.Errorf("OAuth ERROR: Failed to parse metadata JSON: %v", err)
return nil, fmt.Errorf("failed to parse metadata: %w", err)
}

log.Printf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported)
logger.Debugf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported)

// Ensure grant_types_supported is populated for MCP Inspector compatibility
if len(metadata.GrantTypesSupported) == 0 {
log.Printf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)")
logger.Debugf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)")
metadata.GrantTypesSupported = []string{"authorization_code", "refresh_token"}
}

// Ensure response_types_supported is populated for MCP Inspector compatibility
if len(metadata.ResponseTypesSupported) == 0 {
log.Printf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)")
logger.Debugf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)")
metadata.ResponseTypesSupported = []string{"code"}
}

// Ensure subject_types_supported is populated for MCP Inspector compatibility
if len(metadata.SubjectTypesSupported) == 0 {
log.Printf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)")
logger.Debugf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)")
metadata.SubjectTypesSupported = []string{"public"}
}

// Ensure token_endpoint_auth_methods_supported is populated for MCP Inspector compatibility
if len(metadata.TokenEndpointAuthMethodsSupported) == 0 {
log.Printf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)")
logger.Debugf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)")
metadata.TokenEndpointAuthMethodsSupported = []string{"none"}
}

// Add S256 code challenge method support (Azure AD supports this but may not advertise it)
// MCP specification requires S256 support, so we always ensure it's present
log.Printf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)")
logger.Debugf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)")
metadata.CodeChallengeMethodsSupported = []string{"S256"}

// Azure AD v2.0 has limited support for RFC 8707 Resource Indicators
Expand All @@ -197,7 +197,7 @@ func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (*
metadata.RegistrationEndpoint = registrationURL
}

log.Printf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v",
logger.Debugf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v",
metadata.GrantTypesSupported, metadata.ResponseTypesSupported, metadata.CodeChallengeMethodsSupported)

return &metadata, nil
Expand All @@ -217,7 +217,7 @@ func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString stri
// ValidateJWT should ALWAYS be true in production environments
// This bypass creates a significant security vulnerability if enabled in production
if !p.config.TokenValidation.ValidateJWT {
log.Printf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing")
logger.Warnf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing")
return &auth.TokenInfo{
AccessToken: tokenString,
TokenType: "Bearer",
Expand Down Expand Up @@ -400,7 +400,7 @@ func (p *AzureOAuthProvider) getKeyFunc(token *jwt.Token) (interface{}, error) {
// Get the public key for this key ID using the appropriate issuer
key, err := p.getPublicKey(kid, issuer)
if err != nil {
log.Printf("PUBLIC KEY RETRIEVAL FAILED: %s\n", err)
logger.Errorf("PUBLIC KEY RETRIEVAL FAILED: %s", err)
return nil, fmt.Errorf("failed to get public key: %w", err)
}

Expand Down Expand Up @@ -432,7 +432,7 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Printf("Failed to close response body: %v", err)
logger.Errorf("Failed to close response body: %v", err)
}
}()

Expand All @@ -458,7 +458,7 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi
return nil, fmt.Errorf("failed to parse JWKS: %w", err)
}

log.Printf("JWKS Contains %d keys, searching for kid=%s\n", len(jwks.Keys), kid)
logger.Debugf("JWKS Contains %d keys, searching for kid=%s", len(jwks.Keys), kid)

// Parse keys and find the target key
var targetKey *rsa.PublicKey
Expand All @@ -470,7 +470,7 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi
if key.Kty == "RSA" && key.Kid == kid {
pubKey, err := parseRSAPublicKey(key.N, key.E)
if err != nil {
log.Printf("JWKS Failed to parse RSA key %s: %v\n", key.Kid, err)
logger.Errorf("JWKS Failed to parse RSA key %s: %v", key.Kid, err)
continue
}
targetKey = pubKey
Expand Down
4 changes: 2 additions & 2 deletions internal/azureclient/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"

"github.com/Azure/aks-mcp/internal/logger"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

Expand Down Expand Up @@ -66,7 +66,7 @@ func ParseAKSResourceID(resourceID string) (subscriptionID, resourceGroup, clust
func HandleDetectorAPIResponse(resp *http.Response) ([]byte, error) {
defer func() {
if err := resp.Body.Close(); err != nil {
log.Printf("Warning: failed to close response body: %v", err)
logger.Warnf("Warning: failed to close response body: %v", err)
}
}()

Expand Down
Loading
Loading