Skip to content

Commit 50ef087

Browse files
Merge pull request #53 from replicate/andreas/download-files
Download individual files from model
2 parents bae2c3b + 0f4014c commit 50ef087

5 files changed

Lines changed: 75 additions & 18 deletions

File tree

end-to-end-test/end_to_end_test/test_server.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def run(self, text, path):
5959
"""
6060
)
6161
with open(project_dir / "cog.yaml", "w") as f:
62-
f.write(
63-
"""
62+
cog_yaml = """
6463
name: andreas/hello-world
6564
model: infer.py:Model
6665
examples:
@@ -79,7 +78,7 @@ def run(self, text, path):
7978
architectures:
8079
- cpu
8180
"""
82-
)
81+
f.write(cog_yaml)
8382

8483
out, _ = subprocess.Popen(
8584
["cog", "repo", "set", f"http://localhost:{cog_port}/{user}/{repo_name}"],
@@ -116,9 +115,7 @@ def run(self, text, path):
116115
["cog", "-r", repo, "show", "--json", model_id], stdout=subprocess.PIPE
117116
).communicate()
118117
out = json.loads(out)
119-
assert (
120-
out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt"
121-
)
118+
assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt"
122119

123120
# show without -r
124121
out, _ = subprocess.Popen(
@@ -153,6 +150,10 @@ def run(self, text, path):
153150
with input_path.open("w") as f:
154151
f.write("input")
155152

153+
files_endpoint = f"http://localhost:{cog_port}/v1/repos/{user}/{repo_name}/models/{model_id}/files"
154+
assert requests.get(f"{files_endpoint}/cog.yaml").text == cog_yaml
155+
assert requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text == "fooquxbaz"
156+
156157
out_path = output_dir / "out.txt"
157158
subprocess.Popen(
158159
[

pkg/server/download.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,50 @@ package server
33
import (
44
"bytes"
55
"net/http"
6+
"path/filepath"
67
"time"
78

9+
"github.com/gorilla/mux"
10+
811
"github.com/replicate/cog/pkg/console"
12+
"github.com/replicate/cog/pkg/storage"
913
)
1014

1115
func (s *Server) DownloadModel(w http.ResponseWriter, r *http.Request) {
1216
user, name, id := getRepoVars(r)
1317
modTime := time.Now() // TODO
1418

15-
mod, err := s.db.GetModel(user, name, id)
19+
content, err := s.store.Download(user, name, id)
1620
if err != nil {
17-
console.Error(err.Error())
18-
w.WriteHeader(http.StatusInternalServerError)
19-
return
20-
}
21-
if mod == nil {
22-
w.WriteHeader(http.StatusNotFound)
21+
if err == storage.NotFound {
22+
w.WriteHeader(http.StatusNotFound)
23+
} else {
24+
console.Error(err.Error())
25+
w.WriteHeader(http.StatusInternalServerError)
26+
}
2327
return
2428
}
29+
console.Infof("Downloaded %d bytes", len(content))
30+
http.ServeContent(w, r, id+".zip", modTime, bytes.NewReader(content))
31+
}
2532

26-
content, err := s.store.Download(user, name, id)
33+
func (s *Server) DownloadFile(w http.ResponseWriter, r *http.Request) {
34+
user, name, id := getRepoVars(r)
35+
vars := mux.Vars(r)
36+
path := vars["path"]
37+
modTime := time.Now() // TODO
38+
39+
content, err := s.store.DownloadFile(user, name, id, path)
2740
if err != nil {
28-
console.Error(err.Error())
29-
w.WriteHeader(http.StatusInternalServerError)
41+
if err == storage.NotFound {
42+
w.WriteHeader(http.StatusNotFound)
43+
} else {
44+
console.Error(err.Error())
45+
w.WriteHeader(http.StatusInternalServerError)
46+
}
3047
return
3148
}
49+
filename := filepath.Base(path)
3250
console.Infof("Downloaded %d bytes", len(content))
33-
http.ServeContent(w, r, id+".zip", modTime, bytes.NewReader(content))
51+
http.ServeContent(w, r, filename, modTime, bytes.NewReader(content))
3452
}

pkg/server/server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ func (s *Server) Start() error {
6868
router.Path("/v1/repos/{user}/{name}/models/{id}.zip").
6969
Methods(http.MethodGet).
7070
HandlerFunc(s.checkReadAccess(s.DownloadModel))
71+
router.Path("/v1/repos/{user}/{name}/models/{id}/files/{path:.+}").
72+
Methods(http.MethodGet).
73+
HandlerFunc(s.checkReadAccess(s.DownloadFile))
7174
router.Path("/v1/repos/{user}/{name}/models/").
7275
Methods(http.MethodPut).
7376
HandlerFunc(s.checkWriteAccess(s.ReceiveFile))

pkg/storage/local.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package storage
22

33
import (
4+
"archive/zip"
45
"fmt"
56
"io"
67
"os"
@@ -63,14 +64,44 @@ func (s *LocalStorage) Download(user string, name string, id string) ([]byte, er
6364
path := s.pathForID(user, name, id)
6465
contents, err := os.ReadFile(path)
6566
if err != nil {
67+
if os.IsNotExist(err) {
68+
return nil, NotFound
69+
}
6670
return nil, fmt.Errorf("Failed to read %s: %w", path, err)
6771
}
6872
return contents, nil
6973
}
7074

75+
func (s *LocalStorage) DownloadFile(user string, name string, id string, path string) ([]byte, error) {
76+
zipPath := s.pathForID(user, name, id)
77+
reader, err := zip.OpenReader(zipPath)
78+
if err != nil {
79+
if os.IsNotExist(err) {
80+
return nil, NotFound
81+
}
82+
}
83+
for _, file := range reader.File {
84+
if file.Name == path {
85+
r, err := file.Open()
86+
if err != nil {
87+
return nil, fmt.Errorf("Failed to open %s in zip file: %w", path, err)
88+
}
89+
contents, err := io.ReadAll(r)
90+
if err != nil {
91+
return nil, fmt.Errorf("Failed to read %s in zip file: %w", path, err)
92+
}
93+
return contents, nil
94+
}
95+
}
96+
return nil, NotFound
97+
}
98+
7199
func (s *LocalStorage) Delete(user string, name string, id string) error {
72100
path := s.pathForID(user, name, id)
73101
if err := os.Remove(path); err != nil {
102+
if os.IsNotExist(err) {
103+
return NotFound
104+
}
74105
return fmt.Errorf("Failed to delete %s: %w", path, err)
75106
}
76107
return nil

pkg/storage/storage.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package storage
22

33
import (
4+
"errors"
45
"io"
56
)
67

78
type Storage interface {
89
Upload(user string, name string, id string, reader io.Reader) error
9-
Download(user string, name string, id string) ([]byte, error)
10+
Download(user string, name string, id string) ([]byte, error) // TODO(andreas): return reader
11+
DownloadFile(user string, name string, id string, path string) ([]byte, error)
1012
Delete(user string, name string, id string) error
1113
}
14+
15+
var NotFound = errors.New("Not found")

0 commit comments

Comments
 (0)