diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index 79f96b1619..36fff825ee 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -41,7 +41,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error { configPath := path.Join(projectDir, global.ConfigFilename) - exists, err := files.FileExists(configPath) + exists, err := files.Exists(configPath) if err != nil { return err } diff --git a/pkg/cli/download.go b/pkg/cli/download.go index 3470ed87e5..66a96704b1 100644 --- a/pkg/cli/download.go +++ b/pkg/cli/download.go @@ -45,7 +45,7 @@ func downloadModel(cmd *cobra.Command, args []string) (err error) { } // TODO(andreas): allow to checkout to existing directories, with warning prompt - exists, err := files.FileExists(downloadOutputDir) + exists, err := files.Exists(downloadOutputDir) if err != nil { return err } diff --git a/pkg/cli/repo.go b/pkg/cli/repo.go index 8559e60521..129b5bbb1b 100644 --- a/pkg/cli/repo.go +++ b/pkg/cli/repo.go @@ -64,7 +64,7 @@ func setRepo(cmd *cobra.Command, args []string) error { if err != nil { return err } - exists, err := files.FileExists(filepath.Join(cwd, global.ConfigFilename)) + exists, err := files.Exists(filepath.Join(cwd, global.ConfigFilename)) if !exists { console.Warnf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, cwd) } diff --git a/pkg/cli/server.go b/pkg/cli/server.go index 00ee4adb17..f6575f2158 100644 --- a/pkg/cli/server.go +++ b/pkg/cli/server.go @@ -20,6 +20,7 @@ import ( var ( port int dockerRegistry string + buildWebHooks []string ) func newServerCommand() *cobra.Command { @@ -32,6 +33,7 @@ func newServerCommand() *cobra.Command { cmd.Flags().IntVar(&port, "port", 0, "Server port") cmd.Flags().StringVar(&dockerRegistry, "docker-registry", "", "Docker registry to push images to") + cmd.Flags().StringArrayVar(&buildWebHooks, "web-hook", []string{}, "Web hooks that are posted to after build. Format: @") return cmd } @@ -77,6 +79,9 @@ func startServer(cmd *cobra.Command, args []string) error { if err != nil { return err } - s := server.NewServer(port, db, dockerImageBuilder, servingPlatform, store) + s, err := server.NewServer(port, buildWebHooks, db, dockerImageBuilder, servingPlatform, store) + if err != nil { + return err + } return s.Start() } diff --git a/pkg/database/local.go b/pkg/database/local.go index 03a55bb05f..349cb1d266 100644 --- a/pkg/database/local.go +++ b/pkg/database/local.go @@ -17,7 +17,7 @@ type LocalFileDatabase struct { } func NewLocalFileDatabase(rootDir string) (*LocalFileDatabase, error) { - exists, err := files.FileExists(rootDir) + exists, err := files.Exists(rootDir) if err != nil { return nil, err } @@ -44,7 +44,7 @@ func (db *LocalFileDatabase) InsertModel(user string, name string, id string, mo } path := db.modelPath(user, name, id) dir := filepath.Dir(path) - exists, err := files.FileExists(path) + exists, err := files.Exists(path) if err != nil { return err } @@ -62,7 +62,7 @@ func (db *LocalFileDatabase) InsertModel(user string, name string, id string, mo // GetModel returns a model or nil if the model doesn't exist func (db *LocalFileDatabase) GetModel(user string, name string, id string) (*model.Model, error) { path := db.modelPath(user, name, id) - exists, err := files.FileExists(path) + exists, err := files.Exists(path) if err != nil { return nil, fmt.Errorf("Failed to determine if %s exists: %w", path, err) } diff --git a/pkg/files/files.go b/pkg/files/files.go index a8aece8a1c..6f1e0bf582 100644 --- a/pkg/files/files.go +++ b/pkg/files/files.go @@ -3,22 +3,28 @@ package files import ( "fmt" "os" + + "golang.org/x/sys/unix" ) -func FileExists(filePath string) (bool, error) { - if _, err := os.Stat(filePath); err == nil { +func Exists(path string) (bool, error) { + if _, err := os.Stat(path); err == nil { return true, nil } else if os.IsNotExist(err) { return false, nil } else { - return false, fmt.Errorf("Failed to determine if %s exists: %w", filePath, err) + return false, fmt.Errorf("Failed to determine if %s exists: %w", path, err) } } -func IsDir(dirPath string) (bool, error) { - file, err := os.Stat(dirPath) +func IsDir(path string) (bool, error) { + file, err := os.Stat(path) if err != nil { return false, err } return file.Mode().IsDir(), nil } + +func IsExecutable(path string) bool { + return unix.Access(path, unix.X_OK) == nil +} diff --git a/pkg/files/files_test.go b/pkg/files/files_test.go new file mode 100644 index 0000000000..fd47f2b823 --- /dev/null +++ b/pkg/files/files_test.go @@ -0,0 +1,21 @@ +package files + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsExecutable(t *testing.T) { + dir, err := os.MkdirTemp("/tmp", "test-files") + require.NoError(t, err) + path := filepath.Join(dir, "test-file") + err = os.WriteFile(path, []byte{}, 0644) + require.NoError(t, err) + + require.False(t, IsExecutable(path)) + os.Chmod(path, 0744) + require.True(t, IsExecutable(path)) +} diff --git a/pkg/server/build.go b/pkg/server/build.go index c44833dd2b..c0d8670b1f 100644 --- a/pkg/server/build.go +++ b/pkg/server/build.go @@ -34,7 +34,7 @@ func (s *Server) ReceiveFile(w http.ResponseWriter, r *http.Request) { streamLogger.WriteModel(mod) } -func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger, user string, name string) (*model.Model, error) { +func (s *Server) ReceiveModel(r *http.Request, logWriter logger.Logger, user string, name string) (*model.Model, error) { // max 5GB models if err := r.ParseMultipartForm(5 << 30); err != nil { return nil, fmt.Errorf("Failed to parse request: %w", err) @@ -45,7 +45,7 @@ func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger } defer file.Close() - streamLogger.WriteStatus("Received model") + logWriter.WriteStatus("Received model") hasher := sha1.New() if _, err := io.Copy(hasher, file); err != nil { @@ -86,7 +86,7 @@ func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger return nil, fmt.Errorf("Failed to upload to storage: %w", err) } - artifacts, err := s.buildDockerImages(dir, config, name, streamLogger) + artifacts, err := s.buildDockerImages(dir, config, name, logWriter) if err != nil { return nil, err } @@ -97,18 +97,22 @@ func (s *Server) ReceiveModel(r *http.Request, streamLogger *logger.StreamLogger Created: time.Now(), } - runArgs, err := s.testModel(mod, dir, streamLogger) + runArgs, err := s.testModel(mod, dir, logWriter) if err != nil { // TODO(andreas): return other response than 500 if validation fails return nil, err } mod.RunArguments = runArgs - streamLogger.WriteStatus("Inserting into database") + logWriter.WriteStatus("Inserting into database") if err := s.db.InsertModel(user, name, id, mod); err != nil { return nil, fmt.Errorf("Failed to insert into database: %w", err) } + if err := s.runWebHooks(user, name, mod, dir, logWriter); err != nil { + return nil, err + } + return mod, nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index 0f7b81256f..a604cc99ff 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -20,20 +20,30 @@ const topLevelSourceDir = "source" type Server struct { port int + webHooks []*WebHook db database.Database dockerImageBuilder docker.ImageBuilder servingPlatform serving.Platform store storage.Storage } -func NewServer(port int, db database.Database, dockerImageBuilder docker.ImageBuilder, servingPlatform serving.Platform, store storage.Storage) *Server { +func NewServer(port int, rawWebHooks []string, db database.Database, dockerImageBuilder docker.ImageBuilder, servingPlatform serving.Platform, store storage.Storage) (*Server, error) { + webHooks := []*WebHook{} + for _, rawWebHook := range rawWebHooks { + webHook, err := newWebHook(rawWebHook) + if err != nil { + return nil, err + } + webHooks = append(webHooks, webHook) + } return &Server{ port: port, + webHooks: webHooks, db: db, dockerImageBuilder: dockerImageBuilder, servingPlatform: servingPlatform, store: store, - } + }, nil } func (s *Server) Start() error { diff --git a/pkg/server/web_hook.go b/pkg/server/web_hook.go new file mode 100644 index 0000000000..4c2ad0ba4f --- /dev/null +++ b/pkg/server/web_hook.go @@ -0,0 +1,69 @@ +package server + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/replicate/cog/pkg/logger" + "github.com/replicate/cog/pkg/model" +) + +type WebHook struct { + url *url.URL + secret string +} + +func newWebHook(urlWithSecret string) (*WebHook, error) { + splitIndex := strings.LastIndex(urlWithSecret, "@") + if splitIndex == -1 { + return nil, fmt.Errorf("Web hooks must be in the format @") + } + rawURL := urlWithSecret[:splitIndex] + secret := urlWithSecret[splitIndex+1:] + + hookURL, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + return &WebHook{url: hookURL, secret: secret}, nil +} + +func (wh *WebHook) run(user string, name string, mod *model.Model, dir string, logWriter logger.Logger) error { + modelJSON, err := json.Marshal(mod) + if err != nil { + return err + } + modelJSONBase64 := base64.StdEncoding.EncodeToString(modelJSON) + modelPath := fmt.Sprintf("/v1/repos/%s/%s/models/%s", user, name, mod.ID) + + logWriter.Infof("Posting model to %s", wh.url.Host) + + req, err := http.PostForm(wh.url.String(), url.Values{ + "model_id": {mod.ID}, + "model_path": {modelPath}, + "model_json_base64": {modelJSONBase64}, + "user": {user}, + "repo_name": {name}, + "secret": {wh.secret}, + }) + if err != nil { + return fmt.Errorf("Model post failed: %w", err) + } + if req.StatusCode != http.StatusOK { + return fmt.Errorf("Model post failed with HTTP status %d", req.StatusCode) + } + return nil +} + +func (s *Server) runWebHooks(user string, name string, mod *model.Model, dir string, logWriter logger.Logger) error { + for _, hook := range s.webHooks { + if err := hook.run(user, name, mod, dir, logWriter); err != nil { + return err + } + } + return nil +} diff --git a/pkg/serving/local.go b/pkg/serving/local.go index 069ac28e54..02f3f9a8e6 100644 --- a/pkg/serving/local.go +++ b/pkg/serving/local.go @@ -156,7 +156,8 @@ func (p *LocalDockerPlatform) waitForContainerReady(hostPort int, containerID st } func (d *LocalDockerDeployment) Undeploy() error { - if err := d.client.ContainerStop(context.Background(), d.containerID, nil); err != nil { + timeout := time.Duration(100 * time.Millisecond) + if err := d.client.ContainerStop(context.Background(), d.containerID, &timeout); err != nil { return fmt.Errorf("Failed to stop Docker container %s: %w", d.containerID, err) } return nil diff --git a/pkg/settings/project.go b/pkg/settings/project.go index 3a8703544c..9474b7e3e5 100644 --- a/pkg/settings/project.go +++ b/pkg/settings/project.go @@ -23,7 +23,7 @@ func LoadProjectSettings(projectRoot string) (*ProjectSettings, error) { } settingsPath := projectSettingsPath(projectRoot) - exists, err := files.FileExists(settingsPath) + exists, err := files.Exists(settingsPath) if err != nil { return nil, err } diff --git a/pkg/storage/local.go b/pkg/storage/local.go index 7c1046b3b5..80d40698e8 100644 --- a/pkg/storage/local.go +++ b/pkg/storage/local.go @@ -16,7 +16,7 @@ type LocalStorage struct { } func NewLocalStorage(rootDir string) (*LocalStorage, error) { - exists, err := files.FileExists(rootDir) + exists, err := files.Exists(rootDir) if err != nil { return nil, err } @@ -39,7 +39,7 @@ func NewLocalStorage(rootDir string) (*LocalStorage, error) { func (s *LocalStorage) Upload(user string, name string, id string, reader io.Reader) error { path := s.pathForID(user, name, id) dir := filepath.Dir(path) - exists, err := files.FileExists(path) + exists, err := files.Exists(path) if err != nil { return err }