Skip to content

Commit 0f4cf3d

Browse files
authored
Support kwargs predictors (#2345)
This change allows a predictor input to be defined exclusively as `**kwargs`. This is an experimental change that allows a model to be used as a proxy where any inputs will be passed through directly without modification. ```py import json def predict(**kwargs) -> str: return json.dumps(kwargs) ``` This change restricts the use of `**kwargs` to functions/methods that accept only that as an argument to avoid the complexity of having to support both dynamic and fixed inputs. The OpenAPI schema for such a model will be: ```json { "properties": {}, "additionalProperties": true, "type": "object", "title": "Input" } ```
1 parent 79a6fa8 commit 0f4cf3d

File tree

4 files changed

+47
-5
lines changed

4 files changed

+47
-5
lines changed

.tool-versions

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
golang 1.24.2
2+
python 3.13.2

python/cog/predictor.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,38 @@ def validate_input_type(
231231

232232

233233
def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any]:
234-
create_model_kwargs = {}
234+
create_model_kwargs: Dict[str, Any] = {
235+
"__base__": BaseInput,
236+
"__config__": None,
237+
}
235238

236239
order = 0
237240

238241
for name, parameter in signature.parameters.items():
239242
InputType = parameter.annotation
240243

244+
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
245+
raise TypeError(f"Unsupported variadic positional parameter *{name}.")
246+
247+
if parameter.kind == inspect.Parameter.VAR_KEYWORD:
248+
if order != 0:
249+
raise TypeError(f"Unsupported variadic keyword parameter **{name}")
250+
251+
class ExtraKeywordInput(BaseInput):
252+
if PYDANTIC_V2:
253+
model_config = pydantic.ConfigDict(extra="allow")
254+
else:
255+
256+
class Config:
257+
extra = "allow"
258+
259+
create_model_kwargs["__base__"] = ExtraKeywordInput
260+
name = "__pydantic_extra__"
261+
InputType = Dict[str, Any]
262+
263+
create_model_kwargs[name] = (InputType, Input())
264+
continue
265+
241266
validate_input_type(InputType, name)
242267

243268
# if no default is specified, create an empty, required input
@@ -325,8 +350,6 @@ class Input(BaseModel):
325350

326351
return create_model(
327352
"Input",
328-
__config__=None,
329-
__base__=BaseInput,
330353
__module__=__name__,
331354
__validators__=None,
332355
**get_input_create_model_kwargs(signature),
@@ -431,8 +454,6 @@ class TrainingInput(BaseModel):
431454

432455
return create_model(
433456
"TrainingInput",
434-
__config__=None,
435-
__base__=BaseInput,
436457
__module__=__name__,
437458
__validators__=None,
438459
**get_input_create_model_kwargs(signature),
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import Dict
2+
3+
from cog import BasePredictor
4+
5+
6+
class Predictor(BasePredictor):
7+
def predict(self, **kwargs) -> Dict:
8+
return kwargs

python/tests/server/test_http_input.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ def test_empty_input(client, match):
3939
assert resp.json() == match({"status": "succeeded", "output": "foobar"})
4040

4141

42+
@uses_predictor("input_kwargs")
43+
def test_kwargs_input(client, match):
44+
"""Check we support kwargs input fields"""
45+
input = {"animal": "giraffe", "no": 5}
46+
resp = client.post("/predictions", json={"input": input})
47+
assert resp.json() == match({"status": "succeeded"})
48+
assert resp.status_code == 200
49+
50+
result = resp.json()["output"]
51+
assert result == input
52+
53+
4254
@uses_predictor("input_integer")
4355
def test_good_int_input(client, match):
4456
resp = client.post("/predictions", json={"input": {"num": 3}})

0 commit comments

Comments
 (0)