Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def load_predictor_from_ref(ref: str) -> BasePredictor:
return predictor


def is_none(type_arg: Any) -> bool:
if sys.version_info >= (3, 10):
return type_arg is NoneType
return type_arg is None.__class__


def is_union(type: Type[Any]) -> bool:
if get_origin(type) is Union:
return True
Expand All @@ -195,9 +201,7 @@ def is_optional(type: Type[Any]) -> bool:
args = get_args(type)
if len(args) != 2 or not is_union(type):
return False
if sys.version_info >= (3, 10):
return args[1] is NoneType
return args[1] is None.__class__
return is_none(args[1])


def validate_input_type(
Expand Down
131 changes: 115 additions & 16 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,26 @@
import threading
import uuid
from types import TracebackType
from typing import Any, BinaryIO, Callable, Dict, List, Sequence, TextIO, Union
from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Sequence,
TextIO,
Union,
get_args,
get_type_hints,
)

import pydantic
from fastapi import FastAPI
from fastapi.routing import APIRoute
from pydantic import BaseModel
from typing_extensions import Self # added to typing in python 3.11

from ..predictor import is_none, is_optional
from ..types import PYDANTIC_V2
from .errors import CogRuntimeError, CogTimeoutError

Expand Down Expand Up @@ -361,12 +376,87 @@ def unwrap_pydantic_serialization_iterators(obj: Any) -> Any:
return [unwrap_pydantic_serialization_iterators(value) for value in obj]
return obj

else:

def get_annotations(tp) -> dict[str, Any]:
if sys.version_info >= (3, 10):
return get_type_hints(tp)
return tp.__annotations__

def is_pydantic_model_type(tp) -> bool:
try:
return isinstance(tp, type) and issubclass(tp, BaseModel)
except TypeError:
return False

def update_nullable_optional(openapi_schema: Dict[str, Any], app: FastAPI) -> None:
def fetch_referenced_schema(schema: Dict[str, Any], ref: str) -> Dict[str, Any]:
input_path = ref.replace("#/", "").split("/")
referenced_schema = schema
while input_path:
referenced_schema = referenced_schema[input_path[0]]
input_path = input_path[1:]
return referenced_schema

for route in app.routes:
if not isinstance(route, APIRoute):
continue

for dep in route.dependant.body_params:
model = getattr(dep, "type_", None)
if not is_pydantic_model_type(model):
continue
input_model_union = get_annotations(model).get("input")
if input_model_union is None:
continue
input_model = get_args(input_model_union)[0]
schema_node = openapi_schema["components"]["schemas"].get(
model.__name__
)
referenced_schema = fetch_referenced_schema(
openapi_schema, schema_node["properties"]["input"]["$ref"]
)
for k, v in referenced_schema["properties"].items():
annotated_type = get_annotations(input_model)[k]
if is_optional(annotated_type):
v["nullable"] = True

response_model = getattr(route, "response_model", None)
if is_pydantic_model_type(response_model):
output_model_union = get_annotations(response_model).get("output")
if output_model_union is None:
continue
output_model = get_args(output_model_union)[0]
schema_node = openapi_schema["components"]["schemas"].get(
output_model.__name__
)
root = get_annotations(output_model).get("__root__")
for type_arg in get_args(root):
if not is_none(type_arg):
continue
schema_node["nullable"] = True
break
for count, type_node in enumerate(schema_node.get("anyOf", [])):
ref_node = type_node.get("$ref")
if ref_node is None:
continue
referenced_schema = fetch_referenced_schema(
openapi_schema, ref_node
)
output_model = get_args(root)[count]
for k, v in referenced_schema["properties"].items():
annotated_type = get_annotations(output_model)[k]
if is_optional(annotated_type):
v["nullable"] = True

return None


def update_openapi_schema_for_pydantic_2(
openapi_schema: Dict[str, Any],
) -> None:
_remove_webhook_events_filter_title(openapi_schema)
_remove_empty_or_nullable_anyof(openapi_schema)
_update_nullable_anyof(openapi_schema)
_flatten_selected_allof_refs(openapi_schema)
_extract_enum_properties(openapi_schema)
_set_default_enumeration_description(openapi_schema)
Expand All @@ -384,27 +474,36 @@ def _remove_webhook_events_filter_title(
pass


def _remove_empty_or_nullable_anyof(
def _update_nullable_anyof(
openapi_schema: Union[Dict[str, Any], List[Dict[str, Any]]],
in_header: Union[bool, None] = None,
) -> None:
# Version 3.0.X of OpenAPI doesn't support a `null` type, expecting
# `nullable` to be set instead.
if isinstance(openapi_schema, dict):
if in_header is None:
if "in" in openapi_schema:
in_header = openapi_schema.get("in") == "header"
for key, value in list(openapi_schema.items()):
if key == "anyOf" and isinstance(value, list):
non_null_types = [item for item in value if item.get("type") != "null"]
if len(non_null_types) == 0:
del openapi_schema[key]
elif len(non_null_types) == 1:
openapi_schema.update(non_null_types[0])
del openapi_schema[key]
if key != "anyOf" or not isinstance(value, list):
_update_nullable_anyof(value, in_header=in_header)
continue

non_null_items = [item for item in value if item.get("type") != "null"]
if len(non_null_items) == 0:
del openapi_schema[key]
elif len(non_null_items) == 1:
openapi_schema.update(non_null_items[0])
del openapi_schema[key]
else:
openapi_schema[key] = non_null_items

# FIXME: Update tests to expect nullable
# openapi_schema["nullable"] = True
if len(non_null_items) < len(value) and not in_header:
openapi_schema["nullable"] = True

else:
_remove_empty_or_nullable_anyof(value)
elif isinstance(openapi_schema, list): # pyright: ignore
elif isinstance(openapi_schema, list):
for item in openapi_schema:
_remove_empty_or_nullable_anyof(item)
_update_nullable_anyof(item, in_header=in_header)


def _flatten_selected_allof_refs(
Expand Down
4 changes: 4 additions & 0 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
unwrap_pydantic_serialization_iterators,
update_openapi_schema_for_pydantic_2,
)
else:
from .helpers import update_nullable_optional

from .probes import ProbeHelper
from .runner import (
Expand Down Expand Up @@ -136,6 +138,8 @@ def custom_openapi() -> Dict[str, Any]:
# See: https://github.com/tiangolo/fastapi/pull/9873#issuecomment-1997105091
if PYDANTIC_V2:
update_openapi_schema_for_pydantic_2(openapi_schema)
else:
update_nullable_optional(openapi_schema, app)

app.openapi_schema = openapi_schema

Expand Down
4 changes: 3 additions & 1 deletion python/tests/server/fixtures/openapi_complex_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cog import BasePredictor, File, Input, Path
from typing import Optional

from cog import BaseModel, BasePredictor, File, Input, Path

class Predictor(BasePredictor):
def predict(
Expand All @@ -11,5 +12,6 @@ def predict(
image: File = Input(description="Some path"),
choices: str = Input(choices=["foo", "bar"]),
int_choices: int = Input(choices=[3, 4, 5]),
optional_str: Optional[str] = Input(),
) -> str:
pass
16 changes: 16 additions & 0 deletions python/tests/server/fixtures/openapi_optional_output_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional, Union

from cog import BasePredictor
from pydantic import BaseModel


class ModelOutput(BaseModel):
foo_number: int = "42"
foo_string: Optional[str] = None


class Predictor(BasePredictor):
def predict(
self,
) -> Union[ModelOutput, Optional[str]]:
pass
96 changes: 75 additions & 21 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def test_openapi_specification(client):
"in": "header",
"name": "prefer",
"required": False,
"schema": {"title": "Prefer", "type": "string"},
"schema": {
"title": "Prefer",
"type": "string",
},
}
],
"requestBody": {
Expand Down Expand Up @@ -164,45 +167,59 @@ def test_openapi_specification(client):
],
"type": "object",
"properties": {
"no_default": {
"title": "No Default",
"type": "string",
"x-order": 0,
"choices": {
"allOf": [
{
"$ref": "#/components/schemas/choices",
}
],
"x-order": 5,
},
"default_without_input": {
"title": "Default Without Input",
"type": "string",
"default": "default",
"x-order": 1,
},
"image": {
"title": "Image",
"description": "Some path",
"type": "string",
"format": "uri",
"x-order": 4,
},
"input_with_default": {
"title": "Input With Default",
"type": "integer",
"default": -10,
"x-order": 2,
},
"int_choices": {
"allOf": [
{
"$ref": "#/components/schemas/int_choices",
}
],
"x-order": 6,
},
"no_default": {
"title": "No Default",
"type": "string",
"x-order": 0,
},
"optional_str": {
"nullable": True,
"title": "Optional Str",
"type": "string",
"x-order": 7,
},
"path": {
"title": "Path",
"description": "Some path",
"type": "string",
"format": "uri",
"x-order": 3,
},
"image": {
"title": "Image",
"description": "Some path",
"type": "string",
"format": "uri",
"x-order": 4,
},
"choices": {
"allOf": [{"$ref": "#/components/schemas/choices"}],
"x-order": 5,
},
"int_choices": {
"allOf": [{"$ref": "#/components/schemas/int_choices"}],
"x-order": 6,
},
},
}
assert schema["components"]["schemas"]["Output"] == {
Expand Down Expand Up @@ -253,14 +270,51 @@ def test_openapi_specification_with_custom_user_defined_output_type(
}


@uses_predictor("openapi_optional_output_type")
def test_openapi_specification_with_optional_output_type(
client,
):
resp = client.get("/openapi.json")
assert resp.status_code == 200
schema = resp.json()
assert schema["components"]["schemas"]["Output"] == {
"anyOf": [
{
"$ref": "#/components/schemas/ModelOutput",
},
{
"type": "string",
},
],
"nullable": True,
"title": "Output",
}
assert schema["components"]["schemas"]["ModelOutput"] == {
"properties": {
"foo_number": {
"default": "42",
"title": "Foo Number",
"type": "integer",
},
"foo_string": {
"title": "Foo String",
"type": "string",
"nullable": True,
},
},
"type": "object",
"title": "ModelOutput",
}


@uses_predictor("openapi_output_type")
def test_openapi_specification_with_custom_user_defined_output_type_called_output(
client,
):
resp = client.get("/openapi.json")
assert resp.status_code == 200
schema = resp.json()
assert resp.json()["components"]["schemas"]["Output"] == {
assert schema["components"]["schemas"]["Output"] == {
"properties": {
"foo_number": {"default": "42", "title": "Foo Number", "type": "integer"},
"foo_string": {
Expand Down