diff --git a/end-to-end-test/end_to_end_test/test_server.py b/end-to-end-test/end_to_end_test/test_server.py index 2c57d2a183..102b322b0d 100644 --- a/end-to-end-test/end_to_end_test/test_server.py +++ b/end-to-end-test/end_to_end_test/test_server.py @@ -59,8 +59,7 @@ def run(self, text, path): """ ) with open(project_dir / "cog.yaml", "w") as f: - f.write( - """ + cog_yaml = """ name: andreas/hello-world model: infer.py:Model examples: @@ -79,7 +78,7 @@ def run(self, text, path): architectures: - cpu """ - ) + f.write(cog_yaml) out, _ = subprocess.Popen( ["cog", "repo", "set", f"http://localhost:{cog_port}/{user}/{repo_name}"], @@ -116,9 +115,7 @@ def run(self, text, path): ["cog", "-r", repo, "show", "--json", model_id], stdout=subprocess.PIPE ).communicate() out = json.loads(out) - assert ( - out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt" - ) + assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt" # show without -r out, _ = subprocess.Popen( @@ -153,6 +150,10 @@ def run(self, text, path): with input_path.open("w") as f: f.write("input") + files_endpoint = f"http://localhost:{cog_port}/v1/repos/{user}/{repo_name}/models/{model_id}/files" + assert requests.get(f"{files_endpoint}/cog.yaml").text == cog_yaml + assert requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text == "fooquxbaz" + out_path = output_dir / "out.txt" subprocess.Popen( [ diff --git a/pkg/server/download.go b/pkg/server/download.go index 812cf25073..e2d7d165f7 100644 --- a/pkg/server/download.go +++ b/pkg/server/download.go @@ -3,32 +3,50 @@ package server import ( "bytes" "net/http" + "path/filepath" "time" + "github.com/gorilla/mux" + "github.com/replicate/cog/pkg/console" + "github.com/replicate/cog/pkg/storage" ) func (s *Server) DownloadModel(w http.ResponseWriter, r *http.Request) { user, name, id := getRepoVars(r) modTime := time.Now() // TODO - mod, err := s.db.GetModel(user, name, id) + content, err := s.store.Download(user, name, id) if err != nil { - console.Error(err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - if mod == nil { - w.WriteHeader(http.StatusNotFound) + if err == storage.NotFound { + w.WriteHeader(http.StatusNotFound) + } else { + console.Error(err.Error()) + w.WriteHeader(http.StatusInternalServerError) + } return } + console.Infof("Downloaded %d bytes", len(content)) + http.ServeContent(w, r, id+".zip", modTime, bytes.NewReader(content)) +} - content, err := s.store.Download(user, name, id) +func (s *Server) DownloadFile(w http.ResponseWriter, r *http.Request) { + user, name, id := getRepoVars(r) + vars := mux.Vars(r) + path := vars["path"] + modTime := time.Now() // TODO + + content, err := s.store.DownloadFile(user, name, id, path) if err != nil { - console.Error(err.Error()) - w.WriteHeader(http.StatusInternalServerError) + if err == storage.NotFound { + w.WriteHeader(http.StatusNotFound) + } else { + console.Error(err.Error()) + w.WriteHeader(http.StatusInternalServerError) + } return } + filename := filepath.Base(path) console.Infof("Downloaded %d bytes", len(content)) - http.ServeContent(w, r, id+".zip", modTime, bytes.NewReader(content)) + http.ServeContent(w, r, filename, modTime, bytes.NewReader(content)) } diff --git a/pkg/server/server.go b/pkg/server/server.go index b26693e059..1468f511a8 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -68,6 +68,9 @@ func (s *Server) Start() error { router.Path("/v1/repos/{user}/{name}/models/{id}.zip"). Methods(http.MethodGet). HandlerFunc(s.checkReadAccess(s.DownloadModel)) + router.Path("/v1/repos/{user}/{name}/models/{id}/files/{path:.+}"). + Methods(http.MethodGet). + HandlerFunc(s.checkReadAccess(s.DownloadFile)) router.Path("/v1/repos/{user}/{name}/models/"). Methods(http.MethodPut). HandlerFunc(s.checkWriteAccess(s.ReceiveFile)) diff --git a/pkg/storage/local.go b/pkg/storage/local.go index 80d40698e8..a5f4108eff 100644 --- a/pkg/storage/local.go +++ b/pkg/storage/local.go @@ -1,6 +1,7 @@ package storage import ( + "archive/zip" "fmt" "io" "os" @@ -63,14 +64,44 @@ func (s *LocalStorage) Download(user string, name string, id string) ([]byte, er path := s.pathForID(user, name, id) contents, err := os.ReadFile(path) if err != nil { + if os.IsNotExist(err) { + return nil, NotFound + } return nil, fmt.Errorf("Failed to read %s: %w", path, err) } return contents, nil } +func (s *LocalStorage) DownloadFile(user string, name string, id string, path string) ([]byte, error) { + zipPath := s.pathForID(user, name, id) + reader, err := zip.OpenReader(zipPath) + if err != nil { + if os.IsNotExist(err) { + return nil, NotFound + } + } + for _, file := range reader.File { + if file.Name == path { + r, err := file.Open() + if err != nil { + return nil, fmt.Errorf("Failed to open %s in zip file: %w", path, err) + } + contents, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("Failed to read %s in zip file: %w", path, err) + } + return contents, nil + } + } + return nil, NotFound +} + func (s *LocalStorage) Delete(user string, name string, id string) error { path := s.pathForID(user, name, id) if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + return NotFound + } return fmt.Errorf("Failed to delete %s: %w", path, err) } return nil diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 36778d41a1..ce6977d385 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -1,11 +1,15 @@ package storage import ( + "errors" "io" ) type Storage interface { Upload(user string, name string, id string, reader io.Reader) error - Download(user string, name string, id string) ([]byte, error) + Download(user string, name string, id string) ([]byte, error) // TODO(andreas): return reader + DownloadFile(user string, name string, id string, path string) ([]byte, error) Delete(user string, name string, id string) error } + +var NotFound = errors.New("Not found")