Skip to content

Commit 5ea2ce1

Browse files
committed
Add support for local predictions
Signed-off-by: Ben Firshman <[email protected]>
1 parent 7eff4ab commit 5ea2ce1

File tree

4 files changed

+135
-29
lines changed

4 files changed

+135
-29
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import subprocess
2+
3+
4+
def test_predict(tmpdir_factory):
5+
tmpdir = tmpdir_factory.mktemp("project")
6+
with open(tmpdir / "predict.py", "w") as f:
7+
f.write(
8+
"""
9+
import cog
10+
11+
class Model(cog.Model):
12+
def setup(self):
13+
pass
14+
15+
@cog.input("input", type=str)
16+
def predict(self, input):
17+
return "hello " + input
18+
"""
19+
)
20+
with open(tmpdir / "cog.yaml", "w") as f:
21+
cog_yaml = """
22+
model: "predict.py:Model"
23+
environment:
24+
python: "3.8"
25+
"""
26+
f.write(cog_yaml)
27+
28+
result = subprocess.run(
29+
["cog", "predict", "-i", "world"], cwd=tmpdir, check=True, capture_output=True
30+
)
31+
assert b"hello world" in result.stdout

pkg/cli/predict.go

Lines changed: 88 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ import (
1313
"github.com/spf13/cobra"
1414

1515
"github.com/replicate/cog/pkg/client"
16+
"github.com/replicate/cog/pkg/docker"
1617
"github.com/replicate/cog/pkg/logger"
1718
"github.com/replicate/cog/pkg/model"
1819
"github.com/replicate/cog/pkg/serving"
1920
"github.com/replicate/cog/pkg/util/console"
2021
"github.com/replicate/cog/pkg/util/mime"
2122
"github.com/replicate/cog/pkg/util/slices"
23+
"github.com/replicate/cog/pkg/util/terminal"
2224
)
2325

2426
var (
@@ -29,10 +31,15 @@ var (
2931

3032
func newPredictCommand() *cobra.Command {
3133
cmd := &cobra.Command{
32-
Use: "predict <id>",
33-
Short: "Run a single prediction against a version of a model",
34+
Use: "predict [version id]",
35+
Short: "Run a prediction on a version",
36+
Long: `Run a prediction on a version.
37+
38+
If 'version id' is passed, it will run the prediction on that version of the
39+
model. Otherwise, it will build the model in the current directory and run
40+
the prediction on that.`,
3441
RunE: cmdPredict,
35-
Args: cobra.MinimumNArgs(1),
42+
Args: cobra.MaximumNArgs(1),
3643
SuggestFor: []string{"infer"},
3744
}
3845
addModelFlag(cmd)
@@ -48,59 +55,114 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
4855
return fmt.Errorf("--arch must be either 'cpu' or 'gpu'")
4956
}
5057

51-
mod, err := getModel()
52-
if err != nil {
53-
return err
54-
}
58+
ui := terminal.ConsoleUI(context.Background())
59+
defer ui.Close()
5560

56-
id := args[0]
61+
useGPU := predictArch == "gpu"
62+
dockerImageName := ""
5763

58-
client := client.NewClient()
59-
fmt.Println("Loading package", id)
60-
version, err := client.GetVersion(mod, id)
61-
if err != nil {
62-
return err
63-
}
64-
// TODO(bfirsh): differentiate between failed builds and in-progress builds, and probably block here if there is an in-progress build
65-
image := model.ImageForArch(version.Images, predictArch)
66-
if image == nil {
67-
return fmt.Errorf("No %s image has been built for %s:%s", predictArch, mod.String(), id)
64+
if len(args) == 0 {
65+
// Local
66+
67+
config, projectDir, err := getConfig()
68+
if err != nil {
69+
return err
70+
}
71+
logWriter := logger.NewTerminalLogger(ui, "Building Docker image from environment in cog.yaml... ")
72+
generator := docker.NewDockerfileGenerator(config, predictArch, projectDir)
73+
defer func() {
74+
if err := generator.Cleanup(); err != nil {
75+
ui.Output(fmt.Sprintf("Error cleaning up Dockerfile generator: %s", err))
76+
}
77+
}()
78+
dockerfileContents, err := generator.Generate()
79+
if err != nil {
80+
return fmt.Errorf("Failed to generate Dockerfile for %s: %w", predictArch, err)
81+
}
82+
dockerImageBuilder := docker.NewLocalImageBuilder("")
83+
dockerImageName, err = dockerImageBuilder.Build(context.Background(), projectDir, dockerfileContents, "", useGPU, logWriter)
84+
if err != nil {
85+
return fmt.Errorf("Failed to build Docker image: %w", err)
86+
}
87+
88+
logWriter.Done()
89+
90+
} else {
91+
// Remote
92+
93+
id := args[0]
94+
mod, err := getModel()
95+
if err != nil {
96+
return err
97+
}
98+
client := client.NewClient()
99+
st := ui.Status()
100+
defer st.Close()
101+
st.Update("Loading version " + id)
102+
version, err := client.GetVersion(mod, id)
103+
st.Step(terminal.StatusOK, "Loaded version "+id)
104+
if err != nil {
105+
return err
106+
}
107+
image := model.ImageForArch(version.Images, predictArch)
108+
// TODO(bfirsh): differentiate between failed builds and in-progress builds, and probably block here if there is an in-progress build
109+
if image == nil {
110+
return fmt.Errorf("No %s image has been built for %s:%s", predictArch, mod.String(), id)
111+
}
112+
dockerImageName = image.URI
68113
}
69114

115+
st := ui.Status()
116+
defer st.Close()
117+
st.Update(fmt.Sprintf("Starting Docker image %s and running setup()...", dockerImageName))
70118
servingPlatform, err := serving.NewLocalDockerPlatform()
71119
if err != nil {
120+
st.Step(terminal.StatusError, "Failed to start model: "+err.Error())
72121
return err
73122
}
74123
logWriter := logger.NewConsoleLogger()
75-
useGPU := predictArch == "gpu"
76-
deployment, err := servingPlatform.Deploy(context.Background(), image.URI, useGPU, logWriter)
124+
deployment, err := servingPlatform.Deploy(context.Background(), dockerImageName, useGPU, logWriter)
77125
if err != nil {
126+
st.Step(terminal.StatusError, "Failed to start model: "+err.Error())
78127
return err
79128
}
80129
defer func() {
81130
if err := deployment.Undeploy(); err != nil {
82131
console.Warnf("Failed to kill Docker container: %s", err)
83132
}
84133
}()
134+
st.Step(terminal.StatusOK, fmt.Sprintf("Model running in Docker image %s", dockerImageName))
85135

86-
return predictIndividualInputs(deployment, inputs, outPath, logWriter)
136+
return predictIndividualInputs(ui, deployment, inputs, outPath, logWriter)
87137
}
88138

89-
func predictIndividualInputs(deployment serving.Deployment, inputs []string, outputPath string, logWriter logger.Logger) error {
139+
func predictIndividualInputs(ui terminal.UI, deployment serving.Deployment, inputs []string, outputPath string, logWriter logger.Logger) error {
140+
st := ui.Status()
141+
defer st.Close()
142+
st.Update("Running prediction...")
90143
example := parsePredictInputs(inputs)
91144
result, err := deployment.RunPrediction(context.Background(), example, logWriter)
92145
if err != nil {
146+
st.Step(terminal.StatusError, "Failed to run prediction: "+err.Error())
93147
return err
94148
}
149+
st.Close()
150+
95151
// TODO(andreas): support multiple outputs?
96152
output := result.Values["output"]
97153

154+
ui.Output("")
155+
98156
// Write to stdout
99157
if outputPath == "" {
100158
// Is it something we can sensibly write to stdout?
101159
if output.MimeType == "text/plain" {
102-
_, err := io.Copy(os.Stdout, output.Buffer)
103-
return err
160+
output, err := io.ReadAll(output.Buffer)
161+
if err != nil {
162+
return err
163+
}
164+
ui.Output(string(output))
165+
return nil
104166
} else if output.MimeType == "application/json" {
105167
var obj map[string]interface{}
106168
dec := json.NewDecoder(output.Buffer)
@@ -110,7 +172,7 @@ func predictIndividualInputs(deployment serving.Deployment, inputs []string, out
110172
f := colorjson.NewFormatter()
111173
f.Indent = 2
112174
s, _ := f.Marshal(obj)
113-
fmt.Println(string(s))
175+
ui.Output(string(s))
114176
return nil
115177
}
116178
// Otherwise, fall back to writing file
@@ -139,7 +201,7 @@ func predictIndividualInputs(deployment serving.Deployment, inputs []string, out
139201
return err
140202
}
141203

142-
fmt.Println("Written output to " + outputPath)
204+
ui.Output("Written output to " + outputPath)
143205
return nil
144206
}
145207

pkg/docker/generate.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ func (g *DockerfileGenerator) GenerateBase() (string, error) {
9494
}
9595

9696
func (g *DockerfileGenerator) Generate() (string, error) {
97+
if g.Config.Model == "" {
98+
return "", fmt.Errorf("'model' option is not set in cog.yaml")
99+
}
100+
97101
base, err := g.GenerateBase()
98102
if err != nil {
99103
return "", err

pkg/serving/local.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func NewLocalDockerPlatform() (*LocalDockerPlatform, error) {
5151
func (p *LocalDockerPlatform) Deploy(ctx context.Context, imageTag string, useGPU bool, logWriter logger.Logger) (Deployment, error) {
5252
// TODO(andreas): output container logs
5353

54-
logWriter.Infof("Deploying container %s", imageTag)
54+
logWriter.Debugf("Deploying container %s", imageTag)
5555

5656
if !docker.Exists(imageTag, logWriter) {
5757
if err := docker.Pull(imageTag, logWriter); err != nil {
@@ -106,6 +106,15 @@ func (p *LocalDockerPlatform) Deploy(ctx context.Context, imageTag string, useGP
106106
DeviceRequests: deviceRequests,
107107
},
108108
}
109+
// TODO(bfirsh): support mounting local code
110+
// if mountCodePath != "" {
111+
// hostConfig.Mounts = []mount.Mount{{
112+
// Type: mount.TypeBind,
113+
// Source: mountCodePath,
114+
// Target: "/code",
115+
// ReadOnly: false,
116+
// }}
117+
// }
109118
resp, err := p.client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, "")
110119
if err != nil {
111120
return nil, fmt.Errorf("Failed to create Docker container for image %s: %w", imageTag, err)
@@ -134,7 +143,7 @@ func (p *LocalDockerPlatform) waitForContainerReady(ctx context.Context, hostPor
134143
url := fmt.Sprintf("http://localhost:%d/ping", hostPort)
135144

136145
start := time.Now()
137-
logWriter.Info("Waiting for model to become accessible")
146+
logWriter.Debug("Waiting for model to become accessible")
138147
for {
139148
now := time.Now()
140149
if now.Sub(start) > global.StartupTimeout {
@@ -158,7 +167,7 @@ func (p *LocalDockerPlatform) waitForContainerReady(ctx context.Context, hostPor
158167
if resp.StatusCode != http.StatusOK {
159168
continue
160169
}
161-
logWriter.Info("Got successful ping response from container")
170+
logWriter.Debug("Got successful ping response from container")
162171
return nil
163172
}
164173
}

0 commit comments

Comments
 (0)