diff --git a/pkg/cli/benchmark.go b/pkg/cli/benchmark.go index 5005378650..02be597c88 100644 --- a/pkg/cli/benchmark.go +++ b/pkg/cli/benchmark.go @@ -102,7 +102,11 @@ func runBenchmarkInference(mod *model.Model, modelDir string, results *Benchmark logWriter := logger.NewConsoleLogger() bootStart := time.Now() - deployment, err := servingPlatform.Deploy(mod, benchmarkTarget, logWriter) + artifact, ok := mod.ArtifactFor(benchmarkTarget) + if !ok { + return fmt.Errorf("Target %s is not defined for model", benchmarkTarget) + } + deployment, err := servingPlatform.Deploy(artifact.URI, logWriter) if err != nil { return err } diff --git a/pkg/cli/infer.go b/pkg/cli/infer.go index 68a7aff188..9e3c283fbd 100644 --- a/pkg/cli/infer.go +++ b/pkg/cli/infer.go @@ -63,7 +63,11 @@ func cmdInfer(cmd *cobra.Command, args []string) error { } logWriter := logger.NewConsoleLogger() // TODO(andreas): GPU inference - deployment, err := servingPlatform.Deploy(mod, model.TargetDockerCPU, logWriter) + artifact, ok := mod.ArtifactFor(model.TargetDockerCPU) + if !ok { + return fmt.Errorf("Target %s is not defined for model", model.TargetDockerCPU) + } + deployment, err := servingPlatform.Deploy(artifact.URI, logWriter) if err != nil { return err } diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 61196796c0..8c3a94c401 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -36,6 +36,7 @@ func NewRootCommand() (*cobra.Command, error) { rootCmd.AddCommand( newBuildCommand(), + newTestCommand(), newDebugCommand(), newInferCommand(), newServerCommand(), diff --git a/pkg/cli/test.go b/pkg/cli/test.go new file mode 100644 index 0000000000..03f7d4adf8 --- /dev/null +++ b/pkg/cli/test.go @@ -0,0 +1,87 @@ +package cli + +import ( + "fmt" + "path/filepath" + "os" + + "github.com/spf13/cobra" + + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/files" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/logger" + "github.com/replicate/cog/pkg/model" + "github.com/replicate/cog/pkg/serving" +) + +func newTestCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "test", + Short: "Test the model locally", + RunE: Test, + Args: cobra.NoArgs, + } + cmd.Flags().StringP("arch", "a", "cpu", "Test architecture") + + return cmd +} + +func Test(cmd *cobra.Command, args []string) error { + arch, err := cmd.Flags().GetString("arch") + if err != nil { + return err + } + projectDir, err := os.Getwd() + if err != nil { + return err + } + logWriter := logger.NewConsoleLogger() + + configPath := filepath.Join(projectDir, global.ConfigFilename) + exists, err := files.Exists(configPath) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, projectDir) + } + configRaw, err := os.ReadFile(filepath.Join(projectDir, global.ConfigFilename)) + if err != nil { + return fmt.Errorf("Failed to read %s: %w", global.ConfigFilename, err) + } + config, err := model.ConfigFromYAML(configRaw) + if err != nil { + return err + } + if err := config.ValidateAndCompleteConfig(); err != nil { + return err + } + archMap := map[string]bool{} + for _, confArch := range config.Environment.Architectures { + archMap[confArch] = true + } + if _, ok := archMap[arch]; !ok { + return fmt.Errorf("Architecture %s is not defined for model", arch) + } + generator := &docker.DockerfileGenerator{Config: config, Arch: arch} + dockerfileContents, err := generator.Generate() + if err != nil { + return fmt.Errorf("Failed to generate Dockerfile for %s: %w", arch, err) + } + dockerImageBuilder := docker.NewLocalImageBuilder("") + servingPlatform, err := serving.NewLocalDockerPlatform() + if err != nil { + return err + } + tag, err := dockerImageBuilder.Build(projectDir, dockerfileContents, "", logWriter) + if err != nil { + return fmt.Errorf("Failed to build Docker image: %w", err) + } + + if _, err := serving.TestModel(servingPlatform, tag, config, projectDir, logWriter); err != nil { + return err + } + + return nil +} diff --git a/pkg/docker/builder.go b/pkg/docker/builder.go index 2e03bf5119..f2eb9a83d1 100644 --- a/pkg/docker/builder.go +++ b/pkg/docker/builder.go @@ -5,5 +5,6 @@ import ( ) type ImageBuilder interface { - BuildAndPush(dir string, dockerfilePath string, name string, logWriter logger.Logger) (fullImageTag string, err error) + Build(dir string, dockerfileContents string, name string, logWriter logger.Logger) (tag string, err error) + Push(tag string, logWriter logger.Logger) error } diff --git a/pkg/docker/local_builder.go b/pkg/docker/local_builder.go index 3386013ecb..b80c859ae4 100644 --- a/pkg/docker/local_builder.go +++ b/pkg/docker/local_builder.go @@ -5,11 +5,11 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "regexp" "strings" "github.com/replicate/cog/pkg/console" - "github.com/replicate/cog/pkg/logger" "github.com/replicate/cog/pkg/shell" ) @@ -27,26 +27,16 @@ func NewLocalImageBuilder(registry string) *LocalImageBuilder { return &LocalImageBuilder{registry: registry} } -func (b *LocalImageBuilder) BuildAndPush(dir string, dockerfilePath string, name string, logWriter logger.Logger) (fullImageTag string, err error) { - tag, err := b.build(dir, dockerfilePath, logWriter) - if err != nil { - return "", err - } - fullImageTag = fmt.Sprintf("%s/%s:%s", b.registry, name, tag) - if err := b.tag(tag, fullImageTag, logWriter); err != nil { - return "", err - } - if b.registry != noRegistry { - if err := b.push(fullImageTag, logWriter); err != nil { - return "", err - } - } - return fullImageTag, nil -} - -func (b *LocalImageBuilder) build(dir string, dockerfilePath string, logWriter logger.Logger) (tag string, err error) { +func (b *LocalImageBuilder) Build(dir string, dockerfileContents string, name string, logWriter logger.Logger) (tag string, err error) { console.Debugf("Building in %s", dir) + // TODO(andreas): pipe dockerfile contents to builder + relDockerfilePath := "Dockerfile" + dockerfilePath := filepath.Join(dir, relDockerfilePath) + if err := os.WriteFile(dockerfilePath, []byte(dockerfileContents), 0644); err != nil { + return "", fmt.Errorf("Failed to write Dockerfile") + } + cmd := exec.Command( "docker", "build", ".", "--progress", "plain", @@ -76,25 +66,41 @@ func (b *LocalImageBuilder) build(dir string, dockerfilePath string, logWriter l dockerTag := <-tagChan + if err != nil { + return "", err + } + logWriter.Infof("Successfully built %s", dockerTag) - return dockerTag, err + tag = dockerTag + if name != "" { + tag = fmt.Sprintf("%s/%s:%s", b.registry, name, dockerTag) + if err := b.tag(dockerTag, tag, logWriter); err != nil { + return "", err + } + } + + return tag, nil } -func (b *LocalImageBuilder) tag(tag string, fullImageTag string, logWriter logger.Logger) error { - console.Debugf("Tagging %s as %s", tag, fullImageTag) +func (b *LocalImageBuilder) tag(dockerTag string, tag string, logWriter logger.Logger) error { + console.Debugf("Tagging %s as %s", dockerTag, tag) - cmd := exec.Command("docker", "tag", tag, fullImageTag) + cmd := exec.Command("docker", "tag", dockerTag, tag) cmd.Env = os.Environ() if _, err := cmd.Output(); err != nil { ee := err.(*exec.ExitError) stderr := string(ee.Stderr) - return fmt.Errorf("Failed to tag %s as %s, got error: %s", tag, fullImageTag, stderr) + return fmt.Errorf("Failed to tag %s as %s, got error: %s", dockerTag, tag, stderr) } return nil } -func (b *LocalImageBuilder) push(tag string, logWriter logger.Logger) error { +func (b *LocalImageBuilder) Push(tag string, logWriter logger.Logger) error { + if b.registry == noRegistry { + return nil + } + logWriter.Infof("Pushing %s to registry", tag) args := []string{"push", tag} diff --git a/pkg/model/compatibility.go b/pkg/model/compatibility.go index ce28b2b2e0..c9b1588417 100644 --- a/pkg/model/compatibility.go +++ b/pkg/model/compatibility.go @@ -4,6 +4,7 @@ import ( _ "embed" "encoding/json" "fmt" + "runtime" "sort" "strings" @@ -251,9 +252,10 @@ func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err func torchCPUPackage(ver string) (name string, cpuVersion string, indexURL string, err error) { for _, compat := range TorchCompatibilityMatrix { if compat.TorchVersion() == ver && compat.CUDA == nil { - return "torch", compat.Torch, compat.IndexURL, nil + return "torch", torchStripCPUSuffixForM1(compat.Torch), compat.IndexURL, nil } } + return "", "", "", fmt.Errorf("No matching Torch CPU package for version %s", ver) } @@ -297,7 +299,7 @@ func torchGPUPackage(ver string, cuda string) (name string, cpuVersion string, i func torchvisionCPUPackage(ver string) (name string, cpuVersion string, indexURL string, err error) { for _, compat := range TorchCompatibilityMatrix { if compat.TorchvisionVersion() == ver && compat.CUDA == nil { - return "torchvision", compat.Torchvision, compat.IndexURL, nil + return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision), compat.IndexURL, nil } } return "", "", "", fmt.Errorf("No matching torchvision CPU package for version %s", ver) @@ -339,3 +341,13 @@ func torchvisionGPUPackage(ver string, cuda string) (name string, cpuVersion str return "torchvision", latest.Torchvision, latest.IndexURL, nil } + +// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html +// TODO(andreas): clean up this hack by actually parsing the torch_stable.html list in the generator +func torchStripCPUSuffixForM1(version string) string { + // TODO(andreas): clean up this hack + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + return strings.ReplaceAll(version, "+cpu", "") + } + return version +} diff --git a/pkg/server/build.go b/pkg/server/build.go index 7ef3ce805b..33c997c4aa 100644 --- a/pkg/server/build.go +++ b/pkg/server/build.go @@ -10,7 +10,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "time" "github.com/mholt/archiver/v3" @@ -107,88 +106,39 @@ func (s *Server) ReceiveModel(r *http.Request, logWriter logger.Logger, user str Created: time.Now(), } - runArgs, err := s.testModel(mod, dir, logWriter) - if err != nil { - // TODO(andreas): return other response than 500 if validation fails - return nil, err - } - mod.RunArguments = runArgs - - logWriter.WriteStatus("Inserting into database") - if err := s.db.InsertModel(user, name, id, mod); err != nil { - return nil, fmt.Errorf("Failed to insert into database: %w", err) - } - - if err := s.runWebHooks(user, name, mod, dir, logWriter); err != nil { - return nil, err - } - - return mod, nil -} - -func (s *Server) testModel(mod *model.Model, dir string, logWriter logger.Logger) (map[string]*model.RunArgument, error) { - logWriter.WriteStatus("Testing model") - target := model.TargetDockerCPU + testTarget := model.TargetDockerCPU if _, ok := mod.ArtifactFor(model.TargetDockerCPU); !ok { if _, ok := mod.ArtifactFor(model.TargetDockerGPU); ok { - target = model.TargetDockerGPU + testTarget = model.TargetDockerGPU } else { return nil, fmt.Errorf("Model has neither CPU or GPU target") } } - - deployment, err := s.servingPlatform.Deploy(mod, target, logWriter) + testArtifact, ok := mod.ArtifactFor(testTarget) + if !ok { + return nil, fmt.Errorf("Model has no %s target", testTarget) + } + runArgs, err := serving.TestModel(s.servingPlatform, testArtifact.URI, mod.Config, dir, logWriter) if err != nil { + // TODO(andreas): return other response than 500 if validation fails return nil, err } - defer deployment.Undeploy() + mod.RunArguments = runArgs - help, err := deployment.Help(logWriter) - if err != nil { + if err := s.pushDockerImages(dir, mod, logWriter); err != nil { return nil, err } - for _, example := range mod.Config.Examples { - if err := validateServingExampleInput(help, example.Input); err != nil { - return nil, fmt.Errorf("Example input doesn't match run arguments: %w", err) - } - var expectedOutput []byte = nil - outputIsFile := false - if example.Output != "" { - if strings.HasPrefix(example.Output, "@") { - outputIsFile = true - expectedOutput, err = os.ReadFile(filepath.Join(dir, example.Output[1:])) - if err != nil { - return nil, fmt.Errorf("Failed to read example output file %s: %w", example.Output[1:], err) - } - } else { - expectedOutput = []byte(example.Output) - } - } - - input := serving.NewExampleWithBaseDir(example.Input, dir) + logWriter.WriteStatus("Inserting into database") + if err := s.db.InsertModel(user, name, id, mod); err != nil { + return nil, fmt.Errorf("Failed to insert into database: %w", err) + } - result, err := deployment.RunInference(input, logWriter) - if err != nil { - return nil, err - } - output := result.Values["output"] - outputBytes, err := io.ReadAll(output.Buffer) - if err != nil { - return nil, fmt.Errorf("Failed to read output: %w", err) - } - logWriter.Infof(fmt.Sprintf("Inference result length: %d, mime type: %s", len(outputBytes), output.MimeType)) - if expectedOutput != nil { - if outputIsFile && !bytes.Equal(expectedOutput, outputBytes) { - return nil, fmt.Errorf("Output file contents doesn't match expected %s", example.Output[1:]) - } else if !outputIsFile && strings.TrimSpace(string(outputBytes)) != strings.TrimSpace(example.Output) { - // TODO(andreas): are there cases where space is significant? - return nil, fmt.Errorf("Output %s doesn't match expected: %s", string(outputBytes), example.Output) - } - } + if err := s.runWebHooks(user, name, mod, dir, logWriter); err != nil { + return nil, err } - return help.Arguments, nil + return mod, nil } // TODO(andreas): include user in docker image name? @@ -196,7 +146,6 @@ func (s *Server) buildDockerImages(dir string, config *model.Config, name string // TODO(andreas): parallelize artifacts := []*model.Artifact{} for _, arch := range config.Environment.Architectures { - logWriter.WriteStatus("Building %s image", arch) generator := &docker.DockerfileGenerator{Config: config, Arch: arch} @@ -204,14 +153,7 @@ func (s *Server) buildDockerImages(dir string, config *model.Config, name string if err != nil { return nil, fmt.Errorf("Failed to generate Dockerfile for %s: %w", arch, err) } - // TODO(andreas): pipe dockerfile contents to builder - relDockerfilePath := "Dockerfile." + arch - dockerfilePath := filepath.Join(dir, relDockerfilePath) - if err := os.WriteFile(dockerfilePath, []byte(dockerfileContents), 0644); err != nil { - return nil, fmt.Errorf("Failed to write Dockerfile for %s", arch) - } - - tag, err := s.dockerImageBuilder.BuildAndPush(dir, relDockerfilePath, name, logWriter) + tag, err := s.dockerImageBuilder.Build(dir, dockerfileContents, name, logWriter) if err != nil { return nil, fmt.Errorf("Failed to build Docker image: %w", err) } @@ -231,6 +173,15 @@ func (s *Server) buildDockerImages(dir string, config *model.Config, name string return artifacts, nil } +func (s *Server) pushDockerImages(dir string, model *model.Model, logWriter logger.Logger) error { + for _, artifact := range model.Artifacts { + if err := s.dockerImageBuilder.Push(artifact.URI, logWriter); err != nil { + return err + } + } + return nil +} + func (s *Server) GetCacheHashes(w http.ResponseWriter, r *http.Request) { user, name, _ := getRepoVars(r) console.Infof("Received cache-hashes request for %s/%s", user, name) @@ -258,34 +209,6 @@ func (s *Server) GetCacheHashes(w http.ResponseWriter, r *http.Request) { } } -func validateServingExampleInput(help *serving.HelpResponse, input map[string]string) error { - // TODO(andreas): validate types - missingNames := []string{} - extraneousNames := []string{} - - for name, arg := range help.Arguments { - if _, ok := input[name]; !ok && arg.Default == nil { - missingNames = append(missingNames, name) - } - } - for name := range input { - if _, ok := help.Arguments[name]; !ok { - extraneousNames = append(extraneousNames, name) - } - } - errParts := []string{} - if len(missingNames) > 0 { - errParts = append(errParts, "Missing arguments: "+strings.Join(missingNames, ", ")) - } - if len(extraneousNames) > 0 { - errParts = append(errParts, "Extraneous arguments: "+strings.Join(extraneousNames, ", ")) - } - if len(errParts) > 0 { - return fmt.Errorf(strings.Join(errParts, "; ")) - } - return nil -} - func computeID(dir string) (string, error) { hasher := sha1.New() err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error { diff --git a/pkg/serving/local.go b/pkg/serving/local.go index ca94e4c0e8..cfc3787d29 100644 --- a/pkg/serving/local.go +++ b/pkg/serving/local.go @@ -23,7 +23,6 @@ import ( "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/logger" - "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/shell" ) @@ -47,16 +46,10 @@ func NewLocalDockerPlatform() (*LocalDockerPlatform, error) { }, nil } -func (p *LocalDockerPlatform) Deploy(mod *model.Model, target string, logWriter logger.Logger) (Deployment, error) { +func (p *LocalDockerPlatform) Deploy(imageTag string, logWriter logger.Logger) (Deployment, error) { // TODO(andreas): output container logs - artifact, ok := mod.ArtifactFor(target) - if !ok { - return nil, fmt.Errorf("Model has no %s target", target) - } - imageTag := artifact.URI - - logWriter.Infof("Deploying container %s for target %s", imageTag, artifact.Target) + logWriter.Infof("Deploying container %s", imageTag) if !docker.Exists(imageTag, logWriter) { if err := docker.Pull(imageTag, logWriter); err != nil { diff --git a/pkg/serving/platform.go b/pkg/serving/platform.go index 853946e85b..fae0583ab8 100644 --- a/pkg/serving/platform.go +++ b/pkg/serving/platform.go @@ -10,7 +10,7 @@ import ( ) type Platform interface { - Deploy(mod *model.Model, target string, logWriter logger.Logger) (Deployment, error) + Deploy(imageTag string, logWriter logger.Logger) (Deployment, error) } type Deployment interface { diff --git a/pkg/serving/test.go b/pkg/serving/test.go new file mode 100644 index 0000000000..3ee72e58cd --- /dev/null +++ b/pkg/serving/test.go @@ -0,0 +1,98 @@ +package serving + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/replicate/cog/pkg/logger" + "github.com/replicate/cog/pkg/model" +) + +func TestModel(servingPlatform Platform, imageTag string, config *model.Config, dir string, logWriter logger.Logger) (map[string]*model.RunArgument, error) { + logWriter.WriteStatus("Testing model") + + deployment, err := servingPlatform.Deploy(imageTag, logWriter) + if err != nil { + return nil, err + } + defer deployment.Undeploy() + + help, err := deployment.Help(logWriter) + if err != nil { + return nil, err + } + + for _, example := range config.Examples { + if err := validateServingExampleInput(help, example.Input); err != nil { + return nil, fmt.Errorf("Example input doesn't match run arguments: %w", err) + } + var expectedOutput []byte = nil + outputIsFile := false + if example.Output != "" { + if strings.HasPrefix(example.Output, "@") { + outputIsFile = true + expectedOutput, err = os.ReadFile(filepath.Join(dir, example.Output[1:])) + if err != nil { + return nil, fmt.Errorf("Failed to read example output file %s: %w", example.Output[1:], err) + } + } else { + expectedOutput = []byte(example.Output) + } + } + + input := NewExampleWithBaseDir(example.Input, dir) + + result, err := deployment.RunInference(input, logWriter) + if err != nil { + return nil, err + } + output := result.Values["output"] + outputBytes, err := io.ReadAll(output.Buffer) + if err != nil { + return nil, fmt.Errorf("Failed to read output: %w", err) + } + logWriter.Infof(fmt.Sprintf("Inference result length: %d, mime type: %s", len(outputBytes), output.MimeType)) + if expectedOutput != nil { + if outputIsFile && !bytes.Equal(expectedOutput, outputBytes) { + return nil, fmt.Errorf("Output file contents doesn't match expected %s", example.Output[1:]) + } else if !outputIsFile && strings.TrimSpace(string(outputBytes)) != strings.TrimSpace(example.Output) { + // TODO(andreas): are there cases where space is significant? + return nil, fmt.Errorf("Output %s doesn't match expected: %s", string(outputBytes), example.Output) + } + } + } + + return help.Arguments, nil +} + +func validateServingExampleInput(help *HelpResponse, input map[string]string) error { + // TODO(andreas): validate types + missingNames := []string{} + extraneousNames := []string{} + + for name, arg := range help.Arguments { + if _, ok := input[name]; !ok && arg.Default == nil { + missingNames = append(missingNames, name) + } + } + for name := range input { + if _, ok := help.Arguments[name]; !ok { + extraneousNames = append(extraneousNames, name) + } + } + errParts := []string{} + if len(missingNames) > 0 { + errParts = append(errParts, "Missing arguments: "+strings.Join(missingNames, ", ")) + } + if len(extraneousNames) > 0 { + errParts = append(errParts, "Extraneous arguments: "+strings.Join(extraneousNames, ", ")) + } + if len(errParts) > 0 { + return fmt.Errorf(strings.Join(errParts, "; ")) + } + return nil +}