Skip to content

Commit 30b9446

Browse files
bring back list
1 parent dfdaee7 commit 30b9446

File tree

8 files changed

+161
-8
lines changed

8 files changed

+161
-8
lines changed

README.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ model: "model.py:JazzSoloComposerModel"
4040
3. Build and push the model:
4141
4242
```
43-
$ cog repo set http://10.1.2.3:8000
43+
$ cog repo set http://10.1.2.3:8000/andreas/my-model
4444
$ cog build
4545
...
4646
--> Built and pushed b6a2f8a2d2ff
@@ -86,6 +86,10 @@ You can see more details about the package:
8686

8787
cog show b31f9f72d8f14f0eacc5452e85b05c957b9a8ed9
8888

89+
You can also list the packages for this repo:
90+
91+
cog list
92+
8993
In this output is the Docker image. You can run this anywhere a Docker image runs to deploy your model.
9094

9195

@@ -116,7 +120,7 @@ This does the following:
116120
* Tests that the model works by running the Docker image locally and performing an inference
117121
* Inserts model metadata into database (local files)
118122

119-
### GET `/v1/packages/<user>/<name>`
123+
### GET `/v1/packages/<user>/<name>/<id>`
120124

121125
Fetch package metadata.
122126

@@ -149,6 +153,24 @@ $ curl localhost:8080/v1/packages/andreas/my-model/c43b98b37776656e6b3dac3ea3270
149153
}
150154
```
151155

156+
### GET `/v1/packages/<user>/<name>`
157+
158+
List all packages' metadata.
159+
160+
Example:
161+
162+
```
163+
$ curl localhost:8080/v1/packages/andreas/my-model/ | jq .
164+
[
165+
{
166+
"ID": "c43b98b37776656e6b3dac3ea3270660ffc21ca7",
167+
"Artifacts": [
168+
{
169+
"Target": "docker-cpu",
170+
[...]
171+
]
172+
```
173+
152174
### GET `/v1/packages/<user>/<name>/<id>.zip`
153175

154176
Download the package.

end-to-end-test/end_to_end_test/test_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def run(self, text, path):
115115
assert lines[0] == f"ID: {package_id}"
116116
assert lines[1] == f"Repo: {user}/{repo_name}"
117117

118+
out, _ = subprocess.Popen(["cog", "-r", repo, "ls"], stdout=subprocess.PIPE).communicate()
119+
lines = out.decode().splitlines()
120+
assert lines[1].startswith(f"{package_id} ")
121+
118122
download_dir = tmpdir_factory.mktemp("download") / "my-dir"
119123
subprocess.Popen(
120124
["cog", "-r", repo, "download", "--output-dir", download_dir, package_id],

pkg/cli/list.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package cli
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"text/tabwriter"
7+
8+
"github.com/spf13/cobra"
9+
"github.com/xeonx/timeago"
10+
11+
"github.com/replicate/cog/pkg/client"
12+
)
13+
14+
func newListCommand() *cobra.Command {
15+
cmd := &cobra.Command{
16+
Use: "list",
17+
Short: "List Cog packages",
18+
RunE: listPackages,
19+
Args: cobra.NoArgs,
20+
Aliases: []string{"ls"},
21+
}
22+
addRepoFlag(cmd)
23+
24+
return cmd
25+
}
26+
27+
func listPackages(cmd *cobra.Command, args []string) error {
28+
repo, err := getRepo()
29+
if err != nil {
30+
return err
31+
}
32+
33+
cli := client.NewClient()
34+
models, err := cli.ListPackages(repo)
35+
if err != nil {
36+
return err
37+
}
38+
39+
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
40+
fmt.Fprintln(w, "ID\tCREATED")
41+
for _, mod := range models {
42+
fmt.Fprintf(w, "%s\t%s\n", mod.ID, timeago.English.Format(mod.Created))
43+
}
44+
w.Flush()
45+
46+
return nil
47+
}

pkg/cli/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ func NewRootCommand() (*cobra.Command, error) {
3535
newShowCommand(),
3636
newRepoCommand(),
3737
newDownloadCommand(),
38+
newListCommand(),
3839
)
3940

4041
log.SetLevel(log.DebugLevel)

pkg/client/list.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package client
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"net/http"
7+
8+
"github.com/replicate/cog/pkg/model"
9+
)
10+
11+
func (c *Client) ListPackages(repo *model.Repo) ([]*model.Model, error) {
12+
url := fmt.Sprintf("http://%s/v1/repos/%s/%s/packages/", repo.Host, repo.User, repo.Name)
13+
resp, err := http.Get(url)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
if resp.StatusCode != http.StatusOK {
19+
return nil, fmt.Errorf("List endpoint returned status %d", resp.StatusCode)
20+
}
21+
22+
models := []*model.Model{}
23+
if err := json.NewDecoder(resp.Body).Decode(&models); err != nil {
24+
return nil, fmt.Errorf("Failed to decode response: %w", err)
25+
}
26+
27+
return models, nil
28+
}

pkg/database/database.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ import (
77
type Database interface {
88
InsertModel(user string, name string, id string, mod *model.Model) error
99
GetModel(user string, name string, id string) (*model.Model, error)
10+
ListModels(user string, name string) ([]*model.Model, error)
1011
DeleteModel(user string, name string, id string) error
1112
}

pkg/database/local.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"os"
77
"path/filepath"
88

9+
"strings"
10+
911
"github.com/replicate/cog/pkg/files"
1012
"github.com/replicate/cog/pkg/model"
1113
)
@@ -82,6 +84,27 @@ func (db *LocalFileDatabase) DeleteModel(user string, name string, id string) er
8284
return nil
8385
}
8486

87+
func (db *LocalFileDatabase) ListModels(user string, name string) ([]*model.Model, error) {
88+
repoDir := db.repoDir(user, name)
89+
entries, err := os.ReadDir(repoDir)
90+
if err != nil {
91+
return nil, fmt.Errorf("Failed to scan %s: %w", db.rootDir, err)
92+
}
93+
models := []*model.Model{}
94+
for _, entry := range entries {
95+
filename := entry.Name()
96+
if strings.HasSuffix(filename, ".json") {
97+
path := filepath.Join(repoDir, filename)
98+
mod, err := db.readModel(path)
99+
if err != nil {
100+
return nil, err
101+
}
102+
models = append(models, mod)
103+
}
104+
}
105+
return models, nil
106+
}
107+
85108
func (db *LocalFileDatabase) readModel(path string) (*model.Model, error) {
86109
contents, err := os.ReadFile(path)
87110
if err != nil {
@@ -95,5 +118,9 @@ func (db *LocalFileDatabase) readModel(path string) (*model.Model, error) {
95118
}
96119

97120
func (db *LocalFileDatabase) packagePath(user string, name string, id string) string {
98-
return filepath.Join(db.rootDir, user, name, id+".json")
121+
return filepath.Join(db.repoDir(user, name), id+".json")
122+
}
123+
124+
func (db *LocalFileDatabase) repoDir(user string, name string) string {
125+
return filepath.Join(db.rootDir, user, name)
99126
}

pkg/server/server.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,25 @@ func NewServer(port int, db database.Database, dockerImageBuilder docker.ImageBu
5151
func (s *Server) Start() error {
5252
router := mux.NewRouter()
5353
router.Path("/ping").
54-
Methods("GET").
54+
Methods(http.MethodGet).
5555
HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5656
log.Info("Received ping request")
5757
w.Write([]byte("pong"))
5858
})
5959
router.Path("/v1/repos/{user}/{name}/packages/{id}.zip").
60-
Methods("GET").
60+
Methods(http.MethodGet).
6161
HandlerFunc(s.SendModelPackage)
6262
router.Path("/v1/repos/{user}/{name}/packages/").
63-
Methods("PUT").
63+
Methods(http.MethodPut).
6464
HandlerFunc(s.ReceiveFile)
65+
router.Path("/v1/repos/{user}/{name}/packages/").
66+
Methods(http.MethodGet).
67+
HandlerFunc(s.ListPackages)
6568
router.Path("/v1/repos/{user}/{name}/packages/{id}").
66-
Methods("GET").
69+
Methods(http.MethodGet).
6770
HandlerFunc(s.SendModelMetadata)
6871
router.Path("/v1/repos/{user}/{name}/packages/{id}").
69-
Methods("DELETE").
72+
Methods(http.MethodDelete).
7073
HandlerFunc(s.DeletePackage)
7174
fmt.Println("Starting")
7275
return http.ListenAndServe(fmt.Sprintf(":%d", s.port), router)
@@ -136,6 +139,26 @@ func (s *Server) SendModelMetadata(w http.ResponseWriter, r *http.Request) {
136139
}
137140
}
138141

142+
func (s *Server) ListPackages(w http.ResponseWriter, r *http.Request) {
143+
user, name, _ := getRepoVars(r)
144+
log.Infof("Received list request for %s%s", user, name)
145+
146+
models, err := s.db.ListModels(user, name)
147+
if err != nil {
148+
log.Error(err)
149+
w.WriteHeader(http.StatusInternalServerError)
150+
return
151+
}
152+
w.WriteHeader(http.StatusOK)
153+
w.Header().Set("Content-Type", "application/json")
154+
155+
if err := json.NewEncoder(w).Encode(models); err != nil {
156+
log.Error(err)
157+
w.WriteHeader(http.StatusInternalServerError)
158+
return
159+
}
160+
}
161+
139162
func (s *Server) DeletePackage(w http.ResponseWriter, r *http.Request) {
140163
user, name, id := getRepoVars(r)
141164
log.Infof("Received delete request for %s/%s/%s", user, name, id)

0 commit comments

Comments
 (0)