diff --git a/end-to-end-test/end_to_end_test/test_server.py b/end-to-end-test/end_to_end_test/test_server.py index f677b35c44..70fab47861 100644 --- a/end-to-end-test/end_to_end_test/test_server.py +++ b/end-to-end-test/end_to_end_test/test_server.py @@ -37,7 +37,7 @@ def test_build_show_list_download_infer(cog_server_port_dir, tmpdir_factory): user = "".join(random.choice(string.ascii_lowercase) for i in range(10)) repo_name = "".join(random.choice(string.ascii_lowercase) for i in range(10)) - repo = f"localhost:{cog_port}/{user}/{repo_name}" + repo = f"http://localhost:{cog_port}/{user}/{repo_name}" project_dir = tmpdir_factory.mktemp("project") with open(project_dir / "infer.py", "w") as f: @@ -78,11 +78,11 @@ def run(self, text, path): ) out, _ = subprocess.Popen( - ["cog", "repo", "set", f"localhost:{cog_port}/{user}/{repo_name}"], + ["cog", "repo", "set", f"http://localhost:{cog_port}/{user}/{repo_name}"], stdout=subprocess.PIPE, cwd=project_dir, ).communicate() - assert out.decode() == f"Updated repo: localhost:{cog_port}/{user}/{repo_name}\n" + assert out.decode() == f"Updated repo: http://localhost:{cog_port}/{user}/{repo_name}\n" with open(project_dir / "myfile.txt", "w") as f: f.write("baz") diff --git a/pkg/cli/list.go b/pkg/cli/list.go index 2be9fe14e3..508dcb7e7f 100644 --- a/pkg/cli/list.go +++ b/pkg/cli/list.go @@ -22,6 +22,8 @@ func newListCommand() *cobra.Command { } addRepoFlag(cmd) + cmd.Flags().BoolP("quiet", "q", false, "Quite output, only display IDs") + return cmd } @@ -30,6 +32,10 @@ func listModels(cmd *cobra.Command, args []string) error { if err != nil { return err } + quiet, err := cmd.Flags().GetBool("quiet") + if err != nil { + return err + } cli := client.NewClient() models, err := cli.ListModels(repo) @@ -41,12 +47,18 @@ func listModels(cmd *cobra.Command, args []string) error { return models[i].Created.After(models[j].Created) }) - w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) - fmt.Fprintln(w, "ID\tCREATED") - for _, mod := range models { - fmt.Fprintf(w, "%s\t%s\n", mod.ID, timeago.English.Format(mod.Created)) + if quiet { + for _, mod := range models { + fmt.Println(mod.ID) + } + } else { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tCREATED") + for _, mod := range models { + fmt.Fprintf(w, "%s\t%s\n", mod.ID, timeago.English.Format(mod.Created)) + } + w.Flush() } - w.Flush() return nil } diff --git a/pkg/cli/login.go b/pkg/cli/login.go new file mode 100644 index 0000000000..d3c5b49755 --- /dev/null +++ b/pkg/cli/login.go @@ -0,0 +1,84 @@ +package cli + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/client" + "github.com/replicate/cog/pkg/console" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/settings" +) + +type VerifyResponse struct { + Username string `json:"username"` +} + +func newLoginCommand() *cobra.Command { + var cmd = &cobra.Command{ + Use: "login [COG_SERVER_ADDRESS]", + SuggestFor: []string{"auth", "authenticate", "authorize"}, + Short: "Authorize the replicate CLI to a Cog server", + RunE: login, + Args: cobra.MaximumNArgs(1), + } + + return cmd +} + +func login(cmd *cobra.Command, args []string) error { + address := global.CogServerAddress + if len(args) == 1 { + address = args[0] + } + + c := client.NewClient() + url, err := c.GetDisplayTokenURL(address) + if err != nil { + return err + } + if url == "" { + return fmt.Errorf("This server does not support authentication") + } + fmt.Println("Please visit " + url + " in a web browser") + fmt.Println("and copy the authorization token.") + maybeOpenBrowser(url) + + fmt.Print("\nPaste the token here: ") + token, err := bufio.NewReader(os.Stdin).ReadString('\n') + token = strings.TrimSpace(token) + if err != nil { + return err + } + + username, err := c.VerifyToken(address, token) + if err != nil { + return err + } + + err = settings.SaveAuthToken(address, username, token) + if err != nil { + return err + } + + console.Infof("Successfully authenticated as %s", username) + + return nil +} + +func maybeOpenBrowser(url string) { + switch runtime.GOOS { + case "linux": + exec.Command("xdg-open", url).Start() + case "windows": + exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + exec.Command("open", url).Start() + } +} diff --git a/pkg/cli/repo.go b/pkg/cli/repo.go index ab6039dd08..6b55ca6de5 100644 --- a/pkg/cli/repo.go +++ b/pkg/cli/repo.go @@ -70,7 +70,7 @@ func setRepo(cmd *cobra.Command, args []string) error { } cli := client.NewClient() - if err := cli.Ping(repo); err != nil { + if err := cli.CheckRead(repo); err != nil { return err } diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 8c3a94c401..6add614a19 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -16,7 +16,7 @@ import ( var repoFlag string var projectDirFlag string -var repoRegex = regexp.MustCompile("^(?:([^/]*)/)?(?:([-_a-zA-Z0-9]+)/)([-_a-zA-Z0-9]+)$") +var repoRegex = regexp.MustCompile("^(?:(https?://[^/]*)/)?(?:([-_a-zA-Z0-9]+)/)([-_a-zA-Z0-9]+)$") func NewRootCommand() (*cobra.Command, error) { rootCmd := cobra.Command{ @@ -46,6 +46,7 @@ func NewRootCommand() (*cobra.Command, error) { newListCommand(), newBenchmarkCommand(), newDeleteCommand(), + newLoginCommand(), ) return &rootCmd, nil @@ -83,7 +84,7 @@ func getRepo() (*model.Repo, error) { func parseRepo(repoString string) (*model.Repo, error) { matches := repoRegex.FindStringSubmatch(repoString) if len(matches) == 0 { - return nil, fmt.Errorf("Repo '%s' doesn't match //", repoString) + return nil, fmt.Errorf("Repo '%s' doesn't match [http[s]:///]/", repoString) } return &model.Repo{ Host: matches[1], diff --git a/pkg/cli/server.go b/pkg/cli/server.go index 89af8c2de2..a6e27df3ab 100644 --- a/pkg/cli/server.go +++ b/pkg/cli/server.go @@ -20,6 +20,7 @@ var ( port int dockerRegistry string buildWebHooks []string + authDelegate string ) func newServerCommand() *cobra.Command { @@ -33,6 +34,7 @@ func newServerCommand() *cobra.Command { cmd.Flags().IntVar(&port, "port", 0, "Server port") cmd.Flags().StringVar(&dockerRegistry, "docker-registry", "", "Docker registry to push images to") cmd.Flags().StringArrayVar(&buildWebHooks, "web-hook", []string{}, "Web hooks that are posted to after build. Format: @") + cmd.Flags().StringVar(&authDelegate, "auth-delegate", "", "Address to service that handles authentication logic") return cmd } @@ -78,7 +80,7 @@ func startServer(cmd *cobra.Command, args []string) error { if err != nil { return err } - s, err := server.NewServer(port, buildWebHooks, db, dockerImageBuilder, servingPlatform, store) + s, err := server.NewServer(port, buildWebHooks, authDelegate, db, dockerImageBuilder, servingPlatform, store) if err != nil { return err } diff --git a/pkg/cli/show.go b/pkg/cli/show.go index f80051e3e1..a4b62c28b9 100644 --- a/pkg/cli/show.go +++ b/pkg/cli/show.go @@ -1,6 +1,7 @@ package cli import ( + "encoding/json" "fmt" "os" "text/tabwriter" @@ -21,10 +22,16 @@ func newShowCommand() *cobra.Command { } addRepoFlag(cmd) + cmd.Flags().Bool("json", false, "JSON output") + return cmd } func showModel(cmd *cobra.Command, args []string) error { + jsonOutput, err := cmd.Flags().GetBool("json") + if err != nil { + return err + } repo, err := getRepo() if err != nil { return err @@ -38,6 +45,15 @@ func showModel(cmd *cobra.Command, args []string) error { return err } + if jsonOutput { + data, err := json.MarshalIndent(mod, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + return nil + } + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) fmt.Fprintln(w, "ID:\t"+mod.ID) fmt.Fprintf(w, "Repo:\t%s/%s\n", repo.User, repo.Name) diff --git a/pkg/client/auth.go b/pkg/client/auth.go new file mode 100644 index 0000000000..1bbb1addeb --- /dev/null +++ b/pkg/client/auth.go @@ -0,0 +1,77 @@ +package client + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/replicate/cog/pkg/model" +) + +func (c *Client) GetDisplayTokenURL(address string) (url string, err error) { + resp, err := http.Get(address + "/v1/auth/display-token-url") + if err != nil { + return "", fmt.Errorf("Failed to get login URL: %w", err) + } + if resp.StatusCode == http.StatusNotFound { + return "", fmt.Errorf("Login page does not exist on %s. Is it the correct URL?", address) + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("Login returned status %d", resp.StatusCode) + } + + body := &struct { + URL string `json:"url"` + }{} + if err := json.NewDecoder(resp.Body).Decode(body); err != nil { + return "", err + } + return body.URL, nil +} + +func (c *Client) VerifyToken(address string, token string) (username string, err error) { + resp, err := http.PostForm(address+"/v1/auth/verify-token", url.Values{ + "token": []string{token}, + }) + if err != nil { + return "", fmt.Errorf("Failed to verify token: %w", err) + } + if resp.StatusCode == http.StatusNotFound { + return "", fmt.Errorf("User does not exist") + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("Failed to verify token, got status %d", resp.StatusCode) + } + body := &struct { + Username string `json:"username"` + }{} + if err := json.NewDecoder(resp.Body).Decode(body); err != nil { + return "", err + } + return body.Username, nil +} + +func (c *Client) CheckRead(repo *model.Repo) error { + url := newURL(repo, "v1/repos/%s/%s/check-read", repo.User, repo.Name) + req, err := c.newRequest(http.MethodGet, url, nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("Failed to read response body: %w", err) + } + body := string(bodyBytes) + if resp.StatusCode != http.StatusOK { + return errors.New(body) + } + return nil +} diff --git a/pkg/client/client.go b/pkg/client/client.go index e4c592947f..bf5b836ef5 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -1,12 +1,38 @@ package client import ( + "encoding/base64" "fmt" - "os" + "io" + "net/http" + "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/settings" ) +type cogURL struct { + repo *model.Repo + path string + args []interface{} +} + +func newURL(repo *model.Repo, path string, args ...interface{}) *cogURL { + u := &cogURL{ + repo: repo, + path: path, + } + if len(args) > 0 { + u.path = fmt.Sprintf(u.path, args...) + } + return u +} + +func (u *cogURL) String() string { + host := hostOrDefault(u.repo) + return fmt.Sprintf("%s/%s", host, u.path) +} + type Client struct { } @@ -14,18 +40,31 @@ func NewClient() *Client { return &Client{} } -func (c *Client) getURL(repo *model.Repo, path string, args ...interface{}) (string, error) { - if len(args) > 0 { - path = fmt.Sprintf(path, args...) +func (c *Client) newRequest(method string, url *cogURL, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, url.String(), body) + if err != nil { + return nil, err + } + c.addAuthHeader(req, url.repo) + return req, nil +} + +func (c *Client) addAuthHeader(req *http.Request, repo *model.Repo) error { + host := hostOrDefault(repo) + token, err := settings.LoadAuthToken(host) + if err != nil { + return err + } + if token != "" { + tokenBase64 := base64.StdEncoding.EncodeToString([]byte(token)) + req.Header.Add("Authorization", "Bearer "+tokenBase64) } - var host string + return nil +} + +func hostOrDefault(repo *model.Repo) string { if repo.Host != "" { - host = repo.Host - } else { - host = os.Getenv("COG_INTERNAL_DEFAULT_SERVER") - if host == "" { - return "", fmt.Errorf("Repo is missing host. It should be in the format 'host/user/repository'") - } + return repo.Host } - return fmt.Sprintf("http://%s/%s", host, path), nil + return global.CogServerAddress } diff --git a/pkg/client/delete.go b/pkg/client/delete.go index 521f5a565b..ba7f9ace45 100644 --- a/pkg/client/delete.go +++ b/pkg/client/delete.go @@ -8,11 +8,8 @@ import ( ) func (c *Client) DeleteModel(repo *model.Repo, id string) error { - url, err := c.getURL(repo, "v1/repos/%s/%s/models/%s", repo.User, repo.Name, id) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodDelete, url, nil) + url := newURL(repo, "v1/repos/%s/%s/models/%s", repo.User, repo.Name, id) + req, err := c.newRequest(http.MethodDelete, url, nil) if err != nil { return err } diff --git a/pkg/client/download.go b/pkg/client/download.go index c1cc35fceb..cb3a41f9b4 100644 --- a/pkg/client/download.go +++ b/pkg/client/download.go @@ -13,14 +13,14 @@ import ( ) func (c *Client) DownloadModel(repo *model.Repo, id string, outputDir string) error { - url, err := c.getURL(repo, "v1/repos/%s/%s/models/%s.zip", repo.User, repo.Name, id) - if err != nil { - return err - } - req, err := http.NewRequest("GET", url, nil) + url := newURL(repo, "v1/repos/%s/%s/models/%s.zip", repo.User, repo.Name, id) + req, err := c.newRequest("GET", url, nil) if err != nil { return fmt.Errorf("Failed to create HTTP request: %w", err) } + if err := c.addAuthHeader(req, repo); err != nil { + return err + } resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("Failed to perform HTTP request: %w", err) diff --git a/pkg/client/get.go b/pkg/client/get.go index a395bf324d..20e74fc583 100644 --- a/pkg/client/get.go +++ b/pkg/client/get.go @@ -10,11 +10,12 @@ import ( ) func (c *Client) GetModel(repo *model.Repo, id string) (*model.Model, error) { - url, err := c.getURL(repo, "v1/repos/%s/%s/models/%s", repo.User, repo.Name, id) + url := newURL(repo, "v1/repos/%s/%s/models/%s", repo.User, repo.Name, id) + req, err := c.newRequest(http.MethodGet, url, nil) if err != nil { return nil, err } - resp, err := http.Get(url) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/pkg/client/list.go b/pkg/client/list.go index 0c847259e9..649ce6a812 100644 --- a/pkg/client/list.go +++ b/pkg/client/list.go @@ -9,15 +9,15 @@ import ( ) func (c *Client) ListModels(repo *model.Repo) ([]*model.Model, error) { - url, err := c.getURL(repo, "v1/repos/%s/%s/models/", repo.User, repo.Name) + url := newURL(repo, "v1/repos/%s/%s/models/", repo.User, repo.Name) + req, err := c.newRequest(http.MethodGet, url, nil) if err != nil { return nil, err } - resp, err := http.Get(url) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } - if resp.StatusCode == http.StatusNotFound { return nil, fmt.Errorf("Repository not found: %s", repo.String()) } diff --git a/pkg/client/ping.go b/pkg/client/ping.go index 5e2252e97d..45288e5cb2 100644 --- a/pkg/client/ping.go +++ b/pkg/client/ping.go @@ -8,11 +8,12 @@ import ( ) func (c *Client) Ping(repo *model.Repo) error { - url, err := c.getURL(repo, "ping") + url := newURL(repo, "ping") + req, err := c.newRequest(http.MethodGet, url, nil) if err != nil { return err } - resp, err := http.Get(url) + resp, err := http.DefaultClient.Do(req) if err != nil { return err } diff --git a/pkg/client/upload.go b/pkg/client/upload.go index 5ac6452171..ac5d990777 100644 --- a/pkg/client/upload.go +++ b/pkg/client/upload.go @@ -29,14 +29,12 @@ func (c *Client) UploadModel(repo *model.Repo, projectDir string) (*model.Model, // we need to disable keepalive. there's a bug i (andreas) haven't // been able to get to the bottom of, where keep-alive requests // are missing content-type + // TODO(andreas): this still breaks from time to time DisableKeepAlives: true, }, } - url, err := c.getURL(repo, "v1/repos/%s/%s/models/", repo.User, repo.Name) - if err != nil { - return nil, err - } - req, err := http.NewRequest(http.MethodPut, url, bodyReader) + url := newURL(repo, "v1/repos/%s/%s/models/", repo.User, repo.Name) + req, err := c.newRequest(http.MethodPut, url, bodyReader) if err != nil { return nil, err } @@ -117,11 +115,12 @@ func (c *Client) UploadModel(repo *model.Repo, projectDir string) (*model.Model, } func (c *Client) getRepoCacheHashes(repo *model.Repo) ([]string, error) { - url, err := c.getURL(repo, "v1/repos/%s/%s/cache-hashes/", repo.User, repo.Name) + url := newURL(repo, "v1/repos/%s/%s/cache-hashes/", repo.User, repo.Name) + req, err := c.newRequest(http.MethodGet, url, nil) if err != nil { return nil, err } - resp, err := http.Get(url) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/pkg/global/global.go b/pkg/global/global.go index c92fce6e46..1219c2bca2 100644 --- a/pkg/global/global.go +++ b/pkg/global/global.go @@ -4,11 +4,11 @@ import ( "time" ) -var Version = "0.0.1" -var BuildTime = "none" - var ( - Verbose = false - StartupTimeout = 5 * time.Minute - ConfigFilename = "cog.yaml" + Version = "0.0.1" + BuildTime = "none" + Verbose = false + StartupTimeout = 5 * time.Minute + ConfigFilename = "cog.yaml" + CogServerAddress = "http://cog.replicate.ai" // TODO(andreas): https ) diff --git a/pkg/model/model.go b/pkg/model/model.go index 42534377be..0b5cbd70a7 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -21,6 +21,7 @@ type Stats struct { SetupTime float64 `json:"setup_time"` RunTime float64 `json:"run_time"` MemoryUsage uint64 `json:"memory_usage"` + CPUUsage float64 `json:"cpu_usage"` } type Artifact struct { diff --git a/pkg/server/auth.go b/pkg/server/auth.go new file mode 100644 index 0000000000..a5cb8c4a2e --- /dev/null +++ b/pkg/server/auth.go @@ -0,0 +1,125 @@ +package server + +import ( + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + + "github.com/replicate/cog/pkg/console" +) + +func (s *Server) GetDisplayTokenURL(w http.ResponseWriter, r *http.Request) { + resp := &struct { + URL string `json:"url"` + }{} + if s.authDelegate != "" { + resp.URL = s.authDelegate + "/display-token" + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + console.Errorf("Failed to decode response json: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } +} + +func (s *Server) VerifyToken(w http.ResponseWriter, r *http.Request) { + if s.authDelegate == "" { + console.Error("Attempted to verify auth token but server has no auth delegate") + w.WriteHeader(http.StatusBadRequest) + return + } + if err := r.ParseForm(); err != nil { + console.Errorf("Failed to parse form: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + token := r.FormValue("token") + resp, err := http.PostForm(s.authDelegate+"/verify-token", url.Values{ + "token": []string{token}, + }) + if err != nil { + console.Errorf("Failed to verify token: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + body := &struct { + Username string `json:"username"` + }{} + if err := json.NewDecoder(resp.Body).Decode(body); err != nil { + console.Errorf("Failed decode json: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(body); err != nil { + console.Errorf("Failed encode json: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (s *Server) checkReadAccess(handler http.HandlerFunc) http.HandlerFunc { + return s.checkAccess(handler, "read") +} + +func (s *Server) checkWriteAccess(handler http.HandlerFunc) http.HandlerFunc { + return s.checkAccess(handler, "write") +} + +func (s *Server) checkAccess(handler http.HandlerFunc, mode string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if handler == nil { + handler = func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + return + } + } + + if s.authDelegate == "" { + handler(w, r) + return + } + user, repo, modelID := getRepoVars(r) + + token := "" + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + tokenBase64 := strings.Split(authHeader, "Bearer ")[1] + tokenBytes, err := base64.StdEncoding.DecodeString(tokenBase64) + if err != nil { + console.Errorf("Failed to decode token: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } + token = string(tokenBytes) + } + + values := url.Values{ + "mode": []string{mode}, + "user": []string{user}, + "repo": []string{repo}, + } + if modelID != "" { + values["model_id"] = []string{modelID} + } + if token != "" { + values["token"] = []string{token} + } + resp, err := http.PostForm(s.authDelegate+"/check-access", values) + if err != nil { + console.Errorf("Auth request failed: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if resp.StatusCode == http.StatusOK { + handler(w, r) + return + } + console.Warnf("Not authorized to %s %s/%s:%s", mode, user, repo, modelID) + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + console.Errorf("Failed to copy body: %v", err) + } + return + } +} diff --git a/pkg/server/download.go b/pkg/server/download.go index d6ba07beb0..812cf25073 100644 --- a/pkg/server/download.go +++ b/pkg/server/download.go @@ -10,7 +10,6 @@ import ( func (s *Server) DownloadModel(w http.ResponseWriter, r *http.Request) { user, name, id := getRepoVars(r) - console.Infof("Received download request for %s/%s/%s", user, name, id) modTime := time.Now() // TODO mod, err := s.db.GetModel(user, name, id) diff --git a/pkg/server/server.go b/pkg/server/server.go index bf828bf005..b2a28ca707 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -23,13 +23,14 @@ const topLevelSourceDir = "source" type Server struct { port int webHooks []*WebHook + authDelegate string db database.Database dockerImageBuilder docker.ImageBuilder servingPlatform serving.Platform store storage.Storage } -func NewServer(port int, rawWebHooks []string, db database.Database, dockerImageBuilder docker.ImageBuilder, servingPlatform serving.Platform, store storage.Storage) (*Server, error) { +func NewServer(port int, rawWebHooks []string, authDelegate string, db database.Database, dockerImageBuilder docker.ImageBuilder, servingPlatform serving.Platform, store storage.Storage) (*Server, error) { webHooks := []*WebHook{} for _, rawWebHook := range rawWebHooks { webHook, err := newWebHook(rawWebHook) @@ -41,6 +42,7 @@ func NewServer(port int, rawWebHooks []string, db database.Database, dockerImage return &Server{ port: port, webHooks: webHooks, + authDelegate: authDelegate, db: db, dockerImageBuilder: dockerImageBuilder, servingPlatform: servingPlatform, @@ -50,6 +52,7 @@ func NewServer(port int, rawWebHooks []string, db database.Database, dockerImage func (s *Server) Start() error { router := mux.NewRouter() + router.Path("/"). Methods(http.MethodGet). HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -62,22 +65,31 @@ func (s *Server) Start() error { }) router.Path("/v1/repos/{user}/{name}/models/{id}.zip"). Methods(http.MethodGet). - HandlerFunc(s.DownloadModel) + HandlerFunc(s.checkReadAccess(s.DownloadModel)) router.Path("/v1/repos/{user}/{name}/models/"). Methods(http.MethodPut). - HandlerFunc(s.ReceiveFile) + HandlerFunc(s.checkWriteAccess(s.ReceiveFile)) router.Path("/v1/repos/{user}/{name}/models/"). Methods(http.MethodGet). - HandlerFunc(s.ListModels) + HandlerFunc(s.checkReadAccess(s.ListModels)) router.Path("/v1/repos/{user}/{name}/models/{id}"). Methods(http.MethodGet). - HandlerFunc(s.SendModelMetadata) + HandlerFunc(s.checkReadAccess(s.SendModelMetadata)) router.Path("/v1/repos/{user}/{name}/models/{id}"). Methods(http.MethodDelete). - HandlerFunc(s.DeleteModel) + HandlerFunc(s.checkWriteAccess(s.DeleteModel)) router.Path("/v1/repos/{user}/{name}/cache-hashes/"). Methods(http.MethodGet). - HandlerFunc(s.GetCacheHashes) + HandlerFunc(s.checkReadAccess(s.GetCacheHashes)) + router.Path("/v1/auth/display-token-url"). + Methods(http.MethodGet). + HandlerFunc(s.GetDisplayTokenURL) + router.Path("/v1/auth/verify-token"). + Methods(http.MethodPost). + HandlerFunc(s.VerifyToken) + router.Path("/v1/repos/{user}/{name}/check-read"). + Methods(http.MethodGet). + HandlerFunc(s.checkReadAccess(nil)) console.Infof("Server running on 0.0.0.0:%d", s.port) loggedRouter := handlers.LoggingHandler(os.Stdout, router) diff --git a/pkg/server/web_hook.go b/pkg/server/web_hook.go index a6be25ddbe..b6fc1c0071 100644 --- a/pkg/server/web_hook.go +++ b/pkg/server/web_hook.go @@ -58,6 +58,7 @@ func (wh *WebHook) run(user string, name string, mod *model.Model, dir string, l "docker_image_cpu": {dockerImageCPU}, "docker_image_gpu": {dockerImageGPU}, "memory_usage": {strconv.FormatUint(mod.Stats.MemoryUsage, 10)}, + "cpu_usage": {fmt.Sprintf("%.2f", mod.Stats.CPUUsage)}, "user": {user}, "repo_name": {name}, "secret": {wh.secret}, diff --git a/pkg/serving/local.go b/pkg/serving/local.go index b975fd2014..10601b5ffb 100644 --- a/pkg/serving/local.go +++ b/pkg/serving/local.go @@ -205,7 +205,7 @@ func (d *LocalDockerDeployment) RunInference(input *Example, logWriter logger.Lo } defer resp.Body.Close() - memoryUsage, err := d.getMemoryUsage() + usedMemoryBytes, usedCPUSecs, err := d.getResourceUsage() if err != nil { return nil, err } @@ -261,28 +261,31 @@ func (d *LocalDockerDeployment) RunInference(input *Example, logWriter logger.Lo MimeType: mimeType, }, }, - SetupTime: setupTime, - RunTime: runTime, - MemoryUsage: memoryUsage, + SetupTime: setupTime, + RunTime: runTime, + UsedMemoryBytes: usedMemoryBytes, + UsedCPUSecs: usedCPUSecs, } return result, nil } -func (d *LocalDockerDeployment) getMemoryUsage() (uint64, error) { - +func (d *LocalDockerDeployment) getResourceUsage() (memoryBytes uint64, cpuSecs float64, err error) { statsReader, err := d.client.ContainerStatsOneShot(context.Background(), d.containerID) if err != nil { - return 0, fmt.Errorf("Failed to get container stats: %w", err) + return 0, 0, fmt.Errorf("Failed to get container stats: %w", err) } statsBody, err := io.ReadAll(statsReader.Body) if err != nil { - return 0, fmt.Errorf("Failed to read container stats: %w", err) + return 0, 0, fmt.Errorf("Failed to read container stats: %w", err) } stats := new(types.Stats) if err := json.Unmarshal(statsBody, stats); err != nil { - return 0, err + return 0, 0, err } - return stats.MemoryStats.MaxUsage, nil + cpuNanos := stats.CPUStats.CPUUsage.TotalUsage + cpuSecs = float64(cpuNanos) / 1e9 + + return stats.MemoryStats.MaxUsage, cpuSecs, nil } func (d *LocalDockerDeployment) Help(logWriter logger.Logger) (*HelpResponse, error) { diff --git a/pkg/serving/platform.go b/pkg/serving/platform.go index da1f64cfce..3235b3d532 100644 --- a/pkg/serving/platform.go +++ b/pkg/serving/platform.go @@ -72,10 +72,11 @@ type ResultValue struct { } type Result struct { - Values map[string]ResultValue - SetupTime float64 - RunTime float64 - MemoryUsage uint64 + Values map[string]ResultValue + SetupTime float64 + RunTime float64 + UsedMemoryBytes uint64 + UsedCPUSecs float64 } type HelpResponse struct { diff --git a/pkg/serving/test.go b/pkg/serving/test.go index 25a205c91e..c82425c733 100644 --- a/pkg/serving/test.go +++ b/pkg/serving/test.go @@ -36,6 +36,7 @@ func TestModel(servingPlatform Platform, imageTag string, config *model.Config, setupTimes := []float64{} runTimes := []float64{} memoryUsages := []float64{} + cpuUsages := []float64{} for _, example := range config.Examples { if err := validateServingExampleInput(help, example.Input); err != nil { return nil, nil, fmt.Errorf("Example input doesn't match run arguments: %w", err) @@ -60,6 +61,9 @@ func TestModel(servingPlatform Platform, imageTag string, config *model.Config, if err != nil { return nil, nil, err } + logWriter.Debugf("Memory usage (bytes): %d", result.UsedMemoryBytes) + logWriter.Debugf("CPU usage (seconds): %.1f", result.UsedCPUSecs) + output := result.Values["output"] outputBytes, err := io.ReadAll(output.Buffer) if err != nil { @@ -77,7 +81,8 @@ func TestModel(servingPlatform Platform, imageTag string, config *model.Config, setupTimes = append(setupTimes, result.SetupTime) runTimes = append(runTimes, result.RunTime) - memoryUsages = append(memoryUsages, float64(result.MemoryUsage)) + memoryUsages = append(memoryUsages, float64(result.UsedMemoryBytes)) + cpuUsages = append(cpuUsages, result.UsedCPUSecs) } if len(setupTimes) > 0 { @@ -94,10 +99,16 @@ func TestModel(servingPlatform Platform, imageTag string, config *model.Config, return nil, nil, err } modelStats.MemoryUsage = uint64(memoryUsage) + cpuUsage, err := stats.Max(cpuUsages) + if err != nil { + return nil, nil, err + } + modelStats.CPUUsage = cpuUsage } else { modelStats.SetupTime = 0 modelStats.RunTime = 0 modelStats.MemoryUsage = 0 + modelStats.CPUUsage = 0 } return help.Arguments, modelStats, nil diff --git a/pkg/settings/user.go b/pkg/settings/user.go new file mode 100644 index 0000000000..6911d809a0 --- /dev/null +++ b/pkg/settings/user.go @@ -0,0 +1,143 @@ +package settings + +import ( + "encoding/json" + "fmt" + "os" + "os/user" + "path" + "runtime" + + "github.com/replicate/cog/pkg/files" +) + +type AuthInfo struct { + Token string `json:"token"` + Username string `json:"username"` +} + +type UserSettings struct { + Auth map[string]AuthInfo `json:"auth"` +} + +func SaveAuthToken(address string, username string, token string) error { + var err error + + settingsPath, err := getUserSettingsPath() + if err != nil { + return err + } + + var settings *UserSettings + exists, err := files.Exists(settingsPath) + if err != nil { + return err + } + if exists { + settings, err = LoadUserSettings() + if err != nil { + return err + } + + } else { + settings = &UserSettings{} + } + if settings.Auth == nil { + settings.Auth = map[string]AuthInfo{} + } + settings.Auth[address] = AuthInfo{ + Token: token, + Username: username, + } + + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err + } + return os.WriteFile(settingsPath, data, 0600) +} + +func LoadUserSettings() (*UserSettings, error) { + settingsPath, err := getUserSettingsPath() + if err != nil { + return nil, fmt.Errorf("Failed to determine settings path") + } + + exists, err := files.Exists(settingsPath) + if err != nil { + return nil, err + } + if !exists { + return new(UserSettings), nil + } + + text, err := os.ReadFile(settingsPath) + if err != nil { + return nil, fmt.Errorf("Failed to load settings. Did you run cog login?") + } + + settings := UserSettings{} + err = json.Unmarshal(text, &settings) + if err != nil { + return nil, fmt.Errorf("%s is corrupted. Please re-run cog login", settingsPath) + } + + return &settings, nil +} + +func LoadAuthToken(address string) (string, error) { + s, err := LoadUserSettings() + if err != nil { + return "", err + } + return s.Token(address), nil +} + +func (s *UserSettings) Token(address string) string { + authToken, ok := s.Auth[address] + if !ok { + return "" + } + return authToken.Token +} + +func (s *UserSettings) Username(address string) (token string, err error) { + authToken, ok := s.Auth[address] + if !ok { + return "", fmt.Errorf("You are not logged in! Run \"cog login\" to get started") + } + return authToken.Username, nil +} + +func getUserSettingsPath() (string, error) { + configDir, err := userConfigDir() + if err != nil { + return "", err + } + + folder := path.Join(configDir, "cog") + err = os.MkdirAll(folder, os.ModePerm) + if err != nil { + return "", err + } + settingsPath := path.Join(folder, "settings.json") + + return settingsPath, nil +} + +func userConfigDir() (string, error) { + switch runtime.GOOS { + case "linux": + return os.UserConfigDir() + case "windows": + return os.UserConfigDir() + case "darwin": + usr, err := user.Current() + if err != nil { + return os.UserConfigDir() + } + return path.Join(usr.HomeDir, ".config"), nil + default: + return os.UserConfigDir() + } +}