Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions end-to-end-test/end_to_end_test/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_build_show_list_download_infer(cog_server_port_dir, tmpdir_factory):

user = "".join(random.choice(string.ascii_lowercase) for i in range(10))
repo_name = "".join(random.choice(string.ascii_lowercase) for i in range(10))
repo = f"localhost:{cog_port}/{user}/{repo_name}"
repo = f"http://localhost:{cog_port}/{user}/{repo_name}"

project_dir = tmpdir_factory.mktemp("project")
with open(project_dir / "infer.py", "w") as f:
Expand Down Expand Up @@ -78,11 +78,11 @@ def run(self, text, path):
)

out, _ = subprocess.Popen(
["cog", "repo", "set", f"localhost:{cog_port}/{user}/{repo_name}"],
["cog", "repo", "set", f"http://localhost:{cog_port}/{user}/{repo_name}"],
stdout=subprocess.PIPE,
cwd=project_dir,
).communicate()
assert out.decode() == f"Updated repo: localhost:{cog_port}/{user}/{repo_name}\n"
assert out.decode() == f"Updated repo: http://localhost:{cog_port}/{user}/{repo_name}\n"

with open(project_dir / "myfile.txt", "w") as f:
f.write("baz")
Expand Down
22 changes: 17 additions & 5 deletions pkg/cli/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ func newListCommand() *cobra.Command {
}
addRepoFlag(cmd)

cmd.Flags().BoolP("quiet", "q", false, "Quite output, only display IDs")

return cmd
}

Expand All @@ -30,6 +32,10 @@ func listModels(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
quiet, err := cmd.Flags().GetBool("quiet")
if err != nil {
return err
}

cli := client.NewClient()
models, err := cli.ListModels(repo)
Expand All @@ -41,12 +47,18 @@ func listModels(cmd *cobra.Command, args []string) error {
return models[i].Created.After(models[j].Created)
})

w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tCREATED")
for _, mod := range models {
fmt.Fprintf(w, "%s\t%s\n", mod.ID, timeago.English.Format(mod.Created))
if quiet {
for _, mod := range models {
fmt.Println(mod.ID)
}
} else {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tCREATED")
for _, mod := range models {
fmt.Fprintf(w, "%s\t%s\n", mod.ID, timeago.English.Format(mod.Created))
}
w.Flush()
}
w.Flush()

return nil
}
84 changes: 84 additions & 0 deletions pkg/cli/login.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package cli

import (
"bufio"
"fmt"
"os"
"os/exec"
"runtime"
"strings"

"github.com/spf13/cobra"

"github.com/replicate/cog/pkg/client"
"github.com/replicate/cog/pkg/console"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/settings"
)

type VerifyResponse struct {
Username string `json:"username"`
}

func newLoginCommand() *cobra.Command {
var cmd = &cobra.Command{
Use: "login [COG_SERVER_ADDRESS]",
SuggestFor: []string{"auth", "authenticate", "authorize"},
Short: "Authorize the replicate CLI to a Cog server",
RunE: login,
Args: cobra.MaximumNArgs(1),
}

return cmd
}

func login(cmd *cobra.Command, args []string) error {
address := global.CogServerAddress
if len(args) == 1 {
address = args[0]
}

c := client.NewClient()
url, err := c.GetDisplayTokenURL(address)
if err != nil {
return err
}
if url == "" {
return fmt.Errorf("This server does not support authentication")
}
fmt.Println("Please visit " + url + " in a web browser")
fmt.Println("and copy the authorization token.")
maybeOpenBrowser(url)

fmt.Print("\nPaste the token here: ")
token, err := bufio.NewReader(os.Stdin).ReadString('\n')
token = strings.TrimSpace(token)
if err != nil {
return err
}

username, err := c.VerifyToken(address, token)
if err != nil {
return err
}

err = settings.SaveAuthToken(address, username, token)
if err != nil {
return err
}

console.Infof("Successfully authenticated as %s", username)

return nil
}

func maybeOpenBrowser(url string) {
switch runtime.GOOS {
case "linux":
exec.Command("xdg-open", url).Start()
case "windows":
exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
exec.Command("open", url).Start()
}
}
2 changes: 1 addition & 1 deletion pkg/cli/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func setRepo(cmd *cobra.Command, args []string) error {
}

cli := client.NewClient()
if err := cli.Ping(repo); err != nil {
if err := cli.CheckRead(repo); err != nil {
return err
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
var repoFlag string
var projectDirFlag string

var repoRegex = regexp.MustCompile("^(?:([^/]*)/)?(?:([-_a-zA-Z0-9]+)/)([-_a-zA-Z0-9]+)$")
var repoRegex = regexp.MustCompile("^(?:(https?://[^/]*)/)?(?:([-_a-zA-Z0-9]+)/)([-_a-zA-Z0-9]+)$")

func NewRootCommand() (*cobra.Command, error) {
rootCmd := cobra.Command{
Expand Down Expand Up @@ -46,6 +46,7 @@ func NewRootCommand() (*cobra.Command, error) {
newListCommand(),
newBenchmarkCommand(),
newDeleteCommand(),
newLoginCommand(),
)

return &rootCmd, nil
Expand Down Expand Up @@ -83,7 +84,7 @@ func getRepo() (*model.Repo, error) {
func parseRepo(repoString string) (*model.Repo, error) {
matches := repoRegex.FindStringSubmatch(repoString)
if len(matches) == 0 {
return nil, fmt.Errorf("Repo '%s' doesn't match <host>/<user>/<name>", repoString)
return nil, fmt.Errorf("Repo '%s' doesn't match [http[s]://<host>/]<user>/<name>", repoString)
}
return &model.Repo{
Host: matches[1],
Expand Down
4 changes: 3 additions & 1 deletion pkg/cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
port int
dockerRegistry string
buildWebHooks []string
authDelegate string
)

func newServerCommand() *cobra.Command {
Expand All @@ -33,6 +34,7 @@ func newServerCommand() *cobra.Command {
cmd.Flags().IntVar(&port, "port", 0, "Server port")
cmd.Flags().StringVar(&dockerRegistry, "docker-registry", "", "Docker registry to push images to")
cmd.Flags().StringArrayVar(&buildWebHooks, "web-hook", []string{}, "Web hooks that are posted to after build. Format: <url>@<secret>")
cmd.Flags().StringVar(&authDelegate, "auth-delegate", "", "Address to service that handles authentication logic")
return cmd
}

Expand Down Expand Up @@ -78,7 +80,7 @@ func startServer(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
s, err := server.NewServer(port, buildWebHooks, db, dockerImageBuilder, servingPlatform, store)
s, err := server.NewServer(port, buildWebHooks, authDelegate, db, dockerImageBuilder, servingPlatform, store)
if err != nil {
return err
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/cli/show.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"encoding/json"
"fmt"
"os"
"text/tabwriter"
Expand All @@ -21,10 +22,16 @@ func newShowCommand() *cobra.Command {
}
addRepoFlag(cmd)

cmd.Flags().Bool("json", false, "JSON output")

return cmd
}

func showModel(cmd *cobra.Command, args []string) error {
jsonOutput, err := cmd.Flags().GetBool("json")
if err != nil {
return err
}
repo, err := getRepo()
if err != nil {
return err
Expand All @@ -38,6 +45,15 @@ func showModel(cmd *cobra.Command, args []string) error {
return err
}

if jsonOutput {
data, err := json.MarshalIndent(mod, "", " ")
if err != nil {
return err
}
fmt.Println(string(data))
return nil
}

w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID:\t"+mod.ID)
fmt.Fprintf(w, "Repo:\t%s/%s\n", repo.User, repo.Name)
Expand Down
77 changes: 77 additions & 0 deletions pkg/client/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package client

import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"

"github.com/replicate/cog/pkg/model"
)

func (c *Client) GetDisplayTokenURL(address string) (url string, err error) {
resp, err := http.Get(address + "/v1/auth/display-token-url")
if err != nil {
return "", fmt.Errorf("Failed to get login URL: %w", err)
}
if resp.StatusCode == http.StatusNotFound {
return "", fmt.Errorf("Login page does not exist on %s. Is it the correct URL?", address)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("Login returned status %d", resp.StatusCode)
}

body := &struct {
URL string `json:"url"`
}{}
if err := json.NewDecoder(resp.Body).Decode(body); err != nil {
return "", err
}
return body.URL, nil
}

func (c *Client) VerifyToken(address string, token string) (username string, err error) {
resp, err := http.PostForm(address+"/v1/auth/verify-token", url.Values{
"token": []string{token},
})
if err != nil {
return "", fmt.Errorf("Failed to verify token: %w", err)
}
if resp.StatusCode == http.StatusNotFound {
return "", fmt.Errorf("User does not exist")
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("Failed to verify token, got status %d", resp.StatusCode)
}
body := &struct {
Username string `json:"username"`
}{}
if err := json.NewDecoder(resp.Body).Decode(body); err != nil {
return "", err
}
return body.Username, nil
}

func (c *Client) CheckRead(repo *model.Repo) error {
url := newURL(repo, "v1/repos/%s/%s/check-read", repo.User, repo.Name)
req, err := c.newRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("Failed to read response body: %w", err)
}
body := string(bodyBytes)
if resp.StatusCode != http.StatusOK {
return errors.New(body)
}
return nil
}
Loading