diff --git a/cmd/cog/cog.go b/cmd/cog/cog.go index 5b55fb3aae..f872df4d92 100644 --- a/cmd/cog/cog.go +++ b/cmd/cog/cog.go @@ -3,7 +3,6 @@ package main import ( "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/console" - "github.com/replicate/cog/pkg/global" ) func main() { @@ -11,11 +10,6 @@ func main() { if err != nil { console.Fatalf("%f", err) } - defer func() { - if global.Profiler != nil { - global.Profiler.Stop() - } - }() if err = cmd.Execute(); err != nil { console.Fatalf("%s", err) diff --git a/pkg/cli/root.go b/pkg/cli/root.go index e984acdd88..b9a05664a3 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -5,7 +5,6 @@ import ( "os" "regexp" - "github.com/pkg/profile" "github.com/spf13/cobra" "github.com/replicate/cog/pkg/console" @@ -29,9 +28,6 @@ func NewRootCommand() (*cobra.Command, error) { if global.Verbose { console.SetLevel(console.DebugLevel) } - if global.ProfilingEnabled { - global.Profiler = profile.Start(profile.MemProfile) - } cmd.SilenceUsage = true }, SilenceErrors: true, diff --git a/pkg/global/global.go b/pkg/global/global.go index c168bf5076..cddd927f32 100644 --- a/pkg/global/global.go +++ b/pkg/global/global.go @@ -5,12 +5,11 @@ import ( ) var ( - Version = "0.0.1" - BuildTime = "none" - Verbose = false - ProfilingEnabled = false - Profiler interface{ Stop() } = nil - StartupTimeout = 5 * time.Minute - ConfigFilename = "cog.yaml" - CogServerAddress = "http://cog.replicate.ai" // TODO(andreas): https + Version = "0.0.1" + BuildTime = "none" + Verbose = false + ProfilingEnabled = false + StartupTimeout = 5 * time.Minute + ConfigFilename = "cog.yaml" + CogServerAddress = "http://cog.replicate.ai" // TODO(andreas): https ) diff --git a/pkg/server/profiling.go b/pkg/server/profiling.go new file mode 100644 index 0000000000..c01707aec5 --- /dev/null +++ b/pkg/server/profiling.go @@ -0,0 +1,73 @@ +package server + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/pkg/profile" +) + +// Profile responds with the pprof-formatted memory profile. +// Profiling lasts for duration specified in seconds GET parameter, or for 30 seconds if not specified. +// The package initialization registers it as /debug/pprof/profile. +func profileMemory(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64) + if sec <= 0 || err != nil { + sec = 30 + } + + if durationExceedsWriteTimeout(r, float64(sec)) { + pprofServeError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout") + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", `attachment; filename="profile"`) + profileDir, err := os.MkdirTemp("", "cog-mem-profile") + if err != nil { + pprofServeError(w, http.StatusInternalServerError, + fmt.Sprintf("Failed to create temp directory: %s", err)) + return + } + profiler := profile.Start(profile.ProfilePath(profileDir), profile.MemProfile) + + pprofSleep(r, time.Duration(sec)*time.Second) + profiler.Stop() + + file, err := os.Open(filepath.Join(profileDir, "mem.pprof")) + if err != nil { + pprofServeError(w, http.StatusInternalServerError, + fmt.Sprintf("Failed to create temp directory: %s", err)) + } + defer file.Close() + if _, err := io.Copy(w, file); err != nil { + pprofServeError(w, http.StatusInternalServerError, + fmt.Sprintf("Failed to copy memory profile contents: %s", err)) + } +} + +func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool { + srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server) + return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds() +} + +func pprofServeError(w http.ResponseWriter, status int, txt string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.Header().Del("Content-Disposition") + w.WriteHeader(status) + fmt.Fprintln(w, txt) +} + +func pprofSleep(r *http.Request, d time.Duration) { + select { + case <-time.After(d): + case <-r.Context().Done(): + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index b2a28ca707..b26693e059 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,6 +3,7 @@ package server import ( "fmt" "net/http" + "net/http/pprof" "os" "github.com/gorilla/handlers" @@ -11,6 +12,7 @@ import ( "github.com/replicate/cog/pkg/console" "github.com/replicate/cog/pkg/database" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/serving" "github.com/replicate/cog/pkg/storage" ) @@ -90,6 +92,18 @@ func (s *Server) Start() error { router.Path("/v1/repos/{user}/{name}/check-read"). Methods(http.MethodGet). HandlerFunc(s.checkReadAccess(nil)) + + if global.ProfilingEnabled { + router.HandleFunc("/debug/pprof/", pprof.Index) + router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + router.HandleFunc("/debug/pprof/profile", pprof.Profile) + router.HandleFunc("/debug/pprof/profile-mem", profileMemory) + router.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + router.HandleFunc("/debug/pprof/trace", pprof.Trace) + router.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) + router.Handle("/debug/pprof/heap", pprof.Handler("heap")) + } + console.Infof("Server running on 0.0.0.0:%d", s.port) loggedRouter := handlers.LoggingHandler(os.Stdout, router)