Skip to content

Commit 5d69e16

Browse files
committed
Rename Model.run() to Model.predict()
Signed-off-by: Ben Firshman <[email protected]>
1 parent 3c44200 commit 5d69e16

File tree

6 files changed

+38
-38
lines changed

6 files changed

+38
-38
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ColorizationModel(cog.Model):
2525
self.model = torch.load("./weights.pth")
2626

2727
@cog.input("input", type=Path, help="Grayscale input image")
28-
def run(self, input):
28+
def predict(self, input):
2929
# ... pre-processing ...
3030
output = self.model(processed_input)
3131
# ... post-processing ...

docs/python.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class HelloWorldModel(cog.Model):
1010
self.prefix = "hello "
1111

1212
@cog.input("text", type=str, help="Text that will get prefixed by 'hello '")
13-
def run(self, text):
13+
def predict(self, text):
1414
return self.prefix + text
1515
```
1616

end-to-end-test/end_to_end_test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def setup(self):
7070
@cog.input("text", type=str)
7171
@cog.input("path", type=Path)
7272
@cog.input("output_file", type=bool, default=False)
73-
def run(self, text, path, output_file):
73+
def predict(self, text, path, output_file):
7474
time.sleep(1)
7575
with open(path) as f:
7676
output = self.foo + text + f.read()

end-to-end-test/end_to_end_test/test_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_server_end_to_end(cog_server, project_dir, tmpdir_factory):
101101
).communicate()
102102
paths = sorted(glob(str(download_dir / "*.*")))
103103
filenames = [os.path.basename(f) for f in paths]
104-
assert filenames == ["cog.yaml", "predict.py", "myfile.txt"]
104+
assert filenames == ["cog.yaml", "myfile.txt", "predict.py"]
105105

106106
with open(download_dir / "cog-example-output/output.02.txt") as f:
107107
assert f.read() == "fooquxbaz"

pkg/docker/cog.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def setup(self):
3939
pass
4040

4141
@abstractmethod
42-
def run(self, **kwargs):
42+
def predict(self, **kwargs):
4343
pass
4444

4545

@@ -70,7 +70,7 @@ def handle_request():
7070
)
7171
raw_inputs[key] = val
7272

73-
if hasattr(self.model.run, "_inputs"):
73+
if hasattr(self.model.predict, "_inputs"):
7474
try:
7575
inputs = validate_and_convert_inputs(
7676
self.model, raw_inputs, cleanup_functions
@@ -97,8 +97,8 @@ def ping():
9797
@app.route("/help")
9898
def help():
9999
args = {}
100-
if hasattr(self.model.run, "_inputs"):
101-
input_specs = self.model.run._inputs
100+
if hasattr(self.model.predict, "_inputs"):
101+
input_specs = self.model.predict._inputs
102102
for name, spec in input_specs.items():
103103
arg: Dict[str, Any] = {
104104
"type": _type_name(spec.type),
@@ -186,8 +186,8 @@ def ping():
186186
@app.route("/help")
187187
def help():
188188
args = {}
189-
if hasattr(self.model.run, "_inputs"):
190-
input_specs = self.model.run._inputs
189+
if hasattr(self.model.predict, "_inputs"):
190+
input_specs = self.model.predict._inputs
191191
for name, spec in input_specs.items():
192192
arg = {
193193
"type": _type_name(spec.type),
@@ -420,7 +420,7 @@ def upload_to_temp(self, path: Path) -> str:
420420
def validate_and_convert_inputs(
421421
model: Model, raw_inputs: Dict[str, Any], cleanup_functions: List[Callable]
422422
) -> Dict[str, Any]:
423-
input_specs = model.run._inputs
423+
input_specs = model.predict._inputs
424424
inputs = {}
425425

426426
for name, input_spec in input_specs.items():
@@ -516,7 +516,7 @@ def run_model(model, inputs, cleanup_functions):
516516
Run the model on the inputs, and append resulting paths
517517
to cleanup functions for removal.
518518
"""
519-
result = model.run(**inputs)
519+
result = model.predict(**inputs)
520520
if isinstance(result, Path):
521521
cleanup_functions.append(result.unlink)
522522
return result

pkg/docker/cog_test.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Model(cog.Model):
2626
def setup(self):
2727
self.foo = "foo"
2828

29-
def run(self):
29+
def predict(self):
3030
return self.foo + "bar"
3131

3232
client = make_client(Model())
@@ -41,7 +41,7 @@ def setup(self):
4141
self.foo = "foo"
4242

4343
@cog.input("text", type=str)
44-
def run(self, text):
44+
def predict(self, text):
4545
return self.foo + text
4646

4747
client = make_client(Model())
@@ -56,7 +56,7 @@ def setup(self):
5656
self.foo = "foo"
5757

5858
@cog.input("text", type=str)
59-
def run(self, text):
59+
def predict(self, text):
6060
return self.foo + text
6161

6262
client = make_client(Model())
@@ -73,7 +73,7 @@ def setup(self):
7373
self.foo = "foo"
7474

7575
@cog.input("text", type=str)
76-
def run(self, bad):
76+
def predict(self, bad):
7777
return self.foo + "bar"
7878

7979

@@ -83,7 +83,7 @@ def setup(self):
8383
self.foo = "foo"
8484

8585
@cog.input("num", type=int)
86-
def run(self, num):
86+
def predict(self, num):
8787
num2 = num ** 3
8888
return self.foo + str(num2)
8989

@@ -102,7 +102,7 @@ def setup(self):
102102
self.foo = "foo"
103103

104104
@cog.input("num", type=int)
105-
def run(self, num):
105+
def predict(self, num):
106106
num2 = num ** 2
107107
return self.foo + str(num2)
108108

@@ -117,7 +117,7 @@ def setup(self):
117117
self.foo = "foo"
118118

119119
@cog.input("num", type=int, default=5)
120-
def run(self, num):
120+
def predict(self, num):
121121
num2 = num ** 2
122122
return self.foo + str(num2)
123123

@@ -136,7 +136,7 @@ def setup(self):
136136
self.foo = "foo"
137137

138138
@cog.input("num", type=float)
139-
def run(self, num):
139+
def predict(self, num):
140140
num2 = num ** 3
141141
return self.foo + str(num2)
142142

@@ -158,7 +158,7 @@ def setup(self):
158158
self.foo = "foo"
159159

160160
@cog.input("num", type=float)
161-
def run(self, num):
161+
def predict(self, num):
162162
num2 = num ** 2
163163
return self.foo + str(num2)
164164

@@ -173,7 +173,7 @@ def setup(self):
173173
self.foo = "foo"
174174

175175
@cog.input("flag", type=bool)
176-
def run(self, flag):
176+
def predict(self, flag):
177177
if flag:
178178
return self.foo + "yes"
179179
else:
@@ -194,7 +194,7 @@ def setup(self):
194194
self.foo = "foo"
195195

196196
@cog.input("flag", type=bool)
197-
def run(self, flag):
197+
def predict(self, flag):
198198
if flag:
199199
return self.foo + "yes"
200200
else:
@@ -213,7 +213,7 @@ def setup(self):
213213
@cog.input("num1", type=float, min=3, max=10.5)
214214
@cog.input("num2", type=float, min=-4)
215215
@cog.input("num3", type=int, max=-4)
216-
def run(self, num1, num2, num3):
216+
def predict(self, num1, num2, num3):
217217
return num1 + num2 + num3
218218

219219
client = make_client(Model())
@@ -235,7 +235,7 @@ def setup(self):
235235

236236
@cog.input("text", type=str, options=["foo", "bar"])
237237
@cog.input("num", type=int, options=[1, 2, 3])
238-
def run(self, text, num):
238+
def predict(self, text, num):
239239
return text + ("a" * num)
240240

241241
client = make_client(Model())
@@ -252,7 +252,7 @@ def setup(self):
252252
pass
253253

254254
@cog.input("text", type=str, options=[])
255-
def run(self, text):
255+
def predict(self, text):
256256
return text
257257

258258
with pytest.raises(ValueError):
@@ -262,7 +262,7 @@ def setup(self):
262262
pass
263263

264264
@cog.input("text", type=str, options=["foo"])
265-
def run(self, text):
265+
def predict(self, text):
266266
return text
267267

268268
with pytest.raises(ValueError):
@@ -272,7 +272,7 @@ def setup(self):
272272
pass
273273

274274
@cog.input("text", type=Path, options=["foo"])
275-
def run(self, text):
275+
def predict(self, text):
276276
return text
277277

278278

@@ -283,7 +283,7 @@ def setup(self):
283283

284284
@cog.input("text", type=str, options=["foo", "bar"])
285285
@cog.input("num", type=int, options=[1, 2, 3])
286-
def run(self, text, num):
286+
def predict(self, text, num):
287287
return text + ("a" * num)
288288

289289
client = make_client(Model())
@@ -299,7 +299,7 @@ def setup(self):
299299
self.foo = "foo"
300300

301301
@cog.input("path", type=Path)
302-
def run(self, path):
302+
def predict(self, path):
303303
with open(path) as f:
304304
return self.foo + " " + f.read() + " " + os.path.basename(path)
305305

@@ -318,7 +318,7 @@ def setup(self):
318318
self.foo = "foo"
319319

320320
@cog.input("path", type=Path)
321-
def run(self, path):
321+
def predict(self, path):
322322
with open(path) as f:
323323
return self.foo + " " + f.read() + " " + os.path.basename(path)
324324

@@ -333,7 +333,7 @@ def setup(self):
333333
self.foo = "foo"
334334

335335
@cog.input("path", type=Path, default=None)
336-
def run(self, path):
336+
def predict(self, path):
337337
if path is None:
338338
return "noneee"
339339
with open(path) as f:
@@ -357,7 +357,7 @@ def setup(self):
357357
self.foo = "foo"
358358

359359
@cog.input("text", type=str)
360-
def run(self, text):
360+
def predict(self, text):
361361
temp_dir = tempfile.mkdtemp()
362362
temp_path = os.path.join(temp_dir, "my_file.txt")
363363
with open(temp_path, "w") as f:
@@ -376,7 +376,7 @@ class Model(cog.Model):
376376
def setup(self):
377377
pass
378378

379-
def run(self):
379+
def predict(self):
380380
temp_dir = tempfile.mkdtemp()
381381
temp_path = os.path.join(temp_dir, "my_file.bmp")
382382
img = Image.new("RGB", (255, 255), "red")
@@ -400,7 +400,7 @@ def setup(self):
400400
@cog.input("num1", type=int)
401401
@cog.input("num2", type=int, default=10)
402402
@cog.input("path", type=Path)
403-
def run(self, text, num1, num2, path):
403+
def predict(self, text, num1, num2, path):
404404
with open(path) as f:
405405
path_contents = f.read()
406406
return self.foo + " " + text + " " + str(num1 * num2) + " " + path_contents
@@ -425,7 +425,7 @@ def setup(self):
425425
@cog.input("num1", type=int, help="First number")
426426
@cog.input("num2", type=int, default=10, help="Second number")
427427
@cog.input("path", type=Path, help="A file path")
428-
def run(self, text, num1, num2, path):
428+
def predict(self, text, num1, num2, path):
429429
with open(path) as f:
430430
path_contents = f.read()
431431
return self.foo + " " + text + " " + str(num1 * num2) + " " + path_contents
@@ -461,15 +461,15 @@ class ModelSlow(cog.Model):
461461
def setup(self):
462462
time.sleep(0.5)
463463

464-
def run(self):
464+
def predict(self):
465465
time.sleep(0.5)
466466
return ""
467467

468468
class ModelFast(cog.Model):
469469
def setup(self):
470470
pass
471471

472-
def run(self):
472+
def predict(self):
473473
return ""
474474

475475
client = make_client(ModelSlow())

0 commit comments

Comments
 (0)