Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
)
from .types import Secret as CogSecret

if PYDANTIC_V2:
from pydantic.fields import PydanticUndefined # type: ignore
else:
from pydantic.fields import Undefined as PydanticUndefined

log = structlog.get_logger("cog.server.predictor")

ALLOWED_INPUT_TYPES: List[Type[Any]] = [
Expand Down Expand Up @@ -178,17 +183,27 @@ def load_predictor_from_ref(ref: str) -> BasePredictor:
return predictor


def is_union(type: Type[Any]) -> bool:
if get_origin(type) is Union:
return True
if hasattr(types, "UnionType") and get_origin(type) is types.UnionType:
return True
return False


def is_optional(type: Type[Any]) -> bool:
args = get_args(type)
if len(args) != 2 or not is_union(type):
return False
if sys.version_info >= (3, 10):
return args[1] is NoneType
return args[1] is None.__class__


def validate_input_type(
type: Type[Any], # pylint: disable=redefined-builtin
name: str,
) -> None:
def is_union(type: Type[Any]) -> bool:
if get_origin(type) is Union:
return True
if hasattr(types, "UnionType") and get_origin(type) is types.UnionType:
return True
return False

if type is inspect.Signature.empty:
raise TypeError(
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
Expand All @@ -199,15 +214,7 @@ def is_union(type: Type[Any]) -> bool:
validate_input_type(builtins.type(t), name)
elif get_origin(type) in (Union, List, list) or is_union(type): # noqa: E721
args = get_args(type)

def is_optional() -> bool:
if len(args) != 2 or not is_union(type):
return False
if sys.version_info >= (3, 10):
return args[1] is NoneType
return args[1] is None.__class__

if is_optional():
if is_optional(type):
validate_input_type(args[0], name)
else:
for t in args:
Expand Down Expand Up @@ -240,6 +247,17 @@ def get_input_create_model_kwargs(signature: inspect.Signature) -> Dict[str, Any
if not isinstance(parameter.default, FieldInfo):
default = Input(default=parameter.default)
else:
if is_optional(InputType):
# If we are an optional, make sure the default is None
if (
parameter.default.default is PydanticUndefined
or parameter.default.default is ...
):
if PYDANTIC_V2:
parameter.default.default = None
else:
parameter.default.default_factory = None
parameter.default.default = None
default = parameter.default

if PYDANTIC_V2:
Expand Down
31 changes: 30 additions & 1 deletion python/tests/test_predictor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
import inspect
import os
import sys
from typing import Optional
from unittest.mock import patch

from cog import File, Path
from pydantic.fields import FieldInfo

from cog import File, Input, Path
from cog.predictor import (
get_input_create_model_kwargs,
get_predict,
get_weights_type,
load_predictor_from_ref,
)
from cog.types import PYDANTIC_V2

if PYDANTIC_V2:
from pydantic.fields import PydanticUndefined
else:
from pydantic.fields import Undefined as PydanticUndefined


def is_field_required(field: FieldInfo):
if hasattr(field, "is_required"):
return field.is_required()
if hasattr(field, "required"):
return field.required
return field.default is PydanticUndefined and field.default_factory is None


def test_get_weights_type() -> None:
Expand Down Expand Up @@ -42,6 +61,16 @@ def test_load_predictor_from_ref_overrides_argv():
assert sys.argv == ["foo.py", "exec", "--giraffes=2", "--eat-cookies"]


def test_get_input_create_model_kwargs():
def predict(thing: Optional[str] = Input(description="Hello String.")) -> str:
return thing if thing is not None else "Nothing"

predict_type = get_predict(predict)
signature = inspect.signature(predict_type)
output = get_input_create_model_kwargs(signature)
assert not is_field_required(output["thing"][1])


def _fixture_path(name):
test_dir = os.path.dirname(os.path.realpath(__file__))
return os.path.join(test_dir, f"fixtures/{name}.py") + ":Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import BasePredictor, Input
from typing import Optional


class Predictor(BasePredictor):
def predict(self, s: Optional[str] = Input(description="Hello String.")) -> str:
return "hello " + (s if s is not None else "No One")
15 changes: 15 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,18 @@ def test_predict_with_fast_build_with_local_image(docker_image):
os.remove(weights_file)
assert build_process.returncode == 0
assert result.returncode == 0


def test_predict_optional_project(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/optional-project"
result = subprocess.run(
["cog", "predict", "--debug"],
cwd=project_dir,
check=True,
capture_output=True,
text=True,
timeout=DEFAULT_TIMEOUT,
)
# stdout should be clean without any log messages so it can be piped to other commands
assert result.returncode == 0
assert result.stdout == "hello No One\n"
Loading