Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 91 additions & 25 deletions end-to-end-test/end_to_end_test/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ def cog_server_port_dir():
server_proc.kill()


def test_build_show_list_download_infer(cog_server_port_dir, tmpdir_factory):
cog_port, cog_dir = cog_server_port_dir

user = "".join(random.choice(string.ascii_lowercase) for i in range(10))
model_name = "".join(random.choice(string.ascii_lowercase) for i in range(10))
model = f"http://localhost:{cog_port}/{user}/{model_name}"

project_dir = tmpdir_factory.mktemp("project")
with open(project_dir / "infer.py", "w") as f:
@pytest.fixture
def project_dir(tmpdir_factory):
tmpdir = tmpdir_factory.mktemp("project")
with open(tmpdir / "infer.py", "w") as f:
f.write(
"""
import time
Expand All @@ -61,7 +56,7 @@ def run(self, text, path):
return self.foo + text + f.read()
"""
)
with open(project_dir / "cog.yaml", "w") as f:
with open(tmpdir / "cog.yaml", "w") as f:
cog_yaml = """
name: andreas/hello-world
model: infer.py:Model
Expand All @@ -83,8 +78,23 @@ def run(self, text, path):
"""
f.write(cog_yaml)

return tmpdir


def test_build_show_list_download_infer(
cog_server_port_dir, project_dir, tmpdir_factory
):
cog_port, cog_dir = cog_server_port_dir

user = random_string(10)
model_name = random_string(10)
model_url = f"http://localhost:{cog_port}/{user}/{model_name}"

with open(os.path.join(project_dir, "cog.yaml")) as f:
cog_yaml = f.read()

out, _ = subprocess.Popen(
["cog", "model", "set", f"http://localhost:{cog_port}/{user}/{model_name}"],
["cog", "model", "set", model_url],
stdout=subprocess.PIPE,
cwd=project_dir,
).communicate()
Expand All @@ -108,22 +118,18 @@ def run(self, text, path):
version_id = out.decode().strip().split("Successfully uploaded version ")[1]

out, _ = subprocess.Popen(
["cog", "--model", model, "show", version_id], stdout=subprocess.PIPE
["cog", "--model", model_url, "show", version_id], stdout=subprocess.PIPE
).communicate()
lines = out.decode().splitlines()
assert lines[0] == f"ID: {version_id}"
assert lines[1] == f"Model: {user}/{model_name}"

def show_version():
out, _ = subprocess.Popen(
["cog", "--model", model, "show", "--json", version_id], stdout=subprocess.PIPE
).communicate()
return json.loads(out)

out = show_version()
subprocess.Popen(["cog", "--model", model, "build", "log", "-f", out["build_ids"]["cpu"]]).communicate()
out = show_version(model_url, version_id)
subprocess.Popen(
["cog", "--model", model_url, "build", "log", "-f", out["build_ids"]["cpu"]]
).communicate()

out = show_version()
out = show_version(model_url, version_id)
assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt"

# show without --model
Expand All @@ -137,14 +143,22 @@ def show_version():
assert lines[1] == f"Model: {user}/{model_name}"

out, _ = subprocess.Popen(
["cog", "--model", model, "ls"], stdout=subprocess.PIPE
["cog", "--model", model_url, "ls"], stdout=subprocess.PIPE
).communicate()
lines = out.decode().splitlines()
assert lines[1].startswith(f"{version_id} ")

download_dir = tmpdir_factory.mktemp("download") / "my-dir"
subprocess.Popen(
["cog", "--model", model, "download", "--output-dir", download_dir, version_id],
[
"cog",
"--model",
model_url,
"download",
"--output-dir",
download_dir,
version_id,
],
stdout=subprocess.PIPE,
).communicate()
paths = sorted(glob(str(download_dir / "*.*")))
Expand All @@ -161,14 +175,17 @@ def show_version():

files_endpoint = f"http://localhost:{cog_port}/v1/models/{user}/{model_name}/versions/{version_id}/files"
assert requests.get(f"{files_endpoint}/cog.yaml").text == cog_yaml
assert requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text == "fooquxbaz"
assert (
requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text
== "fooquxbaz"
)

out_path = output_dir / "out.txt"
subprocess.Popen(
[
"cog",
"--model",
model,
model_url,
"infer",
"-o",
out_path,
Expand All @@ -184,8 +201,57 @@ def show_version():
assert f.read() == "foobazinput"


def test_push_log(cog_server_port_dir, project_dir):
cog_port, cog_dir = cog_server_port_dir

user = random_string(10)
model_name = random_string(10)
model_url = f"http://localhost:{cog_port}/{user}/{model_name}"

out, _ = subprocess.Popen(
["cog", "model", "set", model_url],
stdout=subprocess.PIPE,
cwd=project_dir,
).communicate()
assert (
out.decode()
== f"Updated model: http://localhost:{cog_port}/{user}/{model_name}\n"
)

with open(project_dir / "myfile.txt", "w") as f:
f.write("baz")

out, _ = subprocess.Popen(
["cog", "push", "--log"],
cwd=project_dir,
stdout=subprocess.PIPE,
).communicate()

assert out.decode().startswith("Successfully uploaded version "), (
out.decode() + " doesn't start with 'Successfully uploaded version'"
)
version_id = out.decode().strip().split("Successfully uploaded version ")[1]

out = show_version(model_url, version_id)
assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt"
assert out["images"][0]["arch"] == "cpu"
assert out["images"][0]["run_arguments"]["text"]["type"] == "str"


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def random_string(length):
return "".join(random.choice(string.ascii_lowercase) for i in range(length))


def show_version(model_url, version_id):
out, _ = subprocess.Popen(
["cog", "--model", model_url, "show", "--json", version_id],
stdout=subprocess.PIPE,
).communicate()
return json.loads(out)
32 changes: 18 additions & 14 deletions pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,25 @@ func showBuildLogs(cmd *cobra.Command, args []string) error {
return err
}
for entry := range logChan {
switch entry.Level {
case logger.LevelFatal:
console.Fatal(entry.Line)
case logger.LevelError:
console.Error(entry.Line)
case logger.LevelWarn:
console.Warn(entry.Line)
case logger.LevelStatus: // TODO(andreas): handle status differently or remove
console.Info(entry.Line)
case logger.LevelInfo:
console.Info(entry.Line)
case logger.LevelDebug:
console.Debug(entry.Line)
}
outputLogEntry(entry, "")
}

return nil
}

func outputLogEntry(entry *client.LogEntry, prefix string) {
switch entry.Level {
case logger.LevelFatal:
console.Fatal(prefix + entry.Line)
case logger.LevelError:
console.Error(prefix + entry.Line)
case logger.LevelWarn:
console.Warn(prefix + entry.Line)
case logger.LevelStatus: // TODO(andreas): handle status differently or remove
console.Info(prefix + entry.Line)
case logger.LevelInfo:
console.Info(prefix + entry.Line)
case logger.LevelDebug:
console.Debug(prefix + entry.Line)
}
}
63 changes: 63 additions & 0 deletions pkg/cli/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@ import (
"fmt"
"os"
"path"
"sync"

"github.com/spf13/cobra"

"github.com/replicate/cog/pkg/client"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/model"
"github.com/replicate/cog/pkg/util/console"
)

type archLogEntry struct {
entry *client.LogEntry
arch string
}

func newPushCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "push",
Expand All @@ -22,10 +29,17 @@ func newPushCommand() *cobra.Command {
addModelFlag(cmd)
addProjectDirFlag(cmd)

cmd.Flags().Bool("log", false, "Follow image build logs after successful push")

return cmd
}

func push(cmd *cobra.Command, args []string) error {
log, err := cmd.Flags().GetBool("log")
if err != nil {
return err
}

model, err := getModel()
if err != nil {
return err
Expand All @@ -49,5 +63,54 @@ func push(cmd *cobra.Command, args []string) error {
}

fmt.Printf("Successfully uploaded version %s\n", version.ID)

if log {
return pushLog(model, version)
}

return nil
}

func pushLog(model *model.Model, version *model.Version) error {
c := client.NewClient()

logChans := map[string]chan *client.LogEntry{}
for _, arch := range version.Config.Environment.Architectures {
logChan, err := c.GetBuildLogs(model, version.BuildIDs[arch], true)
if err != nil {
return err
}
logChans[arch] = logChan
}

for archEntry := range mergeLogs(logChans) {
prefix := ""
if len(logChans) > 1 {
prefix = fmt.Sprintf("[%s] ", archEntry.arch)
}
outputLogEntry(archEntry.entry, prefix)
}
return nil
}

func mergeLogs(channelMap map[string]chan *client.LogEntry) <-chan *archLogEntry {
out := make(chan *archLogEntry)
var wg sync.WaitGroup
wg.Add(len(channelMap))
for arch, c := range channelMap {
go func(arch string, c <-chan *client.LogEntry) {
for entry := range c {
out <- &archLogEntry{
arch: arch,
entry: entry,
}
}
wg.Done()
}(arch, c)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
4 changes: 4 additions & 0 deletions pkg/client/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"bufio"
"encoding/json"
"fmt"
"net/http"

"github.com/replicate/cog/pkg/logger"
Expand Down Expand Up @@ -30,6 +31,9 @@ func (c *Client) GetBuildLogs(mod *model.Model, buildID string, follow bool) (ch
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Build logs endpoint returned error %d", resp.StatusCode)
}
logChan := make(chan *LogEntry)
go func() {
scanner := bufio.NewScanner(resp.Body)
Expand Down
3 changes: 3 additions & 0 deletions pkg/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ func (s *Server) buildImage(buildID, dir, user, name, id string, version *model.
}
}()

logWriter.Debug("Submitting build")

// TODO(andreas): make it possible to cancel the build
result, err := s.buildQueue.Build(context.Background(), dir, name, id, arch, version.Config, logWriter)
if err != nil {
Expand Down Expand Up @@ -309,6 +311,7 @@ func (s *Server) SendBuildLogs(w http.ResponseWriter, r *http.Request) {
if err != nil {
console.Error(err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
encoder := json.NewEncoder(w)
for entry := range logChan {
Expand Down