@@ -13,12 +13,14 @@ import (
1313 "github.com/spf13/cobra"
1414
1515 "github.com/replicate/cog/pkg/client"
16+ "github.com/replicate/cog/pkg/docker"
1617 "github.com/replicate/cog/pkg/logger"
1718 "github.com/replicate/cog/pkg/model"
1819 "github.com/replicate/cog/pkg/serving"
1920 "github.com/replicate/cog/pkg/util/console"
2021 "github.com/replicate/cog/pkg/util/mime"
2122 "github.com/replicate/cog/pkg/util/slices"
23+ "github.com/replicate/cog/pkg/util/terminal"
2224)
2325
2426var (
@@ -29,10 +31,15 @@ var (
2931
3032func newPredictCommand () * cobra.Command {
3133 cmd := & cobra.Command {
32- Use : "predict <id>" ,
33- Short : "Run a single prediction against a version of a model" ,
34+ Use : "predict [version id]" ,
35+ Short : "Run a prediction on a version" ,
36+ Long : `Run a prediction on a version.
37+
38+ If 'version id' is passed, it will run the prediction on that version of the
39+ model. Otherwise, it will build the model in the current directory and run
40+ the prediction on that.` ,
3441 RunE : cmdPredict ,
35- Args : cobra .MinimumNArgs (1 ),
42+ Args : cobra .MaximumNArgs (1 ),
3643 SuggestFor : []string {"infer" },
3744 }
3845 addModelFlag (cmd )
@@ -48,59 +55,114 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
4855 return fmt .Errorf ("--arch must be either 'cpu' or 'gpu'" )
4956 }
5057
51- mod , err := getModel ()
52- if err != nil {
53- return err
54- }
58+ ui := terminal .ConsoleUI (context .Background ())
59+ defer ui .Close ()
5560
56- id := args [0 ]
61+ useGPU := predictArch == "gpu"
62+ dockerImageName := ""
5763
58- client := client .NewClient ()
59- fmt .Println ("Loading package" , id )
60- version , err := client .GetVersion (mod , id )
61- if err != nil {
62- return err
63- }
64- // TODO(bfirsh): differentiate between failed builds and in-progress builds, and probably block here if there is an in-progress build
65- image := model .ImageForArch (version .Images , predictArch )
66- if image == nil {
67- return fmt .Errorf ("No %s image has been built for %s:%s" , predictArch , mod .String (), id )
64+ if len (args ) == 0 {
65+ // Local
66+
67+ config , projectDir , err := getConfig ()
68+ if err != nil {
69+ return err
70+ }
71+ logWriter := logger .NewTerminalLogger (ui , "Building Docker image from environment in cog.yaml... " )
72+ generator := docker .NewDockerfileGenerator (config , predictArch , projectDir )
73+ defer func () {
74+ if err := generator .Cleanup (); err != nil {
75+ ui .Output (fmt .Sprintf ("Error cleaning up Dockerfile generator: %s" , err ))
76+ }
77+ }()
78+ dockerfileContents , err := generator .Generate ()
79+ if err != nil {
80+ return fmt .Errorf ("Failed to generate Dockerfile for %s: %w" , predictArch , err )
81+ }
82+ dockerImageBuilder := docker .NewLocalImageBuilder ("" )
83+ dockerImageName , err = dockerImageBuilder .Build (context .Background (), projectDir , dockerfileContents , "" , useGPU , logWriter )
84+ if err != nil {
85+ return fmt .Errorf ("Failed to build Docker image: %w" , err )
86+ }
87+
88+ logWriter .Done ()
89+
90+ } else {
91+ // Remote
92+
93+ id := args [0 ]
94+ mod , err := getModel ()
95+ if err != nil {
96+ return err
97+ }
98+ client := client .NewClient ()
99+ st := ui .Status ()
100+ defer st .Close ()
101+ st .Update ("Loading version " + id )
102+ version , err := client .GetVersion (mod , id )
103+ st .Step (terminal .StatusOK , "Loaded version " + id )
104+ if err != nil {
105+ return err
106+ }
107+ image := model .ImageForArch (version .Images , predictArch )
108+ // TODO(bfirsh): differentiate between failed builds and in-progress builds, and probably block here if there is an in-progress build
109+ if image == nil {
110+ return fmt .Errorf ("No %s image has been built for %s:%s" , predictArch , mod .String (), id )
111+ }
112+ dockerImageName = image .URI
68113 }
69114
115+ st := ui .Status ()
116+ defer st .Close ()
117+ st .Update (fmt .Sprintf ("Starting Docker image %s and running setup()..." , dockerImageName ))
70118 servingPlatform , err := serving .NewLocalDockerPlatform ()
71119 if err != nil {
120+ st .Step (terminal .StatusError , "Failed to start model: " + err .Error ())
72121 return err
73122 }
74123 logWriter := logger .NewConsoleLogger ()
75- useGPU := predictArch == "gpu"
76- deployment , err := servingPlatform .Deploy (context .Background (), image .URI , useGPU , logWriter )
124+ deployment , err := servingPlatform .Deploy (context .Background (), dockerImageName , useGPU , logWriter )
77125 if err != nil {
126+ st .Step (terminal .StatusError , "Failed to start model: " + err .Error ())
78127 return err
79128 }
80129 defer func () {
81130 if err := deployment .Undeploy (); err != nil {
82131 console .Warnf ("Failed to kill Docker container: %s" , err )
83132 }
84133 }()
134+ st .Step (terminal .StatusOK , fmt .Sprintf ("Model running in Docker image %s" , dockerImageName ))
85135
86- return predictIndividualInputs (deployment , inputs , outPath , logWriter )
136+ return predictIndividualInputs (ui , deployment , inputs , outPath , logWriter )
87137}
88138
89- func predictIndividualInputs (deployment serving.Deployment , inputs []string , outputPath string , logWriter logger.Logger ) error {
139+ func predictIndividualInputs (ui terminal.UI , deployment serving.Deployment , inputs []string , outputPath string , logWriter logger.Logger ) error {
140+ st := ui .Status ()
141+ defer st .Close ()
142+ st .Update ("Running prediction..." )
90143 example := parsePredictInputs (inputs )
91144 result , err := deployment .RunPrediction (context .Background (), example , logWriter )
92145 if err != nil {
146+ st .Step (terminal .StatusError , "Failed to run prediction: " + err .Error ())
93147 return err
94148 }
149+ st .Close ()
150+
95151 // TODO(andreas): support multiple outputs?
96152 output := result .Values ["output" ]
97153
154+ ui .Output ("" )
155+
98156 // Write to stdout
99157 if outputPath == "" {
100158 // Is it something we can sensibly write to stdout?
101159 if output .MimeType == "text/plain" {
102- _ , err := io .Copy (os .Stdout , output .Buffer )
103- return err
160+ output , err := io .ReadAll (output .Buffer )
161+ if err != nil {
162+ return err
163+ }
164+ ui .Output (string (output ))
165+ return nil
104166 } else if output .MimeType == "application/json" {
105167 var obj map [string ]interface {}
106168 dec := json .NewDecoder (output .Buffer )
@@ -110,7 +172,7 @@ func predictIndividualInputs(deployment serving.Deployment, inputs []string, out
110172 f := colorjson .NewFormatter ()
111173 f .Indent = 2
112174 s , _ := f .Marshal (obj )
113- fmt . Println (string (s ))
175+ ui . Output (string (s ))
114176 return nil
115177 }
116178 // Otherwise, fall back to writing file
@@ -139,7 +201,7 @@ func predictIndividualInputs(deployment serving.Deployment, inputs []string, out
139201 return err
140202 }
141203
142- fmt . Println ("Written output to " + outputPath )
204+ ui . Output ("Written output to " + outputPath )
143205 return nil
144206}
145207
0 commit comments