Skip to content

Commit ecea1a2

Browse files
Local test command
Signed-off-by: andreasjansson <[email protected]>
1 parent 26a60eb commit ecea1a2

File tree

11 files changed

+273
-144
lines changed

11 files changed

+273
-144
lines changed

pkg/cli/benchmark.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ func runBenchmarkInference(mod *model.Model, modelDir string, results *Benchmark
102102

103103
logWriter := logger.NewConsoleLogger()
104104
bootStart := time.Now()
105-
deployment, err := servingPlatform.Deploy(mod, benchmarkTarget, logWriter)
105+
artifact, ok := mod.ArtifactFor(benchmarkTarget)
106+
if !ok {
107+
return fmt.Errorf("Target %s is not defined for model", benchmarkTarget)
108+
}
109+
deployment, err := servingPlatform.Deploy(artifact.URI, logWriter)
106110
if err != nil {
107111
return err
108112
}

pkg/cli/infer.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ func cmdInfer(cmd *cobra.Command, args []string) error {
6363
}
6464
logWriter := logger.NewConsoleLogger()
6565
// TODO(andreas): GPU inference
66-
deployment, err := servingPlatform.Deploy(mod, model.TargetDockerCPU, logWriter)
66+
artifact, ok := mod.ArtifactFor(model.TargetDockerCPU)
67+
if !ok {
68+
return fmt.Errorf("Target %s is not defined for model", model.TargetDockerCPU)
69+
}
70+
deployment, err := servingPlatform.Deploy(artifact.URI, logWriter)
6771
if err != nil {
6872
return err
6973
}

pkg/cli/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ func NewRootCommand() (*cobra.Command, error) {
3636

3737
rootCmd.AddCommand(
3838
newBuildCommand(),
39+
newTestCommand(),
3940
newDebugCommand(),
4041
newInferCommand(),
4142
newServerCommand(),

pkg/cli/test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package cli
2+
3+
import (
4+
"fmt"
5+
"path/filepath"
6+
"os"
7+
8+
"github.com/spf13/cobra"
9+
10+
"github.com/replicate/cog/pkg/docker"
11+
"github.com/replicate/cog/pkg/files"
12+
"github.com/replicate/cog/pkg/global"
13+
"github.com/replicate/cog/pkg/logger"
14+
"github.com/replicate/cog/pkg/model"
15+
"github.com/replicate/cog/pkg/serving"
16+
)
17+
18+
func newTestCommand() *cobra.Command {
19+
cmd := &cobra.Command{
20+
Use: "test",
21+
Short: "Test the model locally",
22+
RunE: Test,
23+
Args: cobra.NoArgs,
24+
}
25+
cmd.Flags().StringP("arch", "a", "cpu", "Test architecture")
26+
27+
return cmd
28+
}
29+
30+
func Test(cmd *cobra.Command, args []string) error {
31+
arch, err := cmd.Flags().GetString("arch")
32+
if err != nil {
33+
return err
34+
}
35+
projectDir, err := os.Getwd()
36+
if err != nil {
37+
return err
38+
}
39+
logWriter := logger.NewConsoleLogger()
40+
41+
configPath := filepath.Join(projectDir, global.ConfigFilename)
42+
exists, err := files.Exists(configPath)
43+
if err != nil {
44+
return err
45+
}
46+
if !exists {
47+
return fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, projectDir)
48+
}
49+
configRaw, err := os.ReadFile(filepath.Join(projectDir, global.ConfigFilename))
50+
if err != nil {
51+
return fmt.Errorf("Failed to read %s: %w", global.ConfigFilename, err)
52+
}
53+
config, err := model.ConfigFromYAML(configRaw)
54+
if err != nil {
55+
return err
56+
}
57+
if err := config.ValidateAndCompleteConfig(); err != nil {
58+
return err
59+
}
60+
archMap := map[string]bool{}
61+
for _, confArch := range config.Environment.Architectures {
62+
archMap[confArch] = true
63+
}
64+
if _, ok := archMap[arch]; !ok {
65+
return fmt.Errorf("Architecture %s is not defined for model", arch)
66+
}
67+
generator := &docker.DockerfileGenerator{Config: config, Arch: arch}
68+
dockerfileContents, err := generator.Generate()
69+
if err != nil {
70+
return fmt.Errorf("Failed to generate Dockerfile for %s: %w", arch, err)
71+
}
72+
dockerImageBuilder := docker.NewLocalImageBuilder("")
73+
servingPlatform, err := serving.NewLocalDockerPlatform()
74+
if err != nil {
75+
return err
76+
}
77+
tag, err := dockerImageBuilder.Build(projectDir, dockerfileContents, "", logWriter)
78+
if err != nil {
79+
return fmt.Errorf("Failed to build Docker image: %w", err)
80+
}
81+
82+
if _, err := serving.TestModel(servingPlatform, tag, config, projectDir, logWriter); err != nil {
83+
return err
84+
}
85+
86+
return nil
87+
}

pkg/docker/builder.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ import (
55
)
66

77
type ImageBuilder interface {
8-
BuildAndPush(dir string, dockerfilePath string, name string, logWriter logger.Logger) (fullImageTag string, err error)
8+
Build(dir string, dockerfileContents string, name string, logWriter logger.Logger) (tag string, err error)
9+
Push(tag string, logWriter logger.Logger) error
910
}

pkg/docker/local_builder.go

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ import (
55
"fmt"
66
"os"
77
"os/exec"
8+
"path/filepath"
89
"regexp"
910
"strings"
1011

1112
"github.com/replicate/cog/pkg/console"
12-
1313
"github.com/replicate/cog/pkg/logger"
1414
"github.com/replicate/cog/pkg/shell"
1515
)
@@ -27,26 +27,16 @@ func NewLocalImageBuilder(registry string) *LocalImageBuilder {
2727
return &LocalImageBuilder{registry: registry}
2828
}
2929

30-
func (b *LocalImageBuilder) BuildAndPush(dir string, dockerfilePath string, name string, logWriter logger.Logger) (fullImageTag string, err error) {
31-
tag, err := b.build(dir, dockerfilePath, logWriter)
32-
if err != nil {
33-
return "", err
34-
}
35-
fullImageTag = fmt.Sprintf("%s/%s:%s", b.registry, name, tag)
36-
if err := b.tag(tag, fullImageTag, logWriter); err != nil {
37-
return "", err
38-
}
39-
if b.registry != noRegistry {
40-
if err := b.push(fullImageTag, logWriter); err != nil {
41-
return "", err
42-
}
43-
}
44-
return fullImageTag, nil
45-
}
46-
47-
func (b *LocalImageBuilder) build(dir string, dockerfilePath string, logWriter logger.Logger) (tag string, err error) {
30+
func (b *LocalImageBuilder) Build(dir string, dockerfileContents string, name string, logWriter logger.Logger) (tag string, err error) {
4831
console.Debugf("Building in %s", dir)
4932

33+
// TODO(andreas): pipe dockerfile contents to builder
34+
relDockerfilePath := "Dockerfile"
35+
dockerfilePath := filepath.Join(dir, relDockerfilePath)
36+
if err := os.WriteFile(dockerfilePath, []byte(dockerfileContents), 0644); err != nil {
37+
return "", fmt.Errorf("Failed to write Dockerfile")
38+
}
39+
5040
cmd := exec.Command(
5141
"docker", "build", ".",
5242
"--progress", "plain",
@@ -76,25 +66,41 @@ func (b *LocalImageBuilder) build(dir string, dockerfilePath string, logWriter l
7666

7767
dockerTag := <-tagChan
7868

69+
if err != nil {
70+
return "", err
71+
}
72+
7973
logWriter.Infof("Successfully built %s", dockerTag)
8074

81-
return dockerTag, err
75+
tag = dockerTag
76+
if name != "" {
77+
tag = fmt.Sprintf("%s/%s:%s", b.registry, name, dockerTag)
78+
if err := b.tag(dockerTag, tag, logWriter); err != nil {
79+
return "", err
80+
}
81+
}
82+
83+
return tag, nil
8284
}
8385

84-
func (b *LocalImageBuilder) tag(tag string, fullImageTag string, logWriter logger.Logger) error {
85-
console.Debugf("Tagging %s as %s", tag, fullImageTag)
86+
func (b *LocalImageBuilder) tag(dockerTag string, tag string, logWriter logger.Logger) error {
87+
console.Debugf("Tagging %s as %s", dockerTag, tag)
8688

87-
cmd := exec.Command("docker", "tag", tag, fullImageTag)
89+
cmd := exec.Command("docker", "tag", dockerTag, tag)
8890
cmd.Env = os.Environ()
8991
if _, err := cmd.Output(); err != nil {
9092
ee := err.(*exec.ExitError)
9193
stderr := string(ee.Stderr)
92-
return fmt.Errorf("Failed to tag %s as %s, got error: %s", tag, fullImageTag, stderr)
94+
return fmt.Errorf("Failed to tag %s as %s, got error: %s", dockerTag, tag, stderr)
9395
}
9496
return nil
9597
}
9698

97-
func (b *LocalImageBuilder) push(tag string, logWriter logger.Logger) error {
99+
func (b *LocalImageBuilder) Push(tag string, logWriter logger.Logger) error {
100+
if b.registry == noRegistry {
101+
return nil
102+
}
103+
98104
logWriter.Infof("Pushing %s to registry", tag)
99105

100106
args := []string{"push", tag}

pkg/model/compatibility.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
_ "embed"
55
"encoding/json"
66
"fmt"
7+
"runtime"
78
"sort"
89
"strings"
910

@@ -251,9 +252,10 @@ func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err
251252
func torchCPUPackage(ver string) (name string, cpuVersion string, indexURL string, err error) {
252253
for _, compat := range TorchCompatibilityMatrix {
253254
if compat.TorchVersion() == ver && compat.CUDA == nil {
254-
return "torch", compat.Torch, compat.IndexURL, nil
255+
return "torch", torchStripCPUSuffixForM1(compat.Torch), compat.IndexURL, nil
255256
}
256257
}
258+
257259
return "", "", "", fmt.Errorf("No matching Torch CPU package for version %s", ver)
258260
}
259261

@@ -297,7 +299,7 @@ func torchGPUPackage(ver string, cuda string) (name string, cpuVersion string, i
297299
func torchvisionCPUPackage(ver string) (name string, cpuVersion string, indexURL string, err error) {
298300
for _, compat := range TorchCompatibilityMatrix {
299301
if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
300-
return "torchvision", compat.Torchvision, compat.IndexURL, nil
302+
return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision), compat.IndexURL, nil
301303
}
302304
}
303305
return "", "", "", fmt.Errorf("No matching torchvision CPU package for version %s", ver)
@@ -339,3 +341,13 @@ func torchvisionGPUPackage(ver string, cuda string) (name string, cpuVersion str
339341

340342
return "torchvision", latest.Torchvision, latest.IndexURL, nil
341343
}
344+
345+
// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
346+
// TODO(andreas): clean up this hack by actually parsing the torch_stable.html list in the generator
347+
func torchStripCPUSuffixForM1(version string) string {
348+
// TODO(andreas): clean up this hack
349+
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
350+
return strings.ReplaceAll(version, "+cpu", "")
351+
}
352+
return version
353+
}

0 commit comments

Comments
 (0)