Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/cli/benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ func runBenchmarkInference(mod *model.Model, modelDir string, results *Benchmark

logWriter := logger.NewConsoleLogger()
bootStart := time.Now()
deployment, err := servingPlatform.Deploy(mod, benchmarkTarget, logWriter)
artifact, ok := mod.ArtifactFor(benchmarkTarget)
if !ok {
return fmt.Errorf("Target %s is not defined for model", benchmarkTarget)
}
deployment, err := servingPlatform.Deploy(artifact.URI, logWriter)
if err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/cli/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ func cmdInfer(cmd *cobra.Command, args []string) error {
}
logWriter := logger.NewConsoleLogger()
// TODO(andreas): GPU inference
deployment, err := servingPlatform.Deploy(mod, model.TargetDockerCPU, logWriter)
artifact, ok := mod.ArtifactFor(model.TargetDockerCPU)
if !ok {
return fmt.Errorf("Target %s is not defined for model", model.TargetDockerCPU)
}
deployment, err := servingPlatform.Deploy(artifact.URI, logWriter)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func NewRootCommand() (*cobra.Command, error) {

rootCmd.AddCommand(
newBuildCommand(),
newTestCommand(),
newDebugCommand(),
newInferCommand(),
newServerCommand(),
Expand Down
87 changes: 87 additions & 0 deletions pkg/cli/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package cli

import (
"fmt"
"path/filepath"
"os"

"github.com/spf13/cobra"

"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/files"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/logger"
"github.com/replicate/cog/pkg/model"
"github.com/replicate/cog/pkg/serving"
)

func newTestCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "test",
Short: "Test the model locally",
RunE: Test,
Args: cobra.NoArgs,
}
cmd.Flags().StringP("arch", "a", "cpu", "Test architecture")

return cmd
}

func Test(cmd *cobra.Command, args []string) error {
arch, err := cmd.Flags().GetString("arch")
if err != nil {
return err
}
projectDir, err := os.Getwd()
if err != nil {
return err
}
logWriter := logger.NewConsoleLogger()

configPath := filepath.Join(projectDir, global.ConfigFilename)
exists, err := files.Exists(configPath)
if err != nil {
return err
}
if !exists {
return fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, projectDir)
}
configRaw, err := os.ReadFile(filepath.Join(projectDir, global.ConfigFilename))
if err != nil {
return fmt.Errorf("Failed to read %s: %w", global.ConfigFilename, err)
}
config, err := model.ConfigFromYAML(configRaw)
if err != nil {
return err
}
if err := config.ValidateAndCompleteConfig(); err != nil {
return err
}
archMap := map[string]bool{}
for _, confArch := range config.Environment.Architectures {
archMap[confArch] = true
}
if _, ok := archMap[arch]; !ok {
return fmt.Errorf("Architecture %s is not defined for model", arch)
}
generator := &docker.DockerfileGenerator{Config: config, Arch: arch}
dockerfileContents, err := generator.Generate()
if err != nil {
return fmt.Errorf("Failed to generate Dockerfile for %s: %w", arch, err)
}
dockerImageBuilder := docker.NewLocalImageBuilder("")
servingPlatform, err := serving.NewLocalDockerPlatform()
if err != nil {
return err
}
tag, err := dockerImageBuilder.Build(projectDir, dockerfileContents, "", logWriter)
if err != nil {
return fmt.Errorf("Failed to build Docker image: %w", err)
}

if _, err := serving.TestModel(servingPlatform, tag, config, projectDir, logWriter); err != nil {
return err
}

return nil
}
3 changes: 2 additions & 1 deletion pkg/docker/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ import (
)

type ImageBuilder interface {
BuildAndPush(dir string, dockerfilePath string, name string, logWriter logger.Logger) (fullImageTag string, err error)
Build(dir string, dockerfileContents string, name string, logWriter logger.Logger) (tag string, err error)
Push(tag string, logWriter logger.Logger) error
}
56 changes: 31 additions & 25 deletions pkg/docker/local_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"

"github.com/replicate/cog/pkg/console"

"github.com/replicate/cog/pkg/logger"
"github.com/replicate/cog/pkg/shell"
)
Expand All @@ -27,26 +27,16 @@ func NewLocalImageBuilder(registry string) *LocalImageBuilder {
return &LocalImageBuilder{registry: registry}
}

func (b *LocalImageBuilder) BuildAndPush(dir string, dockerfilePath string, name string, logWriter logger.Logger) (fullImageTag string, err error) {
tag, err := b.build(dir, dockerfilePath, logWriter)
if err != nil {
return "", err
}
fullImageTag = fmt.Sprintf("%s/%s:%s", b.registry, name, tag)
if err := b.tag(tag, fullImageTag, logWriter); err != nil {
return "", err
}
if b.registry != noRegistry {
if err := b.push(fullImageTag, logWriter); err != nil {
return "", err
}
}
return fullImageTag, nil
}

func (b *LocalImageBuilder) build(dir string, dockerfilePath string, logWriter logger.Logger) (tag string, err error) {
func (b *LocalImageBuilder) Build(dir string, dockerfileContents string, name string, logWriter logger.Logger) (tag string, err error) {
console.Debugf("Building in %s", dir)

// TODO(andreas): pipe dockerfile contents to builder
relDockerfilePath := "Dockerfile"
dockerfilePath := filepath.Join(dir, relDockerfilePath)
if err := os.WriteFile(dockerfilePath, []byte(dockerfileContents), 0644); err != nil {
return "", fmt.Errorf("Failed to write Dockerfile")
}

cmd := exec.Command(
"docker", "build", ".",
"--progress", "plain",
Expand Down Expand Up @@ -76,25 +66,41 @@ func (b *LocalImageBuilder) build(dir string, dockerfilePath string, logWriter l

dockerTag := <-tagChan

if err != nil {
return "", err
}

logWriter.Infof("Successfully built %s", dockerTag)

return dockerTag, err
tag = dockerTag
if name != "" {
tag = fmt.Sprintf("%s/%s:%s", b.registry, name, dockerTag)
if err := b.tag(dockerTag, tag, logWriter); err != nil {
return "", err
}
}

return tag, nil
}

func (b *LocalImageBuilder) tag(tag string, fullImageTag string, logWriter logger.Logger) error {
console.Debugf("Tagging %s as %s", tag, fullImageTag)
func (b *LocalImageBuilder) tag(dockerTag string, tag string, logWriter logger.Logger) error {
console.Debugf("Tagging %s as %s", dockerTag, tag)

cmd := exec.Command("docker", "tag", tag, fullImageTag)
cmd := exec.Command("docker", "tag", dockerTag, tag)
cmd.Env = os.Environ()
if _, err := cmd.Output(); err != nil {
ee := err.(*exec.ExitError)
stderr := string(ee.Stderr)
return fmt.Errorf("Failed to tag %s as %s, got error: %s", tag, fullImageTag, stderr)
return fmt.Errorf("Failed to tag %s as %s, got error: %s", dockerTag, tag, stderr)
}
return nil
}

func (b *LocalImageBuilder) push(tag string, logWriter logger.Logger) error {
func (b *LocalImageBuilder) Push(tag string, logWriter logger.Logger) error {
if b.registry == noRegistry {
return nil
}

logWriter.Infof("Pushing %s to registry", tag)

args := []string{"push", tag}
Expand Down
16 changes: 14 additions & 2 deletions pkg/model/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
_ "embed"
"encoding/json"
"fmt"
"runtime"
"sort"
"strings"

Expand Down Expand Up @@ -251,9 +252,10 @@ func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err
func torchCPUPackage(ver string) (name string, cpuVersion string, indexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchVersion() == ver && compat.CUDA == nil {
return "torch", compat.Torch, compat.IndexURL, nil
return "torch", torchStripCPUSuffixForM1(compat.Torch), compat.IndexURL, nil
}
}

return "", "", "", fmt.Errorf("No matching Torch CPU package for version %s", ver)
}

Expand Down Expand Up @@ -297,7 +299,7 @@ func torchGPUPackage(ver string, cuda string) (name string, cpuVersion string, i
func torchvisionCPUPackage(ver string) (name string, cpuVersion string, indexURL string, err error) {
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
return "torchvision", compat.Torchvision, compat.IndexURL, nil
return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision), compat.IndexURL, nil
}
}
return "", "", "", fmt.Errorf("No matching torchvision CPU package for version %s", ver)
Expand Down Expand Up @@ -339,3 +341,13 @@ func torchvisionGPUPackage(ver string, cuda string) (name string, cpuVersion str

return "torchvision", latest.Torchvision, latest.IndexURL, nil
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
// TODO(andreas): clean up this hack by actually parsing the torch_stable.html list in the generator
func torchStripCPUSuffixForM1(version string) string {
// TODO(andreas): clean up this hack
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return strings.ReplaceAll(version, "+cpu", "")
}
return version
}
Loading