diff --git a/README.md b/README.md index 30ad092..2541ef8 100644 --- a/README.md +++ b/README.md @@ -641,7 +641,7 @@ Usage of ./aks-mcp: --port int Port to listen for the server (only used with transport sse or streamable-http) (default 8000) --timeout int Timeout for command execution in seconds, default is 600s (default 600) --transport string Transport mechanism to use (stdio, sse or streamable-http) (default "stdio") - -v, --verbose Enable verbose logging + --log-level string Log level (debug, info, warn, error) (default "info") ``` **Environment variables:** diff --git a/cmd/aks-mcp/main.go b/cmd/aks-mcp/main.go index febc19c..701b334 100644 --- a/cmd/aks-mcp/main.go +++ b/cmd/aks-mcp/main.go @@ -3,12 +3,12 @@ package main import ( "context" "fmt" - "log" "os" "os/signal" "syscall" "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" "github.com/Azure/aks-mcp/internal/server" "github.com/Azure/aks-mcp/internal/version" ) @@ -18,6 +18,13 @@ func main() { cfg := config.NewConfig() cfg.ParseFlags() + // Initialize logger with configured level + if err := logger.SetLevel(cfg.LogLevel); err != nil { + fmt.Fprintf(os.Stderr, "Invalid log level '%s': %v\n", cfg.LogLevel, err) + os.Exit(1) + } + logger.Debugf("Log level set to: %s", cfg.LogLevel) + // Create validator and run validation checks v := config.NewValidator(cfg) if !v.Validate() { @@ -41,7 +48,7 @@ func main() { defer func() { if cfg.TelemetryService != nil { if err := cfg.TelemetryService.Shutdown(context.Background()); err != nil { - log.Printf("Failed to shutdown telemetry: %v", err) + logger.Errorf("Failed to shutdown telemetry: %v", err) } } }() @@ -65,7 +72,8 @@ func main() { cancel() case err := <-errChan: if err != nil { - log.Fatalf("Service error: %v\n", err) + logger.Errorf("Service error: %v", err) + os.Exit(1) } } } diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go index af49899..a3ea8c0 100644 --- a/internal/auth/oauth/endpoints.go +++ b/internal/auth/oauth/endpoints.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -14,6 +13,7 @@ import ( "github.com/Azure/aks-mcp/internal/auth" "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" ) // validateAzureADURL validates that the URL is a legitimate Azure AD endpoint @@ -76,7 +76,7 @@ func (em *EndpointManager) setCORSHeaders(w http.ResponseWriter, r *http.Request w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours w.Header().Set("Access-Control-Allow-Credentials", "false") } else if requestOrigin != "" { - log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) + logger.Errorf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) } } @@ -127,7 +127,7 @@ func (em *EndpointManager) RegisterEndpoints(mux *http.ServeMux) { // authServerMetadataProxyHandler proxies authorization server metadata from Azure AD func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - log.Printf("OAuth DEBUG: Received request for authorization server metadata: %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG: Received request for authorization server metadata: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests em.setCORSHeaders(w, r) @@ -139,7 +139,7 @@ func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { } if r.Method != http.MethodGet { - log.Printf("OAuth ERROR: Invalid method %s for metadata endpoint", r.Method) + logger.Errorf("OAuth ERROR: Invalid method %s for metadata endpoint", r.Method) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } @@ -163,7 +163,7 @@ func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { metadata, err := provider.GetAuthorizationServerMetadata(serverURL) if err != nil { - log.Printf("Failed to fetch authorization server metadata: %v\n", err) + logger.Errorf("Failed to fetch authorization server metadata: %v", err) http.Error(w, fmt.Sprintf("Failed to fetch authorization server metadata: %v", err), http.StatusInternalServerError) return } @@ -172,7 +172,7 @@ func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { em.setCacheHeaders(w) if err := json.NewEncoder(w).Encode(metadata); err != nil { - log.Printf("Failed to encode response: %v\n", err) + logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -182,7 +182,7 @@ func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { // clientRegistrationHandler implements OAuth 2.0 Dynamic Client Registration (RFC 7591) func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - log.Printf("OAuth DEBUG: Received client registration request: %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG: Received client registration request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests em.setCORSHeaders(w, r) @@ -194,7 +194,7 @@ func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { } if r.Method != http.MethodPost { - log.Printf("OAuth ERROR: Invalid method %s for client registration endpoint, only POST allowed", r.Method) + logger.Errorf("OAuth ERROR: Invalid method %s for client registration endpoint, only POST allowed", r.Method) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } @@ -203,16 +203,16 @@ func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { var registrationRequest ClientRegistrationRequest if err := json.NewDecoder(r.Body).Decode(®istrationRequest); err != nil { - log.Printf("OAuth ERROR: Failed to parse client registration JSON: %v", err) + logger.Errorf("OAuth ERROR: Failed to parse client registration JSON: %v", err) em.writeErrorResponse(w, "invalid_request", "Invalid JSON in request body", http.StatusBadRequest) return } - log.Printf("OAuth DEBUG: Client registration request parsed - client_name: %s, redirect_uris: %v", registrationRequest.ClientName, registrationRequest.RedirectURIs) + logger.Debugf("OAuth DEBUG: Client registration request parsed - client_name: %s, redirect_uris: %v", registrationRequest.ClientName, registrationRequest.RedirectURIs) // Validate registration request if err := em.validateClientRegistration(®istrationRequest); err != nil { - log.Printf("OAuth ERROR: Client registration validation failed: %v", err) + logger.Errorf("OAuth ERROR: Client registration validation failed: %v", err) em.writeErrorResponse(w, "invalid_client_metadata", err.Error(), http.StatusBadRequest) return } @@ -249,7 +249,7 @@ func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(clientInfo); err != nil { - log.Printf("OAuth ERROR: Failed to encode client registration response: %v", err) + logger.Errorf("OAuth ERROR: Failed to encode client registration response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -308,7 +308,7 @@ func (em *EndpointManager) validateRedirectURI(redirectURI string) error { } } - log.Printf("OAuth SECURITY WARNING: Invalid redirect URI attempted: %s, allowed: %v", + logger.Warnf("OAuth SECURITY WARNING: Invalid redirect URI attempted: %s, allowed: %v", redirectURI, em.cfg.OAuthConfig.RedirectURIs) return fmt.Errorf("redirect_uri not registered: %s", redirectURI) } @@ -351,7 +351,7 @@ func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("Failed to encode introspection response: %v", err) + logger.Errorf("Failed to encode introspection response: %v", err) } return } @@ -410,7 +410,7 @@ func (em *EndpointManager) healthHandler() http.HandlerFunc { // protectedResourceMetadataHandler handles OAuth 2.0 Protected Resource Metadata requests func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - log.Printf("OAuth DEBUG: Received request for protected resource metadata: %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG: Received request for protected resource metadata: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests em.setCORSHeaders(w, r) @@ -422,7 +422,7 @@ func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { } if r.Method != http.MethodGet { - log.Printf("OAuth ERROR: Invalid method %s for protected resource metadata endpoint", r.Method) + logger.Errorf("OAuth ERROR: Invalid method %s for protected resource metadata endpoint", r.Method) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } @@ -441,24 +441,24 @@ func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { // Build the resource URL resourceURL := fmt.Sprintf("%s://%s", scheme, host) - log.Printf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL) + logger.Debugf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL) provider := em.provider metadata, err := provider.GetProtectedResourceMetadata(resourceURL) if err != nil { - log.Printf("OAuth ERROR: Failed to get protected resource metadata: %v", err) + logger.Errorf("OAuth ERROR: Failed to get protected resource metadata: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } - log.Printf("OAuth DEBUG: Successfully generated protected resource metadata with %d authorization servers", len(metadata.AuthorizationServers)) + logger.Debugf("OAuth DEBUG: Successfully generated protected resource metadata with %d authorization servers", len(metadata.AuthorizationServers)) w.Header().Set("Content-Type", "application/json") em.setCacheHeaders(w) if err := json.NewEncoder(w).Encode(metadata); err != nil { - log.Printf("OAuth ERROR: Failed to encode protected resource metadata response: %v", err) + logger.Errorf("OAuth ERROR: Failed to encode protected resource metadata response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -476,14 +476,14 @@ func (em *EndpointManager) writeErrorResponse(w http.ResponseWriter, errorCode, } if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("Failed to encode error response: %v", err) + logger.Errorf("Failed to encode error response: %v", err) } } // authorizationProxyHandler proxies authorization requests to Azure AD with resource parameter filtering func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - log.Printf("OAuth DEBUG: Received authorization proxy request: %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG: Received authorization proxy request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests em.setCORSHeaders(w, r) @@ -495,7 +495,7 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { } if r.Method != http.MethodGet { - log.Printf("OAuth ERROR: Invalid method %s for authorization endpoint, only GET allowed", r.Method) + logger.Errorf("OAuth ERROR: Invalid method %s for authorization endpoint, only GET allowed", r.Method) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } @@ -506,16 +506,16 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { // Validate redirect_uri parameter for security and better user experience redirectURI := query.Get("redirect_uri") if redirectURI == "" { - log.Printf("OAuth ERROR: Missing redirect_uri parameter in authorization request") - log.Printf("OAuth HELP: To fix this error, configure redirect URIs using --oauth-redirects flag") - log.Printf("OAuth HELP: For MCP Inspector, use: --oauth-redirects=\"http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback\"") + logger.Errorf("OAuth ERROR: Missing redirect_uri parameter in authorization request") + logger.Infof("OAuth HELP: To fix this error, configure redirect URIs using --oauth-redirects flag") + logger.Infof("OAuth HELP: For MCP Inspector, use: --oauth-redirects=\"http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback\"") em.writeErrorResponse(w, "invalid_request", "redirect_uri parameter is required", http.StatusBadRequest) return } // Validate that the redirect_uri is registered and allowed if err := em.validateRedirectURI(redirectURI); err != nil { - log.Printf("OAuth ERROR: redirect_uri %s not registered - requests will be blocked for security", redirectURI) + logger.Errorf("OAuth ERROR: redirect_uri %s not registered - requests will be blocked for security", redirectURI) em.writeErrorResponse(w, "invalid_request", fmt.Sprintf("redirect_uri not registered: %s", redirectURI), http.StatusBadRequest) return } @@ -525,7 +525,7 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { codeChallengeMethod := query.Get("code_challenge_method") if codeChallenge == "" { - log.Printf("OAuth ERROR: Missing PKCE code_challenge parameter (required for OAuth 2.1)") + logger.Errorf("OAuth ERROR: Missing PKCE code_challenge parameter (required for OAuth 2.1)") em.writeErrorResponse(w, "invalid_request", "PKCE code_challenge is required", http.StatusBadRequest) return } @@ -533,9 +533,9 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { if codeChallengeMethod == "" { // Default to S256 if not specified query.Set("code_challenge_method", "S256") - log.Printf("OAuth DEBUG: Setting default code_challenge_method to S256") + logger.Debugf("OAuth DEBUG: Setting default code_challenge_method to S256") } else if codeChallengeMethod != "S256" { - log.Printf("OAuth ERROR: Unsupported code_challenge_method: %s (only S256 supported)", codeChallengeMethod) + logger.Errorf("OAuth ERROR: Unsupported code_challenge_method: %s (only S256 supported)", codeChallengeMethod) em.writeErrorResponse(w, "invalid_request", "Only S256 code_challenge_method is supported", http.StatusBadRequest) return } @@ -547,7 +547,7 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { // Remove the resource parameter if present for Azure AD compatibility resourceParam := query.Get("resource") if resourceParam != "" { - log.Printf("OAuth DEBUG: Removing resource parameter for Azure AD compatibility: %s", resourceParam) + logger.Debugf("OAuth DEBUG: Removing resource parameter for Azure AD compatibility: %s", resourceParam) query.Del("resource") } @@ -558,14 +558,14 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { finalScopeString := strings.Join(finalScopes, " ") query.Set("scope", finalScopeString) - log.Printf("OAuth DEBUG: Setting final scope for Azure AD: %s", finalScopeString) + logger.Debugf("OAuth DEBUG: Setting final scope for Azure AD: %s", finalScopeString) // Build the Azure AD authorization URL azureAuthURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", em.cfg.OAuthConfig.TenantID) // Create the redirect URL with filtered parameters redirectURL := fmt.Sprintf("%s?%s", azureAuthURL, query.Encode()) - log.Printf("OAuth DEBUG: Redirecting to Azure AD authorization endpoint: %s", azureAuthURL) + logger.Debugf("OAuth DEBUG: Redirecting to Azure AD authorization endpoint: %s", azureAuthURL) // Redirect to Azure AD http.Redirect(w, r, redirectURL, http.StatusFound) @@ -575,7 +575,7 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { // callbackHandler handles OAuth 2.0 Authorization Code flow callback func (em *EndpointManager) callbackHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - log.Printf("OAuth DEBUG: Received callback request: %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG: Received callback request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests em.setCORSHeaders(w, r) @@ -587,7 +587,7 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { } if r.Method != http.MethodGet { - log.Printf("OAuth ERROR: Invalid method %s for callback endpoint, only GET allowed", r.Method) + logger.Errorf("OAuth ERROR: Invalid method %s for callback endpoint, only GET allowed", r.Method) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } @@ -598,7 +598,7 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { // Check for error response from authorization server if authError := query.Get("error"); authError != "" { errorDesc := query.Get("error_description") - log.Printf("OAuth ERROR: Authorization server returned error: %s - %s", authError, errorDesc) + logger.Errorf("OAuth ERROR: Authorization server returned error: %s - %s", authError, errorDesc) em.writeCallbackErrorResponse(w, fmt.Sprintf("Authorization failed: %s - %s", authError, errorDesc)) return } @@ -606,7 +606,7 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { // Get authorization code code := query.Get("code") if code == "" { - log.Printf("OAuth ERROR: Missing authorization code in callback") + logger.Errorf("OAuth ERROR: Missing authorization code in callback") em.writeCallbackErrorResponse(w, "Missing authorization code") return } @@ -614,17 +614,17 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { // Get state parameter for CSRF protection state := query.Get("state") if state == "" { - log.Printf("OAuth ERROR: Missing state parameter in callback") + logger.Errorf("OAuth ERROR: Missing state parameter in callback") em.writeCallbackErrorResponse(w, "Missing state parameter") return } - log.Printf("OAuth DEBUG: Callback parameters validated - has_code: true, state: %s", state) + logger.Debugf("OAuth DEBUG: Callback parameters validated - has_code: true, state: %s", state) // Validate redirect URI for security - construct expected URI and validate it expectedRedirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port) if err := em.validateRedirectURI(expectedRedirectURI); err != nil { - log.Printf("OAuth ERROR: Redirect URI validation failed: %v", err) + logger.Errorf("OAuth ERROR: Redirect URI validation failed: %v", err) em.writeCallbackErrorResponse(w, "Invalid redirect URI") return } @@ -632,7 +632,7 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { // Exchange authorization code for access token tokenResponse, err := em.exchangeCodeForToken(code, state) if err != nil { - log.Printf("OAuth ERROR: Failed to exchange authorization code for token: %v", err) + logger.Errorf("OAuth ERROR: Failed to exchange authorization code for token: %v", err) em.writeCallbackErrorResponse(w, fmt.Sprintf("Failed to exchange code for token: %v", err)) return } @@ -696,7 +696,7 @@ func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenRespo } defer func() { if err := resp.Body.Close(); err != nil { - log.Printf("Failed to close response body: %v", err) + logger.Errorf("Failed to close response body: %v", err) } }() @@ -740,7 +740,7 @@ func (em *EndpointManager) writeCallbackErrorResponse(w http.ResponseWriter, mes `, message) if _, err := w.Write([]byte(html)); err != nil { - log.Printf("Failed to write error response: %v", err) + logger.Errorf("Failed to write error response: %v", err) } } @@ -826,7 +826,7 @@ func (em *EndpointManager) writeCallbackSuccessResponse(w http.ResponseWriter, t tokenResponse.AccessToken) if _, err := w.Write([]byte(html)); err != nil { - log.Printf("Failed to write success response: %v", err) + logger.Errorf("Failed to write success response: %v", err) } } @@ -855,7 +855,7 @@ func (em *EndpointManager) generateSessionToken() (string, error) { // tokenHandler handles OAuth 2.0 token endpoint requests (Authorization Code exchange) func (em *EndpointManager) tokenHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - log.Printf("OAuth DEBUG: Received token endpoint request: %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG: Received token endpoint request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests em.setCORSHeaders(w, r) @@ -867,14 +867,14 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { } if r.Method != http.MethodPost { - log.Printf("OAuth ERROR: Invalid method %s for token endpoint, only POST allowed", r.Method) + logger.Errorf("OAuth ERROR: Invalid method %s for token endpoint, only POST allowed", r.Method) em.writeErrorResponse(w, "invalid_request", "Only POST method is allowed", http.StatusMethodNotAllowed) return } // Parse form data if err := r.ParseForm(); err != nil { - log.Printf("OAuth ERROR: Failed to parse form data: %v", err) + logger.Errorf("OAuth ERROR: Failed to parse form data: %v", err) em.writeErrorResponse(w, "invalid_request", "Failed to parse form data", http.StatusBadRequest) return } @@ -882,7 +882,7 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { // Validate grant type grantType := r.FormValue("grant_type") if grantType != "authorization_code" { - log.Printf("OAuth ERROR: Unsupported grant type: %s", grantType) + logger.Errorf("OAuth ERROR: Unsupported grant type: %s", grantType) em.writeErrorResponse(w, "unsupported_grant_type", fmt.Sprintf("Unsupported grant type: %s", grantType), http.StatusBadRequest) return } @@ -894,40 +894,40 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { codeVerifier := r.FormValue("code_verifier") // PKCE parameter if code == "" { - log.Printf("OAuth ERROR: Missing authorization code in token request") + logger.Errorf("OAuth ERROR: Missing authorization code in token request") em.writeErrorResponse(w, "invalid_request", "Missing authorization code", http.StatusBadRequest) return } if clientID == "" { - log.Printf("OAuth ERROR: Missing client_id in token request") + logger.Errorf("OAuth ERROR: Missing client_id in token request") em.writeErrorResponse(w, "invalid_request", "Missing client_id", http.StatusBadRequest) return } if redirectURI == "" { - log.Printf("OAuth ERROR: Missing redirect_uri in token request") + logger.Errorf("OAuth ERROR: Missing redirect_uri in token request") em.writeErrorResponse(w, "invalid_request", "Missing redirect_uri", http.StatusBadRequest) return } // Enforce PKCE code_verifier for OAuth 2.1 compliance if codeVerifier == "" { - log.Printf("OAuth ERROR: Missing PKCE code_verifier (required for OAuth 2.1)") + logger.Errorf("OAuth ERROR: Missing PKCE code_verifier (required for OAuth 2.1)") em.writeErrorResponse(w, "invalid_request", "PKCE code_verifier is required", http.StatusBadRequest) return } // Validate client ID (accept both configured and dynamically registered clients) if !em.isValidClientID(clientID) { - log.Printf("OAuth ERROR: Invalid client_id: %s", clientID) + logger.Errorf("OAuth ERROR: Invalid client_id: %s", clientID) em.writeErrorResponse(w, "invalid_client", "Invalid client_id", http.StatusBadRequest) return } // Validate redirect URI for security if err := em.validateRedirectURI(redirectURI); err != nil { - log.Printf("OAuth ERROR: Redirect URI validation failed in token endpoint: %v", err) + logger.Errorf("OAuth ERROR: Redirect URI validation failed in token endpoint: %v", err) em.writeErrorResponse(w, "invalid_request", "Invalid redirect_uri", http.StatusBadRequest) return } @@ -939,12 +939,12 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { requestedScope = strings.Join(em.cfg.OAuthConfig.RequiredScopes, " ") } - log.Printf("OAuth DEBUG: Exchanging authorization code for access token with Azure AD, scope: %s", requestedScope) + logger.Debugf("OAuth DEBUG: Exchanging authorization code for access token with Azure AD, scope: %s", requestedScope) // Exchange authorization code for access token with Azure AD tokenResponse, err := em.exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, requestedScope) if err != nil { - log.Printf("OAuth ERROR: Token exchange with Azure AD failed: %v", err) + logger.Errorf("OAuth ERROR: Token exchange with Azure AD failed: %v", err) em.writeErrorResponse(w, "invalid_grant", fmt.Sprintf("Authorization code exchange failed: %v", err), http.StatusBadRequest) return } @@ -955,7 +955,7 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { w.Header().Set("Pragma", "no-cache") if err := json.NewEncoder(w).Encode(tokenResponse); err != nil { - log.Printf("OAuth ERROR: Failed to encode token response: %v", err) + logger.Errorf("OAuth ERROR: Failed to encode token response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -983,15 +983,15 @@ func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVer // Add PKCE code_verifier if present if codeVerifier != "" { data.Set("code_verifier", codeVerifier) - log.Printf("Including PKCE code_verifier in Azure AD token request") + logger.Debugf("Including PKCE code_verifier in Azure AD token request") } else { - log.Printf("No PKCE code_verifier provided - this may cause PKCE verification to fail") + logger.Warnf("No PKCE code_verifier provided - this may cause PKCE verification to fail") } // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests // It uses scope-based resource identification instead // For MCP compliance, we handle resource binding through audience validation - log.Printf("Azure AD token request with scope: %s", scope) + logger.Debugf("Azure AD token request with scope: %s", scope) // Make token exchange request to Azure AD resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above @@ -1000,7 +1000,7 @@ func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVer } defer func() { if err := resp.Body.Close(); err != nil { - log.Printf("Failed to close response body: %v", err) + logger.Errorf("Failed to close response body: %v", err) } }() @@ -1015,7 +1015,7 @@ func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVer return nil, fmt.Errorf("failed to parse token response: %w", err) } - log.Printf("Token exchange successful: access_token received (length: %d)", len(tokenResponse.AccessToken)) + logger.Infof("Token exchange successful: access_token received (length: %d)", len(tokenResponse.AccessToken)) return &tokenResponse, nil } diff --git a/internal/auth/oauth/middleware.go b/internal/auth/oauth/middleware.go index 5170385..46d25b9 100644 --- a/internal/auth/oauth/middleware.go +++ b/internal/auth/oauth/middleware.go @@ -4,11 +4,11 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "strings" "github.com/Azure/aks-mcp/internal/auth" + "github.com/Azure/aks-mcp/internal/logger" ) // contextKey is a custom type for context keys to avoid collisions @@ -43,7 +43,7 @@ func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter, r *http.Request) w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours w.Header().Set("Access-Control-Allow-Credentials", "false") } else if requestOrigin != "" { - log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) + logger.Errorf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) } } @@ -61,7 +61,7 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { // Skip authentication for specific endpoints if m.shouldSkipAuth(r) { - log.Printf("Skipping auth for path: %s\n", r.URL.Path) + logger.Debugf("Skipping auth for path: %s", r.URL.Path) next.ServeHTTP(w, r) return } @@ -70,7 +70,7 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { authResult := m.authenticateRequest(r) if !authResult.Authenticated { - log.Printf("Authentication FAILED - handling error\n") + logger.Errorf("Authentication FAILED - handling error") m.handleAuthError(w, r, authResult) return } @@ -116,8 +116,8 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { authHeader := r.Header.Get("Authorization") if authHeader == "" { - log.Printf("OAuth DEBUG - Missing authorization header for %s %s\n", r.Method, r.URL.Path) - log.Printf("OAuth DEBUG - Request headers: %+v\n", r.Header) + logger.Debugf("OAuth DEBUG - Missing authorization header for %s %s", r.Method, r.URL.Path) + logger.Debugf("OAuth DEBUG - Request headers: %+v", r.Header) return &auth.AuthResult{ Authenticated: false, Error: "missing authorization header", @@ -128,7 +128,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { // Check for Bearer token format const bearerPrefix = "Bearer " if !strings.HasPrefix(authHeader, bearerPrefix) { - log.Printf("FAILED - Invalid authorization header format (missing Bearer prefix)\n") + logger.Errorf("FAILED - Invalid authorization header format (missing Bearer prefix)") return &auth.AuthResult{ Authenticated: false, Error: "invalid authorization header format", @@ -138,7 +138,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { token := strings.TrimPrefix(authHeader, bearerPrefix) if token == "" { - log.Printf("FAILED - Empty bearer token\n") + logger.Errorf("FAILED - Empty bearer token") return &auth.AuthResult{ Authenticated: false, Error: "empty bearer token", @@ -149,7 +149,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { // Basic JWT structure validation tokenParts := strings.Split(token, ".") if len(tokenParts) != 3 { - log.Printf("FAILED - JWT structure validation (has %d parts, expected 3)\n", len(tokenParts)) + logger.Errorf("FAILED - JWT structure validation (has %d parts, expected 3)", len(tokenParts)) return &auth.AuthResult{ Authenticated: false, Error: "invalid JWT structure", @@ -160,7 +160,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { // Validate the token tokenInfo, err := m.provider.ValidateToken(r.Context(), token) if err != nil { - log.Printf("FAILED - Provider token validation failed: %v\n", err) + logger.Errorf("FAILED - Provider token validation failed: %v", err) return &auth.AuthResult{ Authenticated: false, Error: fmt.Sprintf("token validation failed: %v", err), @@ -170,7 +170,7 @@ func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { // Validate required scopes - strict enforcement for security if !m.validateScopes(tokenInfo.Scope) { - log.Printf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes) + logger.Errorf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes) return &auth.AuthResult{ Authenticated: false, Error: "insufficient scope", @@ -278,9 +278,9 @@ func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, } if err := json.NewEncoder(w).Encode(errorResponse); err != nil { - log.Printf("MIDDLEWARE ERROR: Failed to encode error response: %v\n", err) + logger.Errorf("MIDDLEWARE ERROR: Failed to encode error response: %v", err) } else { - log.Printf("MIDDLEWARE ERROR: Error response sent\n") + logger.Errorf("MIDDLEWARE ERROR: Error response sent") } } diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go index d6ec9bc..2f576b1 100644 --- a/internal/auth/oauth/provider.go +++ b/internal/auth/oauth/provider.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "io" - "log" "math/big" "net/http" "net/url" @@ -17,6 +16,7 @@ import ( "github.com/Azure/aks-mcp/internal/auth" internalConfig "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" "github.com/golang-jwt/jwt/v5" ) @@ -109,71 +109,71 @@ func (p *AzureOAuthProvider) GetProtectedResourceMetadata(serverURL string) (*Pr // GetAuthorizationServerMetadata returns OAuth 2.0 Authorization Server Metadata (RFC 8414) func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (*AzureADMetadata, error) { metadataURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", p.config.TenantID) - log.Printf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL) + logger.Debugf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL) resp, err := p.httpClient.Get(metadataURL) if err != nil { - log.Printf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err) + logger.Errorf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err) return nil, fmt.Errorf("failed to fetch metadata from %s: %w", metadataURL, err) } defer func() { if err := resp.Body.Close(); err != nil { - log.Printf("Failed to close response body: %v", err) + logger.Errorf("Failed to close response body: %v", err) } }() if resp.StatusCode == http.StatusNotFound { - log.Printf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID) + logger.Errorf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID) return nil, fmt.Errorf("tenant ID '%s' not found (HTTP 404). Please verify your Azure AD tenant ID is correct", p.config.TenantID) } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - log.Printf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) + logger.Errorf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) return nil, fmt.Errorf("metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) } body, err := io.ReadAll(resp.Body) if err != nil { - log.Printf("OAuth ERROR: Failed to read metadata response: %v", err) + logger.Errorf("OAuth ERROR: Failed to read metadata response: %v", err) return nil, fmt.Errorf("failed to read metadata response: %w", err) } var metadata AzureADMetadata if err := json.Unmarshal(body, &metadata); err != nil { - log.Printf("OAuth ERROR: Failed to parse metadata JSON: %v", err) + logger.Errorf("OAuth ERROR: Failed to parse metadata JSON: %v", err) return nil, fmt.Errorf("failed to parse metadata: %w", err) } - log.Printf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported) + logger.Debugf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported) // Ensure grant_types_supported is populated for MCP Inspector compatibility if len(metadata.GrantTypesSupported) == 0 { - log.Printf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)") + logger.Debugf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)") metadata.GrantTypesSupported = []string{"authorization_code", "refresh_token"} } // Ensure response_types_supported is populated for MCP Inspector compatibility if len(metadata.ResponseTypesSupported) == 0 { - log.Printf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)") + logger.Debugf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)") metadata.ResponseTypesSupported = []string{"code"} } // Ensure subject_types_supported is populated for MCP Inspector compatibility if len(metadata.SubjectTypesSupported) == 0 { - log.Printf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)") + logger.Debugf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)") metadata.SubjectTypesSupported = []string{"public"} } // Ensure token_endpoint_auth_methods_supported is populated for MCP Inspector compatibility if len(metadata.TokenEndpointAuthMethodsSupported) == 0 { - log.Printf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)") + logger.Debugf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)") metadata.TokenEndpointAuthMethodsSupported = []string{"none"} } // Add S256 code challenge method support (Azure AD supports this but may not advertise it) // MCP specification requires S256 support, so we always ensure it's present - log.Printf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)") + logger.Debugf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)") metadata.CodeChallengeMethodsSupported = []string{"S256"} // Azure AD v2.0 has limited support for RFC 8707 Resource Indicators @@ -197,7 +197,7 @@ func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (* metadata.RegistrationEndpoint = registrationURL } - log.Printf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v", + logger.Debugf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v", metadata.GrantTypesSupported, metadata.ResponseTypesSupported, metadata.CodeChallengeMethodsSupported) return &metadata, nil @@ -217,7 +217,7 @@ func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString stri // ValidateJWT should ALWAYS be true in production environments // This bypass creates a significant security vulnerability if enabled in production if !p.config.TokenValidation.ValidateJWT { - log.Printf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing") + logger.Warnf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing") return &auth.TokenInfo{ AccessToken: tokenString, TokenType: "Bearer", @@ -400,7 +400,7 @@ func (p *AzureOAuthProvider) getKeyFunc(token *jwt.Token) (interface{}, error) { // Get the public key for this key ID using the appropriate issuer key, err := p.getPublicKey(kid, issuer) if err != nil { - log.Printf("PUBLIC KEY RETRIEVAL FAILED: %s\n", err) + logger.Errorf("PUBLIC KEY RETRIEVAL FAILED: %s", err) return nil, fmt.Errorf("failed to get public key: %w", err) } @@ -432,7 +432,7 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi } defer func() { if err := resp.Body.Close(); err != nil { - log.Printf("Failed to close response body: %v", err) + logger.Errorf("Failed to close response body: %v", err) } }() @@ -458,7 +458,7 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi return nil, fmt.Errorf("failed to parse JWKS: %w", err) } - log.Printf("JWKS Contains %d keys, searching for kid=%s\n", len(jwks.Keys), kid) + logger.Debugf("JWKS Contains %d keys, searching for kid=%s", len(jwks.Keys), kid) // Parse keys and find the target key var targetKey *rsa.PublicKey @@ -470,7 +470,7 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi if key.Kty == "RSA" && key.Kid == kid { pubKey, err := parseRSAPublicKey(key.N, key.E) if err != nil { - log.Printf("JWKS Failed to parse RSA key %s: %v\n", key.Kid, err) + logger.Errorf("JWKS Failed to parse RSA key %s: %v", key.Kid, err) continue } targetKey = pubKey diff --git a/internal/azureclient/detector.go b/internal/azureclient/detector.go index 4fd77f8..1d97d45 100644 --- a/internal/azureclient/detector.go +++ b/internal/azureclient/detector.go @@ -5,10 +5,10 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strings" + "github.com/Azure/aks-mcp/internal/logger" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) @@ -66,7 +66,7 @@ func ParseAKSResourceID(resourceID string) (subscriptionID, resourceGroup, clust func HandleDetectorAPIResponse(resp *http.Response) ([]byte, error) { defer func() { if err := resp.Body.Close(); err != nil { - log.Printf("Warning: failed to close response body: %v", err) + logger.Warnf("Warning: failed to close response body: %v", err) } }() diff --git a/internal/components/advisor/aks_recommendations.go b/internal/components/advisor/aks_recommendations.go index 88d7e56..1fdd680 100644 --- a/internal/components/advisor/aks_recommendations.go +++ b/internal/components/advisor/aks_recommendations.go @@ -3,23 +3,23 @@ package advisor import ( "encoding/json" "fmt" - "log" "strings" "time" "github.com/Azure/aks-mcp/internal/azcli" "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" ) // HandleAdvisorRecommendation is the main handler for Azure Advisor recommendation operations func HandleAdvisorRecommendation(params map[string]interface{}, cfg *config.ConfigData) (string, error) { operation, ok := params["operation"].(string) if !ok { - log.Println("[ADVISOR] Missing operation parameter") + logger.Errorf("[ADVISOR] Missing operation parameter") return "", fmt.Errorf("operation parameter is required") } - log.Printf("[ADVISOR] Handling operation: %s", operation) + logger.Debugf("[ADVISOR] Handling operation: %s", operation) switch operation { case "list": @@ -27,7 +27,7 @@ func HandleAdvisorRecommendation(params map[string]interface{}, cfg *config.Conf case "report": return handleAKSAdvisorRecommendationReport(params, cfg) default: - log.Printf("[ADVISOR] Invalid operation: %s", operation) + logger.Errorf("[ADVISOR] Invalid operation: %s", operation) return "", fmt.Errorf("invalid operation: %s. Allowed values: list, report", operation) } } @@ -36,7 +36,7 @@ func HandleAdvisorRecommendation(params map[string]interface{}, cfg *config.Conf func handleAKSAdvisorRecommendationList(params map[string]interface{}, cfg *config.ConfigData) (string, error) { subscriptionID, ok := params["subscription_id"].(string) if !ok { - log.Println("[ADVISOR] Missing subscription_id parameter") + logger.Errorf("[ADVISOR] Missing subscription_id parameter") return "", fmt.Errorf("subscription_id parameter is required") } @@ -45,7 +45,7 @@ func handleAKSAdvisorRecommendationList(params map[string]interface{}, cfg *conf category, _ := params["category"].(string) severity, _ := params["severity"].(string) - log.Printf("[ADVISOR] Listing recommendations for subscription: %s, resource_group: %s, category: %s, severity: %s", + logger.Debugf("[ADVISOR] Listing recommendations for subscription: %s, resource_group: %s, category: %s, severity: %s", subscriptionID, resourceGroup, category, severity) // Get cluster names filter if provided @@ -57,30 +57,30 @@ func handleAKSAdvisorRecommendationList(params map[string]interface{}, cfg *conf clusterNames = append(clusterNames, trimmedName) } } - log.Printf("[ADVISOR] Filtering by cluster names: %v", clusterNames) + logger.Debugf("[ADVISOR] Filtering by cluster names: %v", clusterNames) } // Execute Azure CLI command to get recommendations recommendations, err := listRecommendationsViaCLI(subscriptionID, resourceGroup, category, cfg) if err != nil { - log.Printf("[ADVISOR] Failed to list recommendations: %v", err) + logger.Errorf("[ADVISOR] Failed to list recommendations: %v", err) return "", fmt.Errorf("failed to list recommendations: %w", err) } - log.Printf("[ADVISOR] Found %d total recommendations", len(recommendations)) + logger.Infof("[ADVISOR] Found %d total recommendations", len(recommendations)) // Filter for AKS-related recommendations aksRecommendations := filterAKSRecommendationsFromCLI(recommendations) - log.Printf("[ADVISOR] Found %d AKS-related recommendations", len(aksRecommendations)) + logger.Infof("[ADVISOR] Found %d AKS-related recommendations", len(aksRecommendations)) // Apply additional filters if severity != "" { aksRecommendations = filterBySeverity(aksRecommendations, severity) - log.Printf("[ADVISOR] After severity filter: %d recommendations", len(aksRecommendations)) + logger.Debugf("[ADVISOR] After severity filter: %d recommendations", len(aksRecommendations)) } if len(clusterNames) > 0 { aksRecommendations = filterByClusterNames(aksRecommendations, clusterNames) - log.Printf("[ADVISOR] After cluster name filter: %d recommendations", len(aksRecommendations)) + logger.Debugf("[ADVISOR] After cluster name filter: %d recommendations", len(aksRecommendations)) } // Convert to AKS recommendation summaries @@ -89,11 +89,11 @@ func handleAKSAdvisorRecommendationList(params map[string]interface{}, cfg *conf // Return JSON response result, err := json.MarshalIndent(summaries, "", " ") if err != nil { - log.Printf("[ADVISOR] Failed to marshal recommendations: %v", err) + logger.Errorf("[ADVISOR] Failed to marshal recommendations: %v", err) return "", fmt.Errorf("failed to marshal recommendations: %w", err) } - log.Printf("[ADVISOR] Returning %d recommendation summaries", len(summaries)) + logger.Infof("[ADVISOR] Returning %d recommendation summaries", len(summaries)) return string(result), nil } @@ -149,25 +149,25 @@ func listRecommendationsViaCLI(subscriptionID, resourceGroup, category string, c "command": "az " + strings.Join(args, " "), } - log.Printf("[ADVISOR] Executing command: %s", cmdParams["command"]) + logger.Debugf("[ADVISOR] Executing command: %s", cmdParams["command"]) // Execute command output, err := executor.Execute(cmdParams, cfg) if err != nil { - log.Printf("[ADVISOR] Command execution failed: %v", err) + logger.Errorf("[ADVISOR] Command execution failed: %v", err) return nil, fmt.Errorf("failed to execute Azure CLI command: %w", err) } - log.Printf("[ADVISOR] Command output length: %d characters", len(output)) + logger.Debugf("[ADVISOR] Command output length: %d characters", len(output)) // Parse JSON output var recommendations []CLIRecommendation if err := json.Unmarshal([]byte(output), &recommendations); err != nil { - log.Printf("[ADVISOR] Failed to parse JSON output: %v", err) + logger.Errorf("[ADVISOR] Failed to parse JSON output: %v", err) return nil, fmt.Errorf("failed to parse recommendations JSON: %w", err) } - log.Printf("[ADVISOR] Successfully parsed %d recommendations from CLI output", len(recommendations)) + logger.Debugf("[ADVISOR] Successfully parsed %d recommendations from CLI output", len(recommendations)) return recommendations, nil } diff --git a/internal/components/inspektorgadget/handlers.go b/internal/components/inspektorgadget/handlers.go index 3107f9d..231f32f 100644 --- a/internal/components/inspektorgadget/handlers.go +++ b/internal/components/inspektorgadget/handlers.go @@ -250,7 +250,7 @@ func handleLifecycleAction(mgr GadgetManager, deployed bool, action string, acti fmt.Fprintf(os.Stderr, "Failed to get latest version: %v\n", err) } - hc, err := newHelmClient(cfg.Verbose) + hc, err := newHelmClient(cfg.LogLevel == "debug") if err != nil { return "", fmt.Errorf("creating helm client: %w", err) } diff --git a/internal/components/monitor/diagnostics/handlers.go b/internal/components/monitor/diagnostics/handlers.go index 3dac2dd..e76f0b5 100644 --- a/internal/components/monitor/diagnostics/handlers.go +++ b/internal/components/monitor/diagnostics/handlers.go @@ -4,12 +4,12 @@ import ( "context" "encoding/json" "fmt" - "log" "github.com/Azure/aks-mcp/internal/azcli" "github.com/Azure/aks-mcp/internal/azureclient" "github.com/Azure/aks-mcp/internal/components/common" "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" "github.com/Azure/aks-mcp/internal/tools" ) @@ -107,7 +107,7 @@ func HandleControlPlaneLogs(params map[string]interface{}, azClient *azureclient workspaceGUID, kqlQuery, timespan) // Log the query command for debugging - log.Printf("Executing KQL query command: %s", cmd) + logger.Debugf("Executing KQL query command: %s", cmd) cmdParams := map[string]interface{}{ "command": cmd, diff --git a/internal/components/monitor/diagnostics/workspace.go b/internal/components/monitor/diagnostics/workspace.go index ba9d427..73276f1 100644 --- a/internal/components/monitor/diagnostics/workspace.go +++ b/internal/components/monitor/diagnostics/workspace.go @@ -3,12 +3,12 @@ package diagnostics import ( "context" "fmt" - "log" "strings" "github.com/Azure/aks-mcp/internal/azcli" "github.com/Azure/aks-mcp/internal/azureclient" "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" ) // ExtractWorkspaceGUIDFromDiagnosticSettings extracts workspace GUID from diagnostic settings @@ -138,7 +138,7 @@ func FindDiagnosticSettingForCategory(subscriptionID, resourceGroup, clusterName destinationType = string(*setting.Properties.LogAnalyticsDestinationType) } - log.Printf("Using diagnostic setting '%s' for log category '%s' in cluster '%s': workspaceId=%s, destinationType=%s, isResourceSpecific=%t", + logger.Debugf("Using diagnostic setting '%s' for log category '%s' in cluster '%s': workspaceId=%s, destinationType=%s, isResourceSpecific=%t", settingName, logCategory, clusterName, workspaceResourceID, destinationType, isResourceSpecific) return workspaceResourceID, isResourceSpecific, nil diff --git a/internal/config/config.go b/internal/config/config.go index b60f068..75bea3c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,13 +3,13 @@ package config import ( "context" "fmt" - "log" "os" "regexp" "strings" "time" "github.com/Azure/aks-mcp/internal/auth" + "github.com/Azure/aks-mcp/internal/logger" "github.com/Azure/aks-mcp/internal/security" "github.com/Azure/aks-mcp/internal/telemetry" "github.com/Azure/aks-mcp/internal/version" @@ -59,8 +59,8 @@ type ConfigData struct { // Comma-separated list of allowed Kubernetes namespaces AllowNamespaces string - // Verbose logging - Verbose bool + // Log level (debug, info, warn, error) + LogLevel string // OTLP endpoint for OpenTelemetry traces OTLPEndpoint string @@ -81,6 +81,7 @@ func NewConfig() *ConfigData { AccessLevel: "readonly", AdditionalTools: make(map[string]bool), AllowNamespaces: "", + LogLevel: "info", } } @@ -115,7 +116,7 @@ func (cfg *ConfigData) ParseFlags() { "Comma-separated list of allowed Kubernetes namespaces (empty means all namespaces)") // Logging settings - flag.BoolVarP(&cfg.Verbose, "verbose", "v", false, "Enable verbose logging") + flag.StringVar(&cfg.LogLevel, "log-level", "info", "Log level (debug, info, warn, error)") // OTLP settings flag.StringVar(&cfg.OTLPEndpoint, "otlp-endpoint", "", "OTLP endpoint for OpenTelemetry traces (e.g. localhost:4317)") @@ -180,22 +181,22 @@ func (cfg *ConfigData) parseOAuthConfig(additionalRedirectURIs, allowedCORSOrigi if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" { cfg.OAuthConfig.TenantID = tenantID tenantIDSource = "environment variable AZURE_TENANT_ID" - log.Printf("OAuth Config: Using tenant ID from environment variable AZURE_TENANT_ID") + logger.Debugf("OAuth Config: Using tenant ID from environment variable AZURE_TENANT_ID") } } else { tenantIDSource = "command line flag --oauth-tenant-id" - log.Printf("OAuth Config: Using tenant ID from command line flag --oauth-tenant-id") + logger.Debugf("OAuth Config: Using tenant ID from command line flag --oauth-tenant-id") } if cfg.OAuthConfig.ClientID == "" { if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" { cfg.OAuthConfig.ClientID = clientID clientIDSource = "environment variable AZURE_CLIENT_ID" - log.Printf("OAuth Config: Using client ID from environment variable AZURE_CLIENT_ID") + logger.Debugf("OAuth Config: Using client ID from environment variable AZURE_CLIENT_ID") } } else { clientIDSource = "command line flag --oauth-client-id" - log.Printf("OAuth Config: Using client ID from command line flag --oauth-client-id") + logger.Debugf("OAuth Config: Using client ID from command line flag --oauth-client-id") } // Validate GUID formats for tenant ID and client ID @@ -232,7 +233,7 @@ func (cfg *ConfigData) parseOAuthConfig(additionalRedirectURIs, allowedCORSOrigi // Parse allowed CORS origins for OAuth endpoints if allowedCORSOrigins != "" { - log.Printf("OAuth Config: Setting allowed CORS origins from command line flag --oauth-cors-origins") + logger.Debugf("OAuth Config: Setting allowed CORS origins from command line flag --oauth-cors-origins") origins := strings.Split(allowedCORSOrigins, ",") for _, origin := range origins { trimmedOrigin := strings.TrimSpace(origin) @@ -241,7 +242,7 @@ func (cfg *ConfigData) parseOAuthConfig(additionalRedirectURIs, allowedCORSOrigi } } } else { - log.Printf("OAuth Config: No CORS origins configured - cross-origin requests will be blocked for security") + logger.Debugf("OAuth Config: No CORS origins configured - cross-origin requests will be blocked for security") } return nil @@ -270,7 +271,7 @@ func (cfg *ConfigData) InitializeTelemetry(ctx context.Context, serviceName, ser // Initialize telemetry service cfg.TelemetryService = telemetry.NewService(telemetryConfig) if err := cfg.TelemetryService.Initialize(ctx); err != nil { - log.Printf("Failed to initialize telemetry: %v", err) + logger.Errorf("Failed to initialize telemetry: %v", err) // Continue without telemetry - this is not a fatal error } diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..2389e0c --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,64 @@ +package logger + +import ( + "github.com/sirupsen/logrus" +) + +// Logger is the global logrus logger instance +var Logger = logrus.New() + +func init() { + // Set timestamp format to a shorter readable format + Logger.SetFormatter(&logrus.TextFormatter{ + TimestampFormat: "15:04:05", + FullTimestamp: true, + }) +} + +// SetLevel sets the log level +func SetLevel(level string) error { + logLevel, err := logrus.ParseLevel(level) + if err != nil { + return err + } + Logger.SetLevel(logLevel) + return nil +} + +// GetLevel gets the current log level +func GetLevel() string { + return Logger.GetLevel().String() +} + +// Convenience functions using the global Logger +func Debug(args ...any) { + Logger.Debug(args...) +} + +func Debugf(format string, args ...any) { + Logger.Debugf(format, args...) +} + +func Info(args ...any) { + Logger.Info(args...) +} + +func Infof(format string, args ...any) { + Logger.Infof(format, args...) +} + +func Warn(args ...any) { + Logger.Warn(args...) +} + +func Warnf(format string, args ...any) { + Logger.Warnf(format, args...) +} + +func Error(args ...any) { + Logger.Error(args...) +} + +func Errorf(format string, args ...any) { + Logger.Errorf(format, args...) +} diff --git a/internal/server/server.go b/internal/server/server.go index 39b4513..419cdae 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,7 +3,6 @@ package server import ( "encoding/json" "fmt" - "log" "net/http" "time" @@ -20,6 +19,7 @@ import ( "github.com/Azure/aks-mcp/internal/components/network" "github.com/Azure/aks-mcp/internal/config" "github.com/Azure/aks-mcp/internal/k8s" + "github.com/Azure/aks-mcp/internal/logger" "github.com/Azure/aks-mcp/internal/prompts" "github.com/Azure/aks-mcp/internal/tools" "github.com/Azure/aks-mcp/internal/version" @@ -63,7 +63,7 @@ func NewService(cfg *config.ConfigData, opts ...ServiceOption) *Service { // Initialize initializes the service func (s *Service) Initialize() error { - log.Println("Initializing AKS MCP service...") + logger.Infof("Initializing AKS MCP service...") // Phase 1: Initialize core infrastructure if err := s.initializeInfrastructure(); err != nil { @@ -73,7 +73,7 @@ func (s *Service) Initialize() error { // Phase 2: Register all component tools s.registerAllComponents() - log.Println("AKS MCP service initialization completed successfully") + logger.Infof("AKS MCP service initialization completed successfully") return nil } @@ -85,7 +85,7 @@ func (s *Service) initializeInfrastructure() error { return fmt.Errorf("failed to create Azure client: %w", err) } s.azClient = azClient - log.Println("Azure client initialized successfully") + logger.Infof("Azure client initialized successfully") // Initialize OAuth components if enabled and transport is not stdio // OAuth is not supported with stdio transport per MCP specification @@ -102,13 +102,13 @@ func (s *Service) initializeInfrastructure() error { if loginType, err := azcli.EnsureAzCliLoginWithProc(proc, s.cfg); err != nil { return fmt.Errorf("azure cli authentication failed: %w", err) } else { - log.Printf("Azure CLI initialized successfully (%s)", loginType) + logger.Infof("Azure CLI initialized successfully (%s)", loginType) } } else { if loginType, err := azcli.EnsureAzCliLogin(s.cfg); err != nil { return fmt.Errorf("azure cli authentication failed: %w", err) } else { - log.Printf("Azure CLI initialized successfully (%s)", loginType) + logger.Infof("Azure CLI initialized successfully (%s)", loginType) } } @@ -121,14 +121,14 @@ func (s *Service) initializeInfrastructure() error { server.WithLogging(), server.WithRecovery(), ) - log.Println("MCP server initialized successfully") + logger.Infof("MCP server initialized successfully") return nil } // initializeOAuth initializes OAuth authentication components func (s *Service) initializeOAuth() error { - log.Println("Initializing OAuth authentication...") + logger.Infof("Initializing OAuth authentication...") // Validate OAuth configuration if err := s.cfg.OAuthConfig.Validate(); err != nil { @@ -151,7 +151,7 @@ func (s *Service) initializeOAuth() error { // Create endpoint manager s.endpointManager = oauth.NewEndpointManager(provider, s.cfg) - log.Printf("OAuth authentication initialized with tenant: %s", s.cfg.OAuthConfig.TenantID) + logger.Infof("OAuth authentication initialized with tenant: %s", s.cfg.OAuthConfig.TenantID) return nil } @@ -169,12 +169,12 @@ func (s *Service) registerAllComponents() { // registerPrompts registers all available prompts func (s *Service) registerPrompts() { - log.Println("Registering Prompts...") + logger.Infof("Registering Prompts...") - log.Println("Registering config prompts (query_aks_cluster_metadata_from_kubeconfig)") + logger.Debugf("Registering config prompts (query_aks_cluster_metadata_from_kubeconfig)") prompts.RegisterQueryAKSMetadataFromKubeconfigPrompt(s.mcpServer, s.cfg) - log.Println("Registering health prompts (check_cluster_health)") + logger.Debugf("Registering health prompts (check_cluster_health)") prompts.RegisterHealthPrompts(s.mcpServer, s.cfg) } @@ -186,9 +186,9 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { // Register OAuth endpoints if OAuth is enabled if s.cfg.OAuthConfig.Enabled { if s.endpointManager == nil { - log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") + logger.Errorf("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") } - log.Println("Registering OAuth endpoints...") + logger.Infof("Registering OAuth endpoints...") s.endpointManager.RegisterEndpoints(mux) } @@ -244,16 +244,16 @@ func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, // Register OAuth endpoints if OAuth is enabled if s.cfg.OAuthConfig.Enabled { if s.endpointManager == nil { - log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") + logger.Errorf("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") } - log.Println("Registering OAuth endpoints for SSE server...") + logger.Infof("Registering OAuth endpoints for SSE server...") s.endpointManager.RegisterEndpoints(mux) } // Register SSE and Message handlers with authentication if enabled if s.cfg.OAuthConfig.Enabled { if s.authMiddleware == nil { - log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") + logger.Errorf("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") } // Apply authentication middleware to SSE and Message endpoints mux.Handle("/sse", s.authMiddleware.Middleware(sseServer.SSEHandler())) @@ -314,12 +314,12 @@ func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, // Run starts the service with the specified transport func (s *Service) Run() error { - log.Println("AKS MCP version:", version.GetVersion()) + logger.Infof("AKS MCP version: %s", version.GetVersion()) // Start the server switch s.cfg.Transport { case "stdio": - log.Println("Listening for requests on STDIO...") + logger.Infof("Listening for requests on STDIO...") return server.ServeStdio(s.mcpServer) case "sse": addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) @@ -330,13 +330,13 @@ func (s *Service) Run() error { // Create custom HTTP server with helpful 404 responses customServer := s.createCustomSSEServerWithHelp404(sse, addr) - log.Printf("SSE server listening on %s", addr) - log.Printf("SSE endpoint available at: http://%s/sse", addr) - log.Printf("Message endpoint available at: http://%s/message", addr) - log.Printf("Connect to /sse for real-time events, send JSON-RPC to /message") + logger.Infof("SSE server listening on %s", addr) + logger.Infof("SSE endpoint available at: http://%s/sse", addr) + logger.Infof("Message endpoint available at: http://%s/message", addr) + logger.Infof("Connect to /sse for real-time events, send JSON-RPC to /message") if s.cfg.OAuthConfig.Enabled { - log.Printf("OAuth authentication enabled - Bearer token required for SSE and Message endpoints") - log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) + logger.Infof("OAuth authentication enabled - Bearer token required for SSE and Message endpoints") + logger.Infof("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) } return customServer.ListenAndServe() @@ -356,7 +356,7 @@ func (s *Service) Run() error { if mux, ok := customServer.Handler.(*http.ServeMux); ok { if s.cfg.OAuthConfig.Enabled { if s.authMiddleware == nil { - log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") + logger.Errorf("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") } // Apply authentication middleware to MCP endpoint mux.Handle("/mcp", s.authMiddleware.Middleware(streamableServer)) @@ -366,12 +366,12 @@ func (s *Service) Run() error { } } - log.Printf("Streamable HTTP server listening on %s", addr) - log.Printf("MCP endpoint available at: http://%s/mcp", addr) - log.Printf("Send POST requests to /mcp to initialize session and obtain Mcp-Session-Id") + logger.Infof("Streamable HTTP server listening on %s", addr) + logger.Infof("MCP endpoint available at: http://%s/mcp", addr) + logger.Infof("Send POST requests to /mcp to initialize session and obtain Mcp-Session-Id") if s.cfg.OAuthConfig.Enabled { - log.Printf("OAuth authentication enabled - Bearer token required for MCP endpoint") - log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) + logger.Infof("OAuth authentication enabled - Bearer token required for MCP endpoint") + logger.Infof("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) } return customServer.ListenAndServe() @@ -382,7 +382,7 @@ func (s *Service) Run() error { // registerAzureComponents registers all Azure tools (AKS operations, monitoring, fleet, network, compute, detectors, advisor) func (s *Service) registerAzureComponents() { - log.Println("Registering Azure Components...") + logger.Infof("Registering Azure Components...") // AKS Operations Component s.registerAksOpsComponent() @@ -408,12 +408,12 @@ func (s *Service) registerAzureComponents() { // Register Inspektor Gadget tools for observability s.registerInspektorGadgetComponent() - log.Println("Azure Components registered successfully") + logger.Infof("Azure Components registered successfully") } // registerKubernetesComponents registers Kubernetes-related tools (kubectl, helm, cilium, observability) func (s *Service) registerKubernetesComponents() { - log.Println("Registering Kubernetes Components...") + logger.Infof("Registering Kubernetes Components...") // Core Kubernetes Component (kubectl) s.registerKubectlComponent() @@ -421,12 +421,12 @@ func (s *Service) registerKubernetesComponents() { // Optional Kubernetes Components (based on configuration) s.registerOptionalKubernetesComponents() - log.Println("Kubernetes Components registered successfully") + logger.Infof("Kubernetes Components registered successfully") } // registerKubectlComponent registers core kubectl commands based on access level func (s *Service) registerKubectlComponent() { - log.Println("Registering Core Kubernetes Component (kubectl)") + logger.Debugf("Registering Core Kubernetes Component (kubectl)") // Get kubectl tools filtered by access level kubectlTools := kubectl.RegisterKubectlTools(s.cfg.AccessLevel) @@ -439,7 +439,7 @@ func (s *Service) registerKubectlComponent() { // Register each kubectl tool for _, tool := range kubectlTools { - log.Printf("Registering kubectl tool: %s", tool.Name) + logger.Debugf("Registering kubectl tool: %s", tool.Name) // Create a handler that injects the tool name into params handler := k8stools.CreateToolHandlerWithName(kubectlExecutor, k8sCfg, tool.Name) s.mcpServer.AddTool(tool, handler) @@ -448,7 +448,7 @@ func (s *Service) registerKubectlComponent() { // registerOptionalKubernetesComponents registers optional Kubernetes tools based on configuration func (s *Service) registerOptionalKubernetesComponents() { - log.Println("Registering Optional Kubernetes Components") + logger.Debugf("Registering Optional Kubernetes Components") // Register helm if enabled s.registerHelmComponent() @@ -461,7 +461,7 @@ func (s *Service) registerOptionalKubernetesComponents() { // Log if no optional components are enabled if !s.cfg.AdditionalTools["helm"] && !s.cfg.AdditionalTools["cilium"] && !s.cfg.AdditionalTools["hubble"] { - log.Println("No optional Kubernetes components enabled") + logger.Infof("No optional Kubernetes components enabled") } } @@ -470,80 +470,80 @@ func (s *Service) registerInspektorGadgetComponent() { gadgetMgr := inspektorgadget.NewGadgetManager() // Register Inspektor Gadget tool - log.Println("Registering Inspektor Gadget Observability tool: inspektor_gadget_observability") + logger.Debugf("Registering Inspektor Gadget Observability tool: inspektor_gadget_observability") inspektorGadget := inspektorgadget.RegisterInspektorGadgetTool() s.mcpServer.AddTool(inspektorGadget, tools.CreateResourceHandler(inspektorgadget.InspektorGadgetHandler(gadgetMgr, s.cfg), s.cfg)) } // registerAksOpsComponent registers AKS operations tools func (s *Service) registerAksOpsComponent() { - log.Println("Registering AKS operations tool: az_aks_operations") + logger.Debugf("Registering AKS operations tool: az_aks_operations") aksOperationsTool := azaks.RegisterAzAksOperations(s.cfg) s.mcpServer.AddTool(aksOperationsTool, tools.CreateToolHandler(azaks.NewAksOperationsExecutor(), s.cfg)) } // registerMonitoringComponent registers Azure monitoring tools func (s *Service) registerMonitoringComponent() { - log.Println("Registering monitoring tool: az_monitoring") + logger.Debugf("Registering monitoring tool: az_monitoring") monitoringTool := monitor.RegisterAzMonitoring() s.mcpServer.AddTool(monitoringTool, tools.CreateResourceHandler(monitor.GetAzMonitoringHandler(s.azClient, s.cfg), s.cfg)) } // registerFleetComponent registers Azure fleet management tools func (s *Service) registerFleetComponent() { - log.Println("Registering fleet tool: az_fleet") + logger.Debugf("Registering fleet tool: az_fleet") fleetTool := fleet.RegisterFleet() s.mcpServer.AddTool(fleetTool, tools.CreateToolHandler(azcli.NewFleetExecutor(), s.cfg)) } // registerAdvisorComponent registers Azure advisor tools func (s *Service) registerAdvisorComponent() { - log.Println("Registering advisor tool: az_advisor_recommendation") + logger.Debugf("Registering advisor tool: az_advisor_recommendation") advisorTool := advisor.RegisterAdvisorRecommendationTool() s.mcpServer.AddTool(advisorTool, tools.CreateResourceHandler(advisor.GetAdvisorRecommendationHandler(s.cfg), s.cfg)) } // registerNetworkComponent registers network-related Azure resource tools func (s *Service) registerNetworkComponent() { - log.Println("Registering Network Resources Component") + logger.Debugf("Registering Network Resources Component") // Register network resources tool - log.Println("Registering network tool: az_network_resources") + logger.Debugf("Registering network tool: az_network_resources") networkTool := network.RegisterAzNetworkResources() s.mcpServer.AddTool(networkTool, tools.CreateResourceHandler(network.GetAzNetworkResourcesHandler(s.azClient, s.cfg), s.cfg)) } // registerComputeComponent registers compute-related Azure resource tools (VMSS/VM) func (s *Service) registerComputeComponent() { - log.Println("Registering Compute Resources Component") + logger.Debugf("Registering Compute Resources Component") // Register AKS VMSS info tool (supports both single node pool and all node pools) - log.Println("Registering compute tool: get_aks_vmss_info") + logger.Debugf("Registering compute tool: get_aks_vmss_info") vmssInfoTool := compute.RegisterAKSVMSSInfoTool() s.mcpServer.AddTool(vmssInfoTool, tools.CreateResourceHandler(compute.GetAKSVMSSInfoHandler(s.azClient, s.cfg), s.cfg)) // Register unified compute operations tool - log.Println("Registering compute tool: az_compute_operations") + logger.Debugf("Registering compute tool: az_compute_operations") computeOperationsTool := compute.RegisterAzComputeOperations(s.cfg) s.mcpServer.AddTool(computeOperationsTool, tools.CreateToolHandler(compute.NewComputeOperationsExecutor(), s.cfg)) } // registerDetectorComponent registers detector-related Azure resource tools func (s *Service) registerDetectorComponent() { - log.Println("Registering Detector Resources Component") + logger.Debugf("Registering Detector Resources Component") // Register list detectors tool - log.Println("Registering detector tool: list_detectors") + logger.Debugf("Registering detector tool: list_detectors") listTool := detectors.RegisterListDetectorsTool() s.mcpServer.AddTool(listTool, tools.CreateResourceHandler(detectors.GetListDetectorsHandler(s.azClient, s.cfg), s.cfg)) // Register run detector tool - log.Println("Registering detector tool: run_detector") + logger.Debugf("Registering detector tool: run_detector") runTool := detectors.RegisterRunDetectorTool() s.mcpServer.AddTool(runTool, tools.CreateResourceHandler(detectors.GetRunDetectorHandler(s.azClient, s.cfg), s.cfg)) // Register run detectors by category tool - log.Println("Registering detector tool: run_detectors_by_category") + logger.Debugf("Registering detector tool: run_detectors_by_category") categoryTool := detectors.RegisterRunDetectorsByCategoryTool() s.mcpServer.AddTool(categoryTool, tools.CreateResourceHandler(detectors.GetRunDetectorsByCategoryHandler(s.azClient, s.cfg), s.cfg)) } @@ -551,7 +551,7 @@ func (s *Service) registerDetectorComponent() { // registerHelmComponent registers helm tools if enabled func (s *Service) registerHelmComponent() { if s.cfg.AdditionalTools["helm"] { - log.Println("Registering Kubernetes tool: helm") + logger.Debugf("Registering Kubernetes tool: helm") helmTool := helm.RegisterHelm() helmExecutor := k8s.WrapK8sExecutor(helm.NewExecutor()) s.mcpServer.AddTool(helmTool, tools.CreateToolHandler(helmExecutor, s.cfg)) @@ -561,7 +561,7 @@ func (s *Service) registerHelmComponent() { // registerCiliumComponent registers cilium tools if enabled func (s *Service) registerCiliumComponent() { if s.cfg.AdditionalTools["cilium"] { - log.Println("Registering Kubernetes tool: cilium") + logger.Debugf("Registering Kubernetes tool: cilium") ciliumTool := cilium.RegisterCilium() ciliumExecutor := k8s.WrapK8sExecutor(cilium.NewExecutor()) s.mcpServer.AddTool(ciliumTool, tools.CreateToolHandler(ciliumExecutor, s.cfg)) @@ -571,7 +571,7 @@ func (s *Service) registerCiliumComponent() { // registerHubbleComponent registers hubble tools if enabled func (s *Service) registerHubbleComponent() { if s.cfg.AdditionalTools["hubble"] { - log.Println("Registering Kubernetes tool: hubble") + logger.Debugf("Registering Kubernetes tool: hubble") hubbleTool := hubble.RegisterHubble() hubbleExecutor := k8s.WrapK8sExecutor(hubble.NewExecutor()) s.mcpServer.AddTool(hubbleTool, tools.CreateToolHandler(hubbleExecutor, s.cfg)) diff --git a/internal/telemetry/service.go b/internal/telemetry/service.go index d2280e1..64b9acd 100644 --- a/internal/telemetry/service.go +++ b/internal/telemetry/service.go @@ -3,9 +3,9 @@ package telemetry import ( "context" "fmt" - "log" "time" + "github.com/Azure/aks-mcp/internal/logger" "github.com/microsoft/ApplicationInsights-Go/appinsights" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -74,7 +74,7 @@ func (s *Service) initializeTracing(ctx context.Context) error { otlptracegrpc.WithInsecure(), ) if err != nil { - log.Printf("Failed to create OTLP gRPC exporter: %v", err) + logger.Errorf("Failed to create OTLP gRPC exporter: %v", err) } else { exporters = append(exporters, otlpExporter) } diff --git a/internal/tools/handler.go b/internal/tools/handler.go index c66f216..8b00912 100644 --- a/internal/tools/handler.go +++ b/internal/tools/handler.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "fmt" - "log" "github.com/Azure/aks-mcp/internal/config" + "github.com/Azure/aks-mcp/internal/logger" "github.com/mark3labs/mcp-go/mcp" ) @@ -14,29 +14,27 @@ import ( func logToolCall(toolName string, arguments interface{}) { // Try to format as JSON for better readability if jsonBytes, err := json.Marshal(arguments); err == nil { - log.Printf("\n>>> [%s] %s", toolName, string(jsonBytes)) + logger.Debugf("\n>>> [%s] %s", toolName, string(jsonBytes)) } else { - log.Printf("\n>>> [%s] %v", toolName, arguments) + logger.Debugf("\n>>> [%s] %v", toolName, arguments) } } // logToolResult logs the result or error of a tool call func logToolResult(toolName string, result string, err error) { if err != nil { - log.Printf("\n<<< [%s] ERROR: %v", toolName, err) + logger.Debugf("\n<<< [%s] ERROR: %v", toolName, err) } else if len(result) > 500 { - log.Printf("\n<<< [%s] Result: %d bytes (truncated): %.500s...", toolName, len(result), result) + logger.Debugf("\n<<< [%s] Result: %d bytes (truncated): %.500s...", toolName, len(result), result) } else { - log.Printf("\n<<< [%s] Result: %s", toolName, result) + logger.Debugf("\n<<< [%s] Result: %s", toolName, result) } } // CreateToolHandler creates an adapter that converts CommandExecutor to the format expected by MCP server func CreateToolHandler(executor CommandExecutor, cfg *config.ConfigData) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - if cfg.Verbose { - logToolCall(req.Params.Name, req.Params.Arguments) - } + logToolCall(req.Params.Name, req.Params.Arguments) args, ok := req.Params.Arguments.(map[string]interface{}) if !ok { @@ -54,9 +52,7 @@ func CreateToolHandler(executor CommandExecutor, cfg *config.ConfigData) func(ct cfg.TelemetryService.TrackToolInvocation(ctx, req.Params.Name, operation, err == nil) } - if cfg.Verbose { - logToolResult(req.Params.Name, result, err) - } + logToolResult(req.Params.Name, result, err) if err != nil { // Include command output (often stderr) in the error for context @@ -73,9 +69,7 @@ func CreateToolHandler(executor CommandExecutor, cfg *config.ConfigData) func(ct // CreateResourceHandler creates an adapter that converts ResourceHandler to the format expected by MCP server func CreateResourceHandler(handler ResourceHandler, cfg *config.ConfigData) func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - if cfg.Verbose { - logToolCall(req.Params.Name, req.Params.Arguments) - } + logToolCall(req.Params.Name, req.Params.Arguments) args, ok := req.Params.Arguments.(map[string]interface{}) if !ok { @@ -95,9 +89,7 @@ func CreateResourceHandler(handler ResourceHandler, cfg *config.ConfigData) func cfg.TelemetryService.TrackToolInvocation(ctx, req.Params.Name, operation, err == nil) } - if cfg.Verbose { - logToolResult(req.Params.Name, result, err) - } + logToolResult(req.Params.Name, result, err) if err != nil { // Include handler output in the error message for better diagnostics diff --git a/internal/tools/handler_test.go b/internal/tools/handler_test.go index f629273..a534148 100644 --- a/internal/tools/handler_test.go +++ b/internal/tools/handler_test.go @@ -143,7 +143,7 @@ func TestCreateResourceHandler_ErrorWithoutOutput(t *testing.T) { func TestCreateToolHandler_Success_Verbose_Telemetry_LongResult(t *testing.T) { cfg := config.NewConfig() - cfg.Verbose = true // exercise logToolCall + logToolResult + cfg.LogLevel = "debug" // exercise logToolCall + logToolResult // Provide non-nil telemetry to exercise TrackToolInvocation path cfg.TelemetryService = telemetry.NewService(telemetry.NewConfig("svc", "1.0")) @@ -177,7 +177,7 @@ func TestCreateToolHandler_Success_Verbose_Telemetry_LongResult(t *testing.T) { func TestCreateToolHandler_InvalidArguments_Verbose_LogsFallback_TracksTelemetry(t *testing.T) { cfg := config.NewConfig() - cfg.Verbose = true + cfg.LogLevel = "debug" cfg.TelemetryService = telemetry.NewService(telemetry.NewConfig("svc", "1.0")) exec := CommandExecutorFunc(func(params map[string]interface{}, _ *config.ConfigData) (string, error) { @@ -210,7 +210,7 @@ func TestCreateToolHandler_InvalidArguments_Verbose_LogsFallback_TracksTelemetry func TestCreateResourceHandler_ShortSuccess_Verbose_Telemetry(t *testing.T) { cfg := config.NewConfig() - cfg.Verbose = true + cfg.LogLevel = "debug" cfg.TelemetryService = telemetry.NewService(telemetry.NewConfig("svc", "1.0")) rh := ResourceHandlerFunc(func(params map[string]interface{}, _ *config.ConfigData) (string, error) { @@ -240,7 +240,7 @@ func TestCreateResourceHandler_ShortSuccess_Verbose_Telemetry(t *testing.T) { func TestCreateResourceHandler_InvalidArguments_Verbose_LogsFallback_TracksTelemetry(t *testing.T) { cfg := config.NewConfig() - cfg.Verbose = true + cfg.LogLevel = "debug" cfg.TelemetryService = telemetry.NewService(telemetry.NewConfig("svc", "1.0")) rh := ResourceHandlerFunc(func(params map[string]interface{}, _ *config.ConfigData) (string, error) { @@ -272,7 +272,7 @@ func TestCreateResourceHandler_InvalidArguments_Verbose_LogsFallback_TracksTelem func TestCreateToolHandler_Error_Verbose_LogErrorBranch(t *testing.T) { cfg := config.NewConfig() - cfg.Verbose = true + cfg.LogLevel = "debug" exec := CommandExecutorFunc(func(params map[string]interface{}, _ *config.ConfigData) (string, error) { return "", errors.New("boom")