Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error {

configPath := path.Join(projectDir, global.ConfigFilename)

exists, err := files.FileExists(configPath)
exists, err := files.Exists(configPath)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func downloadModel(cmd *cobra.Command, args []string) (err error) {
}

// TODO(andreas): allow to checkout to existing directories, with warning prompt
exists, err := files.FileExists(downloadOutputDir)
exists, err := files.Exists(downloadOutputDir)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func setRepo(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
exists, err := files.FileExists(filepath.Join(cwd, global.ConfigFilename))
exists, err := files.Exists(filepath.Join(cwd, global.ConfigFilename))
if !exists {
console.Warnf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, cwd)
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
var (
port int
dockerRegistry string
buildWebHooks []string
)

func newServerCommand() *cobra.Command {
Expand All @@ -32,6 +33,7 @@ func newServerCommand() *cobra.Command {

cmd.Flags().IntVar(&port, "port", 0, "Server port")
cmd.Flags().StringVar(&dockerRegistry, "docker-registry", "", "Docker registry to push images to")
cmd.Flags().StringArrayVar(&buildWebHooks, "web-hook", []string{}, "Web hooks that are posted to after build. Format: <url>@<secret>")
return cmd
}

Expand Down Expand Up @@ -77,6 +79,9 @@ func startServer(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
s := server.NewServer(port, db, dockerImageBuilder, servingPlatform, store)
s, err := server.NewServer(port, buildWebHooks, db, dockerImageBuilder, servingPlatform, store)
if err != nil {
return err
}
return s.Start()
}
6 changes: 3 additions & 3 deletions pkg/database/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type LocalFileDatabase struct {
}

func NewLocalFileDatabase(rootDir string) (*LocalFileDatabase, error) {
exists, err := files.FileExists(rootDir)
exists, err := files.Exists(rootDir)
if err != nil {
return nil, err
}
Expand All @@ -44,7 +44,7 @@ func (db *LocalFileDatabase) InsertModel(user string, name string, id string, mo
}
path := db.modelPath(user, name, id)
dir := filepath.Dir(path)
exists, err := files.FileExists(path)
exists, err := files.Exists(path)
if err != nil {
return err
}
Expand All @@ -62,7 +62,7 @@ func (db *LocalFileDatabase) InsertModel(user string, name string, id string, mo
// GetModel returns a model or nil if the model doesn't exist
func (db *LocalFileDatabase) GetModel(user string, name string, id string) (*model.Model, error) {
path := db.modelPath(user, name, id)
exists, err := files.FileExists(path)
exists, err := files.Exists(path)
if err != nil {
return nil, fmt.Errorf("Failed to determine if %s exists: %w", path, err)
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/files/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@ package files
import (
"fmt"
"os"

"golang.org/x/sys/unix"
)

func FileExists(filePath string) (bool, error) {
if _, err := os.Stat(filePath); err == nil {
func Exists(path string) (bool, error) {
if _, err := os.Stat(path); err == nil {
return true, nil
} else if os.IsNotExist(err) {
return false, nil
} else {
return false, fmt.Errorf("Failed to determine if %s exists: %w", filePath, err)
return false, fmt.Errorf("Failed to determine if %s exists: %w", path, err)
}
}

func IsDir(dirPath string) (bool, error) {
file, err := os.Stat(dirPath)
func IsDir(path string) (bool, error) {
file, err := os.Stat(path)
if err != nil {
return false, err
}
return file.Mode().IsDir(), nil
}

func IsExecutable(path string) bool {
return unix.Access(path, unix.X_OK) == nil
}
21 changes: 21 additions & 0 deletions pkg/files/files_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package files

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestIsExecutable(t *testing.T) {
dir, err := os.MkdirTemp("/tmp", "test-files")
require.NoError(t, err)
path := filepath.Join(dir, "test-file")
err = os.WriteFile(path, []byte{}, 0644)
require.NoError(t, err)

require.False(t, IsExecutable(path))
os.Chmod(path, 0744)
require.True(t, IsExecutable(path))
}
14 changes: 9 additions & 5 deletions pkg/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (s *Server) ReceiveFile(w http.ResponseWriter, r *http.Request) {
streamLogger.WriteModel(mod)
}

func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger, user string, name string) (*model.Model, error) {
func (s *Server) ReceiveModel(r *http.Request, logWriter logger.Logger, user string, name string) (*model.Model, error) {
// max 5GB models
if err := r.ParseMultipartForm(5 << 30); err != nil {
return nil, fmt.Errorf("Failed to parse request: %w", err)
Expand All @@ -45,7 +45,7 @@ func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger
}
defer file.Close()

streamLogger.WriteStatus("Received model")
logWriter.WriteStatus("Received model")

hasher := sha1.New()
if _, err := io.Copy(hasher, file); err != nil {
Expand Down Expand Up @@ -86,7 +86,7 @@ func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger
return nil, fmt.Errorf("Failed to upload to storage: %w", err)
}

artifacts, err := s.buildDockerImages(dir, config, name, streamLogger)
artifacts, err := s.buildDockerImages(dir, config, name, logWriter)
if err != nil {
return nil, err
}
Expand All @@ -97,18 +97,22 @@ func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger
Created: time.Now(),
}

runArgs, err := s.testModel(mod, dir, streamLogger)
runArgs, err := s.testModel(mod, dir, logWriter)
if err != nil {
// TODO(andreas): return other response than 500 if validation fails
return nil, err
}
mod.RunArguments = runArgs

streamLogger.WriteStatus("Inserting into database")
logWriter.WriteStatus("Inserting into database")
if err := s.db.InsertModel(user, name, id, mod); err != nil {
return nil, fmt.Errorf("Failed to insert into database: %w", err)
}

if err := s.runWebHooks(user, name, mod, dir, logWriter); err != nil {
return nil, err
}

return mod, nil
}

Expand Down
14 changes: 12 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,30 @@ const topLevelSourceDir = "source"

type Server struct {
port int
webHooks []*WebHook
db database.Database
dockerImageBuilder docker.ImageBuilder
servingPlatform serving.Platform
store storage.Storage
}

func NewServer(port int, db database.Database, dockerImageBuilder docker.ImageBuilder, servingPlatform serving.Platform, store storage.Storage) *Server {
func NewServer(port int, rawWebHooks []string, db database.Database, dockerImageBuilder docker.ImageBuilder, servingPlatform serving.Platform, store storage.Storage) (*Server, error) {
webHooks := []*WebHook{}
for _, rawWebHook := range rawWebHooks {
webHook, err := newWebHook(rawWebHook)
if err != nil {
return nil, err
}
webHooks = append(webHooks, webHook)
}
return &Server{
port: port,
webHooks: webHooks,
db: db,
dockerImageBuilder: dockerImageBuilder,
servingPlatform: servingPlatform,
store: store,
}
}, nil
}

func (s *Server) Start() error {
Expand Down
69 changes: 69 additions & 0 deletions pkg/server/web_hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package server

import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"

"github.com/replicate/cog/pkg/logger"
"github.com/replicate/cog/pkg/model"
)

type WebHook struct {
url *url.URL
secret string
}

func newWebHook(urlWithSecret string) (*WebHook, error) {
splitIndex := strings.LastIndex(urlWithSecret, "@")
if splitIndex == -1 {
return nil, fmt.Errorf("Web hooks must be in the format <url>@<secret>")
}
rawURL := urlWithSecret[:splitIndex]
secret := urlWithSecret[splitIndex+1:]

hookURL, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
return &WebHook{url: hookURL, secret: secret}, nil
}

func (wh *WebHook) run(user string, name string, mod *model.Model, dir string, logWriter logger.Logger) error {
modelJSON, err := json.Marshal(mod)
if err != nil {
return err
}
modelJSONBase64 := base64.StdEncoding.EncodeToString(modelJSON)
modelPath := fmt.Sprintf("/v1/repos/%s/%s/models/%s", user, name, mod.ID)

logWriter.Infof("Posting model to %s", wh.url.Host)

req, err := http.PostForm(wh.url.String(), url.Values{
"model_id": {mod.ID},
"model_path": {modelPath},
"model_json_base64": {modelJSONBase64},
"user": {user},
"repo_name": {name},
"secret": {wh.secret},
})
if err != nil {
return fmt.Errorf("Model post failed: %w", err)
}
if req.StatusCode != http.StatusOK {
return fmt.Errorf("Model post failed with HTTP status %d", req.StatusCode)
}
return nil
}

func (s *Server) runWebHooks(user string, name string, mod *model.Model, dir string, logWriter logger.Logger) error {
for _, hook := range s.webHooks {
if err := hook.run(user, name, mod, dir, logWriter); err != nil {
return err
}
}
return nil
}
3 changes: 2 additions & 1 deletion pkg/serving/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ func (p *LocalDockerPlatform) waitForContainerReady(hostPort int, containerID st
}

func (d *LocalDockerDeployment) Undeploy() error {
if err := d.client.ContainerStop(context.Background(), d.containerID, nil); err != nil {
timeout := time.Duration(100 * time.Millisecond)
if err := d.client.ContainerStop(context.Background(), d.containerID, &timeout); err != nil {
return fmt.Errorf("Failed to stop Docker container %s: %w", d.containerID, err)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/settings/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func LoadProjectSettings(projectRoot string) (*ProjectSettings, error) {
}

settingsPath := projectSettingsPath(projectRoot)
exists, err := files.FileExists(settingsPath)
exists, err := files.Exists(settingsPath)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/storage/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type LocalStorage struct {
}

func NewLocalStorage(rootDir string) (*LocalStorage, error) {
exists, err := files.FileExists(rootDir)
exists, err := files.Exists(rootDir)
if err != nil {
return nil, err
}
Expand All @@ -39,7 +39,7 @@ func NewLocalStorage(rootDir string) (*LocalStorage, error) {
func (s *LocalStorage) Upload(user string, name string, id string, reader io.Reader) error {
path := s.pathForID(user, name, id)
dir := filepath.Dir(path)
exists, err := files.FileExists(path)
exists, err := files.Exists(path)
if err != nil {
return err
}
Expand Down