diff --git a/end-to-end-test/end_to_end_test/test_server.py b/end-to-end-test/end_to_end_test/test_server.py index a26a6ca417..73dbecea9c 100644 --- a/end-to-end-test/end_to_end_test/test_server.py +++ b/end-to-end-test/end_to_end_test/test_server.py @@ -34,15 +34,10 @@ def cog_server_port_dir(): server_proc.kill() -def test_build_show_list_download_infer(cog_server_port_dir, tmpdir_factory): - cog_port, cog_dir = cog_server_port_dir - - user = "".join(random.choice(string.ascii_lowercase) for i in range(10)) - model_name = "".join(random.choice(string.ascii_lowercase) for i in range(10)) - model = f"http://localhost:{cog_port}/{user}/{model_name}" - - project_dir = tmpdir_factory.mktemp("project") - with open(project_dir / "infer.py", "w") as f: +@pytest.fixture +def project_dir(tmpdir_factory): + tmpdir = tmpdir_factory.mktemp("project") + with open(tmpdir / "infer.py", "w") as f: f.write( """ import time @@ -61,7 +56,7 @@ def run(self, text, path): return self.foo + text + f.read() """ ) - with open(project_dir / "cog.yaml", "w") as f: + with open(tmpdir / "cog.yaml", "w") as f: cog_yaml = """ name: andreas/hello-world model: infer.py:Model @@ -83,8 +78,23 @@ def run(self, text, path): """ f.write(cog_yaml) + return tmpdir + + +def test_build_show_list_download_infer( + cog_server_port_dir, project_dir, tmpdir_factory +): + cog_port, cog_dir = cog_server_port_dir + + user = random_string(10) + model_name = random_string(10) + model_url = f"http://localhost:{cog_port}/{user}/{model_name}" + + with open(os.path.join(project_dir, "cog.yaml")) as f: + cog_yaml = f.read() + out, _ = subprocess.Popen( - ["cog", "model", "set", f"http://localhost:{cog_port}/{user}/{model_name}"], + ["cog", "model", "set", model_url], stdout=subprocess.PIPE, cwd=project_dir, ).communicate() @@ -108,22 +118,18 @@ def run(self, text, path): version_id = out.decode().strip().split("Successfully uploaded version ")[1] out, _ = subprocess.Popen( - ["cog", "--model", model, "show", version_id], stdout=subprocess.PIPE + ["cog", "--model", model_url, "show", version_id], stdout=subprocess.PIPE ).communicate() lines = out.decode().splitlines() assert lines[0] == f"ID: {version_id}" assert lines[1] == f"Model: {user}/{model_name}" - def show_version(): - out, _ = subprocess.Popen( - ["cog", "--model", model, "show", "--json", version_id], stdout=subprocess.PIPE - ).communicate() - return json.loads(out) - - out = show_version() - subprocess.Popen(["cog", "--model", model, "build", "log", "-f", out["build_ids"]["cpu"]]).communicate() + out = show_version(model_url, version_id) + subprocess.Popen( + ["cog", "--model", model_url, "build", "log", "-f", out["build_ids"]["cpu"]] + ).communicate() - out = show_version() + out = show_version(model_url, version_id) assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt" # show without --model @@ -137,14 +143,22 @@ def show_version(): assert lines[1] == f"Model: {user}/{model_name}" out, _ = subprocess.Popen( - ["cog", "--model", model, "ls"], stdout=subprocess.PIPE + ["cog", "--model", model_url, "ls"], stdout=subprocess.PIPE ).communicate() lines = out.decode().splitlines() assert lines[1].startswith(f"{version_id} ") download_dir = tmpdir_factory.mktemp("download") / "my-dir" subprocess.Popen( - ["cog", "--model", model, "download", "--output-dir", download_dir, version_id], + [ + "cog", + "--model", + model_url, + "download", + "--output-dir", + download_dir, + version_id, + ], stdout=subprocess.PIPE, ).communicate() paths = sorted(glob(str(download_dir / "*.*"))) @@ -161,14 +175,17 @@ def show_version(): files_endpoint = f"http://localhost:{cog_port}/v1/models/{user}/{model_name}/versions/{version_id}/files" assert requests.get(f"{files_endpoint}/cog.yaml").text == cog_yaml - assert requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text == "fooquxbaz" + assert ( + requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text + == "fooquxbaz" + ) out_path = output_dir / "out.txt" subprocess.Popen( [ "cog", "--model", - model, + model_url, "infer", "-o", out_path, @@ -184,8 +201,57 @@ def show_version(): assert f.read() == "foobazinput" +def test_push_log(cog_server_port_dir, project_dir): + cog_port, cog_dir = cog_server_port_dir + + user = random_string(10) + model_name = random_string(10) + model_url = f"http://localhost:{cog_port}/{user}/{model_name}" + + out, _ = subprocess.Popen( + ["cog", "model", "set", model_url], + stdout=subprocess.PIPE, + cwd=project_dir, + ).communicate() + assert ( + out.decode() + == f"Updated model: http://localhost:{cog_port}/{user}/{model_name}\n" + ) + + with open(project_dir / "myfile.txt", "w") as f: + f.write("baz") + + out, _ = subprocess.Popen( + ["cog", "push", "--log"], + cwd=project_dir, + stdout=subprocess.PIPE, + ).communicate() + + assert out.decode().startswith("Successfully uploaded version "), ( + out.decode() + " doesn't start with 'Successfully uploaded version'" + ) + version_id = out.decode().strip().split("Successfully uploaded version ")[1] + + out = show_version(model_url, version_id) + assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt" + assert out["images"][0]["arch"] == "cpu" + assert out["images"][0]["run_arguments"]["text"]["type"] == "str" + + def find_free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] + + +def random_string(length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + + +def show_version(model_url, version_id): + out, _ = subprocess.Popen( + ["cog", "--model", model_url, "show", "--json", version_id], + stdout=subprocess.PIPE, + ).communicate() + return json.loads(out) diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 2ad3593dd7..a9e70bb0c3 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -54,21 +54,25 @@ func showBuildLogs(cmd *cobra.Command, args []string) error { return err } for entry := range logChan { - switch entry.Level { - case logger.LevelFatal: - console.Fatal(entry.Line) - case logger.LevelError: - console.Error(entry.Line) - case logger.LevelWarn: - console.Warn(entry.Line) - case logger.LevelStatus: // TODO(andreas): handle status differently or remove - console.Info(entry.Line) - case logger.LevelInfo: - console.Info(entry.Line) - case logger.LevelDebug: - console.Debug(entry.Line) - } + outputLogEntry(entry, "") } return nil } + +func outputLogEntry(entry *client.LogEntry, prefix string) { + switch entry.Level { + case logger.LevelFatal: + console.Fatal(prefix + entry.Line) + case logger.LevelError: + console.Error(prefix + entry.Line) + case logger.LevelWarn: + console.Warn(prefix + entry.Line) + case logger.LevelStatus: // TODO(andreas): handle status differently or remove + console.Info(prefix + entry.Line) + case logger.LevelInfo: + console.Info(prefix + entry.Line) + case logger.LevelDebug: + console.Debug(prefix + entry.Line) + } +} diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 88fdbe9dd8..33a7780b40 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -4,14 +4,21 @@ import ( "fmt" "os" "path" + "sync" "github.com/spf13/cobra" "github.com/replicate/cog/pkg/client" "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/util/console" ) +type archLogEntry struct { + entry *client.LogEntry + arch string +} + func newPushCommand() *cobra.Command { cmd := &cobra.Command{ Use: "push", @@ -22,10 +29,17 @@ func newPushCommand() *cobra.Command { addModelFlag(cmd) addProjectDirFlag(cmd) + cmd.Flags().Bool("log", false, "Follow image build logs after successful push") + return cmd } func push(cmd *cobra.Command, args []string) error { + log, err := cmd.Flags().GetBool("log") + if err != nil { + return err + } + model, err := getModel() if err != nil { return err @@ -49,5 +63,54 @@ func push(cmd *cobra.Command, args []string) error { } fmt.Printf("Successfully uploaded version %s\n", version.ID) + + if log { + return pushLog(model, version) + } + return nil } + +func pushLog(model *model.Model, version *model.Version) error { + c := client.NewClient() + + logChans := map[string]chan *client.LogEntry{} + for _, arch := range version.Config.Environment.Architectures { + logChan, err := c.GetBuildLogs(model, version.BuildIDs[arch], true) + if err != nil { + return err + } + logChans[arch] = logChan + } + + for archEntry := range mergeLogs(logChans) { + prefix := "" + if len(logChans) > 1 { + prefix = fmt.Sprintf("[%s] ", archEntry.arch) + } + outputLogEntry(archEntry.entry, prefix) + } + return nil +} + +func mergeLogs(channelMap map[string]chan *client.LogEntry) <-chan *archLogEntry { + out := make(chan *archLogEntry) + var wg sync.WaitGroup + wg.Add(len(channelMap)) + for arch, c := range channelMap { + go func(arch string, c <-chan *client.LogEntry) { + for entry := range c { + out <- &archLogEntry{ + arch: arch, + entry: entry, + } + } + wg.Done() + }(arch, c) + } + go func() { + wg.Wait() + close(out) + }() + return out +} diff --git a/pkg/client/build.go b/pkg/client/build.go index b79c53c491..c7009b2aa7 100644 --- a/pkg/client/build.go +++ b/pkg/client/build.go @@ -3,6 +3,7 @@ package client import ( "bufio" "encoding/json" + "fmt" "net/http" "github.com/replicate/cog/pkg/logger" @@ -30,6 +31,9 @@ func (c *Client) GetBuildLogs(mod *model.Model, buildID string, follow bool) (ch if err != nil { return nil, err } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Build logs endpoint returned error %d", resp.StatusCode) + } logChan := make(chan *LogEntry) go func() { scanner := bufio.NewScanner(resp.Body) diff --git a/pkg/server/build.go b/pkg/server/build.go index dfbeb8fc93..7932ab2d21 100644 --- a/pkg/server/build.go +++ b/pkg/server/build.go @@ -107,6 +107,8 @@ func (s *Server) buildImage(buildID, dir, user, name, id string, version *model. } }() + logWriter.Debug("Submitting build") + // TODO(andreas): make it possible to cancel the build result, err := s.buildQueue.Build(context.Background(), dir, name, id, arch, version.Config, logWriter) if err != nil { @@ -309,6 +311,7 @@ func (s *Server) SendBuildLogs(w http.ResponseWriter, r *http.Request) { if err != nil { console.Error(err.Error()) w.WriteHeader(http.StatusInternalServerError) + return } encoder := json.NewEncoder(w) for entry := range logChan {