Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions end-to-end-test/end_to_end_test/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def run(self, text, path):
"""
)
with open(project_dir / "cog.yaml", "w") as f:
f.write(
"""
cog_yaml = """
name: andreas/hello-world
model: infer.py:Model
examples:
Expand All @@ -79,7 +78,7 @@ def run(self, text, path):
architectures:
- cpu
"""
)
f.write(cog_yaml)

out, _ = subprocess.Popen(
["cog", "repo", "set", f"http://localhost:{cog_port}/{user}/{repo_name}"],
Expand Down Expand Up @@ -116,9 +115,7 @@ def run(self, text, path):
["cog", "-r", repo, "show", "--json", model_id], stdout=subprocess.PIPE
).communicate()
out = json.loads(out)
assert (
out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt"
)
assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt"

# show without -r
out, _ = subprocess.Popen(
Expand Down Expand Up @@ -153,6 +150,10 @@ def run(self, text, path):
with input_path.open("w") as f:
f.write("input")

files_endpoint = f"http://localhost:{cog_port}/v1/repos/{user}/{repo_name}/models/{model_id}/files"
assert requests.get(f"{files_endpoint}/cog.yaml").text == cog_yaml
assert requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text == "fooquxbaz"

out_path = output_dir / "out.txt"
subprocess.Popen(
[
Expand Down
40 changes: 29 additions & 11 deletions pkg/server/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,50 @@ package server
import (
"bytes"
"net/http"
"path/filepath"
"time"

"github.com/gorilla/mux"

"github.com/replicate/cog/pkg/console"
"github.com/replicate/cog/pkg/storage"
)

func (s *Server) DownloadModel(w http.ResponseWriter, r *http.Request) {
user, name, id := getRepoVars(r)
modTime := time.Now() // TODO

mod, err := s.db.GetModel(user, name, id)
content, err := s.store.Download(user, name, id)
if err != nil {
console.Error(err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
if mod == nil {
w.WriteHeader(http.StatusNotFound)
if err == storage.NotFound {
w.WriteHeader(http.StatusNotFound)
} else {
console.Error(err.Error())
w.WriteHeader(http.StatusInternalServerError)
}
return
}
console.Infof("Downloaded %d bytes", len(content))
http.ServeContent(w, r, id+".zip", modTime, bytes.NewReader(content))
}

content, err := s.store.Download(user, name, id)
func (s *Server) DownloadFile(w http.ResponseWriter, r *http.Request) {
user, name, id := getRepoVars(r)
vars := mux.Vars(r)
path := vars["path"]
modTime := time.Now() // TODO

content, err := s.store.DownloadFile(user, name, id, path)
if err != nil {
console.Error(err.Error())
w.WriteHeader(http.StatusInternalServerError)
if err == storage.NotFound {
w.WriteHeader(http.StatusNotFound)
} else {
console.Error(err.Error())
w.WriteHeader(http.StatusInternalServerError)
}
return
}
filename := filepath.Base(path)
console.Infof("Downloaded %d bytes", len(content))
http.ServeContent(w, r, id+".zip", modTime, bytes.NewReader(content))
http.ServeContent(w, r, filename, modTime, bytes.NewReader(content))
}
3 changes: 3 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ func (s *Server) Start() error {
router.Path("/v1/repos/{user}/{name}/models/{id}.zip").
Methods(http.MethodGet).
HandlerFunc(s.checkReadAccess(s.DownloadModel))
router.Path("/v1/repos/{user}/{name}/models/{id}/files/{path:.+}").
Methods(http.MethodGet).
HandlerFunc(s.checkReadAccess(s.DownloadFile))
router.Path("/v1/repos/{user}/{name}/models/").
Methods(http.MethodPut).
HandlerFunc(s.checkWriteAccess(s.ReceiveFile))
Expand Down
31 changes: 31 additions & 0 deletions pkg/storage/local.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"archive/zip"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -63,14 +64,44 @@ func (s *LocalStorage) Download(user string, name string, id string) ([]byte, er
path := s.pathForID(user, name, id)
contents, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, NotFound
}
return nil, fmt.Errorf("Failed to read %s: %w", path, err)
}
return contents, nil
}

func (s *LocalStorage) DownloadFile(user string, name string, id string, path string) ([]byte, error) {
zipPath := s.pathForID(user, name, id)
reader, err := zip.OpenReader(zipPath)
if err != nil {
if os.IsNotExist(err) {
return nil, NotFound
}
}
for _, file := range reader.File {
if file.Name == path {
r, err := file.Open()
if err != nil {
return nil, fmt.Errorf("Failed to open %s in zip file: %w", path, err)
}
contents, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("Failed to read %s in zip file: %w", path, err)
}
return contents, nil
}
}
return nil, NotFound
}

func (s *LocalStorage) Delete(user string, name string, id string) error {
path := s.pathForID(user, name, id)
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
return NotFound
}
return fmt.Errorf("Failed to delete %s: %w", path, err)
}
return nil
Expand Down
6 changes: 5 additions & 1 deletion pkg/storage/storage.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package storage

import (
"errors"
"io"
)

type Storage interface {
Upload(user string, name string, id string, reader io.Reader) error
Download(user string, name string, id string) ([]byte, error)
Download(user string, name string, id string) ([]byte, error) // TODO(andreas): return reader
DownloadFile(user string, name string, id string, path string) ([]byte, error)
Delete(user string, name string, id string) error
}

var NotFound = errors.New("Not found")