diff --git a/examples/http/logging_middleware.go b/examples/http/logging_middleware.go index 4266012c..97d8ad3e 100644 --- a/examples/http/logging_middleware.go +++ b/examples/http/logging_middleware.go @@ -5,47 +5,49 @@ package main import ( + "context" "log" - "net/http" "time" -) - -// responseWriter wraps http.ResponseWriter to capture the status code. -type responseWriter struct { - http.ResponseWriter - statusCode int -} -func (rw *responseWriter) WriteHeader(code int) { - rw.statusCode = code - rw.ResponseWriter.WriteHeader(code) -} + "github.com/modelcontextprotocol/go-sdk/mcp" +) -func loggingHandler(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - // Create a response writer wrapper to capture status code. - wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} - - // Log request details. - log.Printf("[REQUEST] %s | %s | %s %s", - start.Format(time.RFC3339), - r.RemoteAddr, - r.Method, - r.URL.Path) - - // Call the actual handler. - handler.ServeHTTP(wrapped, r) - - // Log response details. - duration := time.Since(start) - log.Printf("[RESPONSE] %s | %s | %s %s | Status: %d | Duration: %v", - time.Now().Format(time.RFC3339), - r.RemoteAddr, - r.Method, - r.URL.Path, - wrapped.statusCode, - duration) - }) +// createLoggingMiddleware creates an MCP middleware that logs method calls. +func createLoggingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func( + ctx context.Context, + method string, + req mcp.Request, + ) (mcp.Result, error) { + start := time.Now() + sessionID := req.GetSession().ID() + + // Log request details. + log.Printf("[REQUEST] Session: %s | Method: %s", + sessionID, + method) + + // Call the actual handler. + result, err := next(ctx, method, req) + + // Log response details. + duration := time.Since(start) + + if err != nil { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: ERROR | Duration: %v | Error: %v", + sessionID, + method, + duration, + err) + } else { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: OK | Duration: %v", + sessionID, + method, + duration) + } + + return result, err + } + } } diff --git a/examples/http/main.go b/examples/http/main.go index 188674ae..ced4871e 100644 --- a/examples/http/main.go +++ b/examples/http/main.go @@ -117,6 +117,9 @@ func runServer(url string) { Version: "1.0.0", }, nil) + // Add MCP-level logging middleware. + server.AddReceivingMiddleware(createLoggingMiddleware()) + // Add the cityTime tool. mcp.AddTool(server, &mcp.Tool{ Name: "cityTime", @@ -128,13 +131,11 @@ func runServer(url string) { return server }, nil) - handlerWithLogging := loggingHandler(handler) - log.Printf("MCP server listening on %s", url) log.Printf("Available tool: cityTime (cities: nyc, sf, boston)") - // Start the HTTP server with logging handler. - if err := http.ListenAndServe(url, handlerWithLogging); err != nil { + // Start the HTTP server. + if err := http.ListenAndServe(url, handler); err != nil { log.Fatalf("Server failed: %v", err) } }