Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ def validate_input_type(
type: Type[Any], # pylint: disable=redefined-builtin
name: str,
) -> None:
def is_union(type: Type[Any]) -> bool:
if get_origin(type) is Union:
return True
if hasattr(types, "UnionType") and get_origin(type) is types.UnionType:
return True
return False

if type is inspect.Signature.empty:
raise TypeError(
f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types."
Expand All @@ -190,13 +197,11 @@ def validate_input_type(
if get_origin(type) is Literal:
for t in get_args(type):
validate_input_type(builtins.type(t), name)
elif get_origin(type) in (Union, List, list) or (
hasattr(types, "UnionType") and get_origin(type) is types.UnionType
): # noqa: E721
elif get_origin(type) in (Union, List, list) or is_union(type): # noqa: E721
args = get_args(type)

def is_optional() -> bool:
if len(args) != 2 or get_origin(type) is not Union:
if len(args) != 2 or not is_union(type):
return False
if sys.version_info >= (3, 10):
return args[1] is NoneType
Expand Down
7 changes: 7 additions & 0 deletions python/tests/server/fixtures/input_path_or_none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import BasePredictor, Path


class Predictor(BasePredictor):
def predict(self, file: Path | None) -> str:
print(f"file: {file}")
return "hello"
12 changes: 12 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import os
import sys
import threading
import time

Expand Down Expand Up @@ -333,3 +334,14 @@ def test_input_with_unsupported_type():
assert "TypeError: Unsupported input type input_unsupported_type" in "".join(
app.state.setup_result.logs
)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or newer")
@uses_predictor("input_path_or_none")
def test_path_or_none(client, httpserver, match):
httpserver.expect_request("/foo.txt").respond_with_data("hello")
resp = client.post(
"/predictions",
json={"input": {"file": httpserver.url_for("/foo.txt")}},
)
assert resp.json() == match({"output": "hello", "status": "succeeded"})
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.10"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from cog import BasePredictor, Input


class Predictor(BasePredictor):
def setup(self):
self.prefix = "hello"

def predict(
self,
text: str | None = Input(
description="Text to prefix with 'hello '", default=None
),
) -> str:
return self.prefix + " " + text
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
build:
python_version: "3.10"
python_packages:
- "pillow"

predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from cog import BasePredictor, Path, Input
from PIL import Image


class Predictor(BasePredictor):
def predict(self,test_image: Path | None = Input(description="Test image", default=None)) -> Path:
"""Run a single prediction on the model"""
im = Image.new("RGB", (100, 100), color="red")
im.save(Path("./hello.webp"))
return Path("./hello.webp")

12 changes: 12 additions & 0 deletions test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,15 @@ def test_bad_dockerignore(docker_image):
"The .cog tmp path cannot be ignored by docker in .dockerignore"
in build_process.stderr.decode()
)


def test_pydantic1_none(docker_image):
project_dir = Path(__file__).parent / "fixtures/pydantic1-none"

build_process = subprocess.run(
["cog", "build", "-t", docker_image],
cwd=project_dir,
capture_output=True,
)

assert build_process.returncode == 0
15 changes: 15 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,18 @@ async def make_request(i: int) -> httpx.Response:
for i, task in enumerate(tasks):
assert task.result().status_code == 200
assert task.result().json()["output"] == f"wake up sleepyhead{i}"


def test_predict_new_union_project(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/new-union-project"
result = subprocess.run(
["cog", "predict", "--debug", "-i", "text=world"],
cwd=project_dir,
check=True,
capture_output=True,
text=True,
timeout=DEFAULT_TIMEOUT,
)
# stdout should be clean without any log messages so it can be piped to other commands
assert result.returncode == 0
assert result.stdout == "hello world\n"
Loading