diff --git a/pkg/cli/push.go b/pkg/cli/push.go index abfee5ef08..087c8199c9 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -18,6 +18,8 @@ import ( "github.com/replicate/cog/pkg/util/console" ) +var pushPipeline bool + func newPushCommand() *cobra.Command { cmd := &cobra.Command{ Use: "push [IMAGE]", @@ -40,6 +42,7 @@ func newPushCommand() *cobra.Command { addFastFlag(cmd) addLocalImage(cmd) addConfigFlag(cmd) + addPipelineImage(cmd) return cmd } @@ -111,6 +114,7 @@ func push(cmd *cobra.Command, args []string) error { err = docker.Push(ctx, imageName, buildFast, projectDir, dockerClient, docker.BuildInfo{ BuildTime: buildDuration, BuildID: buildID.String(), + Pipeline: pushPipeline, }, client) if err != nil { if strings.Contains(err.Error(), "404") { @@ -140,3 +144,9 @@ func push(cmd *cobra.Command, args []string) error { return nil } + +func addPipelineImage(cmd *cobra.Command) { + const pipeline = "x-pipeline" + cmd.Flags().BoolVar(&pushPipeline, pipeline, false, "Whether to use the experimental pipeline push feature") + _ = cmd.Flags().MarkHidden(pipeline) +} diff --git a/pkg/docker/pipeline_push.go b/pkg/docker/pipeline_push.go new file mode 100644 index 0000000000..20ac20ae03 --- /dev/null +++ b/pkg/docker/pipeline_push.go @@ -0,0 +1,78 @@ +package docker + +import ( + "archive/tar" + "bytes" + "context" + "io" + "os" + "path/filepath" + + "github.com/replicate/cog/pkg/dockerignore" + "github.com/replicate/cog/pkg/web" +) + +func PipelinePush(ctx context.Context, image string, projectDir string, webClient *web.Client) error { + tarball, err := createTarball(projectDir) + if err != nil { + return err + } + return webClient.PostNewPipeline(ctx, image, tarball) +} + +func createTarball(folder string) (*bytes.Buffer, error) { + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + matcher, err := dockerignore.CreateMatcher(folder) + if err != nil { + return nil, err + } + + err = dockerignore.Walk(folder, matcher, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + relPath, err := filepath.Rel(folder, path) + if err != nil { + return err + } + + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + header, err := tar.FileInfoHeader(info, info.Name()) + if err != nil { + return err + } + header.Name = relPath + + err = tw.WriteHeader(header) + if err != nil { + return err + } + + _, err = io.Copy(tw, file) + if err != nil { + return err + } + return nil + }) + if err != nil { + return nil, err + } + + if err := tw.Close(); err != nil { + return nil, err + } + + return &buf, nil +} diff --git a/pkg/docker/pipeline_push_test.go b/pkg/docker/pipeline_push_test.go new file mode 100644 index 0000000000..da097a39ba --- /dev/null +++ b/pkg/docker/pipeline_push_test.go @@ -0,0 +1,49 @@ +package docker + +import ( + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/docker/dockertest" + "github.com/replicate/cog/pkg/env" + cogHttp "github.com/replicate/cog/pkg/http" + "github.com/replicate/cog/pkg/web" +) + +func TestPipelinePush(t *testing.T) { + // Setup mock http server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + output := "{\"version\":\"user/test:53c740f17ce88a61c3da5b0c20e48fd48e2da537c3a1276dec63ab11fbad6bcb\"}" + w.WriteHeader(http.StatusCreated) + w.Write([]byte(output)) + })) + defer server.Close() + url, err := url.Parse(server.URL) + require.NoError(t, err) + t.Setenv(env.SchemeEnvVarName, url.Scheme) + t.Setenv(web.WebHostEnvVarName, url.Host) + + dir := t.TempDir() + + // Create mock predict + predictPyPath := filepath.Join(dir, "predict.py") + handle, err := os.Create(predictPyPath) + require.NoError(t, err) + handle.WriteString("import cog") + dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" + + // Setup mock command + command := dockertest.NewMockCommand() + client, err := cogHttp.ProvideHTTPClient(t.Context(), command) + require.NoError(t, err) + webClient := web.NewClient(command, client) + + err = PipelinePush(t.Context(), "r8.im/username/modelname", dir, webClient) + require.NoError(t, err) +} diff --git a/pkg/docker/push.go b/pkg/docker/push.go index 0cfee0bc86..82abab9f66 100644 --- a/pkg/docker/push.go +++ b/pkg/docker/push.go @@ -14,11 +14,16 @@ import ( type BuildInfo struct { BuildTime time.Duration BuildID string + Pipeline bool } func Push(ctx context.Context, image string, fast bool, projectDir string, command command.Command, buildInfo BuildInfo, client *http.Client) error { webClient := web.NewClient(command, client) + if buildInfo.Pipeline { + return PipelinePush(ctx, image, projectDir, webClient) + } + if err := webClient.PostPushStart(ctx, buildInfo.BuildID, buildInfo.BuildTime); err != nil { console.Warnf("Failed to send build timings to server: %v", err) } diff --git a/pkg/docker/push_test.go b/pkg/docker/push_test.go index 5dfcb6ae6f..485f4fd483 100644 --- a/pkg/docker/push_test.go +++ b/pkg/docker/push_test.go @@ -152,3 +152,36 @@ func TestPushWithWeight(t *testing.T) { err = Push(t.Context(), "r8.im/username/modelname", true, dir, command, BuildInfo{}, client) require.NoError(t, err) } + +func TestPushPipeline(t *testing.T) { + // Setup mock http server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + output := "{\"version\":\"user/test:53c740f17ce88a61c3da5b0c20e48fd48e2da537c3a1276dec63ab11fbad6bcb\"}" + w.WriteHeader(http.StatusCreated) + w.Write([]byte(output)) + })) + defer server.Close() + url, err := url.Parse(server.URL) + require.NoError(t, err) + t.Setenv(env.SchemeEnvVarName, url.Scheme) + t.Setenv(web.WebHostEnvVarName, url.Host) + + dir := t.TempDir() + + // Create mock predict + predictPyPath := filepath.Join(dir, "predict.py") + handle, err := os.Create(predictPyPath) + require.NoError(t, err) + handle.WriteString("import cog") + dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" + + // Setup mock command + command := dockertest.NewMockCommand() + client, err := cogHttp.ProvideHTTPClient(t.Context(), command) + require.NoError(t, err) + + err = Push(t.Context(), "r8.im/username/modelname", false, dir, command, BuildInfo{ + Pipeline: true, + }, client) + require.NoError(t, err) +} diff --git a/pkg/dockerignore/dockerignore.go b/pkg/dockerignore/dockerignore.go index 237b9235d5..987f199c03 100644 --- a/pkg/dockerignore/dockerignore.go +++ b/pkg/dockerignore/dockerignore.go @@ -10,8 +10,10 @@ import ( "github.com/replicate/cog/pkg/util/files" ) +const DockerIgnoreFilename = ".dockerignore" + func CreateMatcher(dir string) (*ignore.GitIgnore, error) { - dockerIgnorePath := filepath.Join(dir, ".dockerignore") + dockerIgnorePath := filepath.Join(dir, DockerIgnoreFilename) dockerIgnoreExists, err := files.Exists(dockerIgnorePath) if err != nil { return nil, err @@ -27,6 +29,32 @@ func CreateMatcher(dir string) (*ignore.GitIgnore, error) { return ignore.CompileIgnoreLines(patterns...), nil } +func Walk(root string, ignoreMatcher *ignore.GitIgnore, fn filepath.WalkFunc) error { + return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // We ignore files ignored by .dockerignore + if ignoreMatcher != nil && ignoreMatcher.MatchesPath(path) { + if info.IsDir() { + return filepath.SkipDir + } + return nil + } + + if info.IsDir() && info.Name() == ".cog" { + return filepath.SkipDir + } + + if info.Name() == DockerIgnoreFilename { + return nil + } + + return fn(path, info, err) + }) +} + func readDockerIgnore(dockerIgnorePath string) ([]string, error) { var patterns []string file, err := os.Open(dockerIgnorePath) diff --git a/pkg/dockerignore/dockerignore_test.go b/pkg/dockerignore/dockerignore_test.go new file mode 100644 index 0000000000..64ef1b9578 --- /dev/null +++ b/pkg/dockerignore/dockerignore_test.go @@ -0,0 +1,56 @@ +package dockerignore + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWalk(t *testing.T) { + dir := t.TempDir() + + predictOtherPyFilename := "predict_other.py" + predictOtherPyFilepath := filepath.Join(dir, predictOtherPyFilename) + predictOtherPyHandle, err := os.Create(predictOtherPyFilepath) + require.NoError(t, err) + predictOtherPyHandle.WriteString("import cog") + + dockerIgnorePath := filepath.Join(dir, ".dockerignore") + dockerIgnoreHandle, err := os.Create(dockerIgnorePath) + require.NoError(t, err) + dockerIgnoreHandle.WriteString(predictOtherPyFilename) + + predictPyFilename := "predict.py" + predictPyFilepath := filepath.Join(dir, predictPyFilename) + predictPyHandle, err := os.Create(predictPyFilepath) + require.NoError(t, err) + predictPyHandle.WriteString("import cog") + + matcher, err := CreateMatcher(dir) + require.NoError(t, err) + + foundFiles := []string{} + err = Walk(dir, matcher, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + relPath, err := filepath.Rel(dir, path) + if err != nil { + return err + } + + foundFiles = append(foundFiles, relPath) + + return nil + }) + require.NoError(t, err) + + require.Equal(t, []string{predictPyFilename}, foundFiles) +} diff --git a/pkg/web/client.go b/pkg/web/client.go index bc10f6a0d4..7ebd9883fb 100644 --- a/pkg/web/client.go +++ b/pkg/web/client.go @@ -6,6 +6,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "mime/multipart" "net/http" "net/url" "os" @@ -13,6 +15,7 @@ import ( "strings" "time" + "github.com/docker/docker/api/types/image" "github.com/replicate/go/types" "golang.org/x/sync/errgroup" @@ -204,16 +207,66 @@ func (c *Client) PostNewVersion(ctx context.Context, image string, weights []Fil return nil } +func (c *Client) PostNewPipeline(ctx context.Context, image string, tarball *bytes.Buffer) error { + // Fetch manifest + manifest, err := c.dockerCommand.Inspect(ctx, image) + if err != nil { + return util.WrapError(err, "failed to inspect docker image") + } + + // Create form data body + body := new(bytes.Buffer) + mp := multipart.NewWriter(body) + defer mp.Close() + err = mp.WriteField("openapi_schema", manifest.Config.Labels[command.CogOpenAPISchemaLabelKey]) + if err != nil { + return err + } + part, err := mp.CreateFormFile("source_archive", "source.tar") + if err != nil { + return err + } + _, err = io.Copy(part, bytes.NewReader(tarball.Bytes())) + if err != nil { + return err + } + + // Create version URL + versionUrl, err := newVersionURL(image) + if err != nil { + return err + } + + // Create a new request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, versionUrl.String(), bytes.NewReader(body.Bytes())) + if err != nil { + return err + } + req.Header.Set("Content-Type", mp.FormDataContentType()) + + // Make the request + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode)) + } + + return nil +} + func (c *Client) versionFromManifest(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) (*Version, error) { manifest, err := c.dockerCommand.Inspect(ctx, image) if err != nil { return nil, util.WrapError(err, "failed to inspect docker image") } - var cogConfig config.Config - err = json.Unmarshal([]byte(manifest.Config.Labels[command.CogConfigLabelKey]), &cogConfig) + cogConfig, err := readCogConfig(manifest) if err != nil { - return nil, util.WrapError(err, "failed to get cog config from docker image") + return nil, err } var openAPISchema map[string]any @@ -311,7 +364,7 @@ func (c *Client) versionFromManifest(ctx context.Context, image string, weights version := Version{ Annotations: manifest.Config.Labels, - CogConfig: cogConfig, + CogConfig: *cogConfig, CogVersion: manifest.Config.Labels[command.CogVersionLabelKey], OpenAPISchema: openAPISchema, RuntimeConfig: runtimeConfig, @@ -326,61 +379,6 @@ func (c *Client) versionFromManifest(ctx context.Context, image string, weights return &version, nil } -func newVersionURL(image string) (url.URL, error) { - imageComponents := strings.Split(image, "/") - newVersionUrl := webBaseURL() - if len(imageComponents) != 3 { - return newVersionUrl, ErrorBadRegistryURL - } - if imageComponents[0] != global.ReplicateRegistryHost { - return newVersionUrl, ErrorBadRegistryHost - } - newVersionUrl.Path = strings.Join([]string{"", "api", "models", imageComponents[1], imageComponents[2], "versions"}, "/") - return newVersionUrl, nil -} - -func webBaseURL() url.URL { - return url.URL{ - Scheme: env.SchemeFromEnvironment(), - Host: HostFromEnvironment(), - } -} - -func stripCodeFromStub(cogConfig config.Config, isPredict bool) (string, error) { - var stubComponents []string - if isPredict { - stubComponents = strings.Split(cogConfig.Predict, ":") - } else { - stubComponents = strings.Split(cogConfig.Train, ":") - } - - if len(stubComponents) < 2 { - return "", nil - } - - codeFile := stubComponents[0] - - b, err := os.ReadFile(codeFile) - if err != nil { - return "", err - } - - // TODO: We should attempt to strip the code here, in python this is done like so: - // from cog.code_xforms import strip_model_source_code - // code = strip_model_source_code( - // util.read_file(os.path.join(fs, 'src', base_file)), - // [base_class], - // ['predict', 'train'], - // ) - // Currently the behavior of the code strip attempts to strip, and if it can't it - // loads the whole file in. Here we just load the whole file in. - // We should figure out a way to call cog python from here to fulfill this. - // It could be a good idea to do this in the layer functions where we do pip freeze - // et al. - - return string(b), nil -} - func (c *Client) InitiateAndDoFileChallenge(ctx context.Context, weights []File, files []File) ([]FileChallengeAnswer, error) { var challengeAnswers []FileChallengeAnswer @@ -458,3 +456,89 @@ func (c *Client) doSingleFileChallenge(ctx context.Context, file File, fileType ChallengeID: challenge.ID, }, nil } + +func newVersionURL(image string) (url.URL, error) { + imageComponents := strings.Split(image, "/") + newVersionUrl := webBaseURL() + if len(imageComponents) != 3 { + return newVersionUrl, ErrorBadRegistryURL + } + if imageComponents[0] != global.ReplicateRegistryHost { + return newVersionUrl, ErrorBadRegistryHost + } + newVersionUrl.Path = strings.Join([]string{"", "api", "models", imageComponents[1], imageComponents[2], "versions"}, "/") + return newVersionUrl, nil +} + +func webBaseURL() url.URL { + return url.URL{ + Scheme: env.SchemeFromEnvironment(), + Host: HostFromEnvironment(), + } +} + +func codeFileName(cogConfig *config.Config, isPredict bool) (string, error) { + var stubComponents []string + if isPredict { + if cogConfig.Predict == "" { + return "", nil + } + stubComponents = strings.Split(cogConfig.Predict, ":") + } else { + if cogConfig.Train == "" { + return "", nil + } + stubComponents = strings.Split(cogConfig.Train, ":") + } + + if len(stubComponents) < 2 { + return "", errors.New("Code stub components has less than 2 entries.") + } + + return stubComponents[0], nil +} + +func readCode(cogConfig *config.Config, isPredict bool) (string, string, error) { + codeFile, err := codeFileName(cogConfig, isPredict) + if err != nil { + return "", codeFile, err + } + if codeFile == "" { + return "", "", nil + } + + b, err := os.ReadFile(codeFile) + if err != nil { + return "", codeFile, err + } + + return string(b), codeFile, nil +} + +func stripCodeFromStub(cogConfig *config.Config, isPredict bool) (string, error) { + // TODO: We should attempt to strip the code here, in python this is done like so: + // from cog.code_xforms import strip_model_source_code + // code = strip_model_source_code( + // util.read_file(os.path.join(fs, 'src', base_file)), + // [base_class], + // ['predict', 'train'], + // ) + // Currently the behavior of the code strip attempts to strip, and if it can't it + // loads the whole file in. Here we just load the whole file in. + // We should figure out a way to call cog python from here to fulfill this. + // It could be a good idea to do this in the layer functions where we do pip freeze + // et al. + + code, _, err := readCode(cogConfig, isPredict) + return code, err +} + +func readCogConfig(manifest *image.InspectResponse) (*config.Config, error) { + var cogConfig config.Config + err := json.Unmarshal([]byte(manifest.Config.Labels[command.CogConfigLabelKey]), &cogConfig) + if err != nil { + return nil, util.WrapError(err, "failed to get cog config from docker image") + } + + return &cogConfig, nil +} diff --git a/pkg/web/client_test.go b/pkg/web/client_test.go index 7fe4fcd7d5..5dfd51ccf8 100644 --- a/pkg/web/client_test.go +++ b/pkg/web/client_test.go @@ -1,6 +1,7 @@ package web import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -167,3 +168,33 @@ func TestDoFileChallenge(t *testing.T) { }, }) } + +func TestPostPipeline(t *testing.T) { + // Setup mock http server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + output := "{\"version\":\"user/test:53c740f17ce88a61c3da5b0c20e48fd48e2da537c3a1276dec63ab11fbad6bcb\"}" + w.WriteHeader(http.StatusCreated) + w.Write([]byte(output)) + })) + defer server.Close() + url, err := url.Parse(server.URL) + require.NoError(t, err) + t.Setenv(env.SchemeEnvVarName, url.Scheme) + t.Setenv(WebHostEnvVarName, url.Host) + + dir := t.TempDir() + + // Create mock predict + predictPyPath := filepath.Join(dir, "predict.py") + handle, err := os.Create(predictPyPath) + require.NoError(t, err) + handle.WriteString("import cog") + dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" + + // Setup mock command + command := dockertest.NewMockCommand() + + client := NewClient(command, http.DefaultClient) + err = client.PostNewPipeline(t.Context(), "r8.im/user/test", new(bytes.Buffer)) + require.NoError(t, err) +} diff --git a/pkg/weights/fast_weights.go b/pkg/weights/fast_weights.go index fa688f450b..7abe5bae55 100644 --- a/pkg/weights/fast_weights.go +++ b/pkg/weights/fast_weights.go @@ -63,24 +63,10 @@ func findFullWeights(folder string, weights []Weight, weightFile string) ([]Weig if err != nil { return weights, err } - err = filepath.Walk(folder, func(path string, info os.FileInfo, err error) error { + err = dockerignore.Walk(folder, matcher, func(path string, info os.FileInfo, err error) error { if err != nil { return err } - - // We ignore files ignored by .dockerignore - if matcher != nil && matcher.MatchesPath(path) { - if info.IsDir() { - return filepath.SkipDir - } - return nil - } - - // Skip the .cog directory when looking for weights - this is where we store cog generated files - if info.IsDir() && info.Name() == ".cog" { - return filepath.SkipDir - } - relPath, err := filepath.Rel(folder, path) if err != nil { return err