diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 6c73530c8b..f86753420f 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -183,6 +183,12 @@ def load_predictor_from_ref(ref: str) -> BasePredictor: return predictor +def is_none(type_arg: Any) -> bool: + if sys.version_info >= (3, 10): + return type_arg is NoneType + return type_arg is None.__class__ + + def is_union(type: Type[Any]) -> bool: if get_origin(type) is Union: return True @@ -195,9 +201,7 @@ def is_optional(type: Type[Any]) -> bool: args = get_args(type) if len(args) != 2 or not is_union(type): return False - if sys.version_info >= (3, 10): - return args[1] is NoneType - return args[1] is None.__class__ + return is_none(args[1]) def validate_input_type( diff --git a/python/cog/server/helpers.py b/python/cog/server/helpers.py index 990fa49cb0..c7af410aa2 100644 --- a/python/cog/server/helpers.py +++ b/python/cog/server/helpers.py @@ -10,11 +10,26 @@ import threading import uuid from types import TracebackType -from typing import Any, BinaryIO, Callable, Dict, List, Sequence, TextIO, Union +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + List, + Sequence, + TextIO, + Union, + get_args, + get_type_hints, +) import pydantic +from fastapi import FastAPI +from fastapi.routing import APIRoute +from pydantic import BaseModel from typing_extensions import Self # added to typing in python 3.11 +from ..predictor import is_none, is_optional from ..types import PYDANTIC_V2 from .errors import CogRuntimeError, CogTimeoutError @@ -361,12 +376,87 @@ def unwrap_pydantic_serialization_iterators(obj: Any) -> Any: return [unwrap_pydantic_serialization_iterators(value) for value in obj] return obj +else: + + def get_annotations(tp) -> dict[str, Any]: + if sys.version_info >= (3, 10): + return get_type_hints(tp) + return tp.__annotations__ + + def is_pydantic_model_type(tp) -> bool: + try: + return isinstance(tp, type) and issubclass(tp, BaseModel) + except TypeError: + return False + + def update_nullable_optional(openapi_schema: Dict[str, Any], app: FastAPI) -> None: + def fetch_referenced_schema(schema: Dict[str, Any], ref: str) -> Dict[str, Any]: + input_path = ref.replace("#/", "").split("/") + referenced_schema = schema + while input_path: + referenced_schema = referenced_schema[input_path[0]] + input_path = input_path[1:] + return referenced_schema + + for route in app.routes: + if not isinstance(route, APIRoute): + continue + + for dep in route.dependant.body_params: + model = getattr(dep, "type_", None) + if not is_pydantic_model_type(model): + continue + input_model_union = get_annotations(model).get("input") + if input_model_union is None: + continue + input_model = get_args(input_model_union)[0] + schema_node = openapi_schema["components"]["schemas"].get( + model.__name__ + ) + referenced_schema = fetch_referenced_schema( + openapi_schema, schema_node["properties"]["input"]["$ref"] + ) + for k, v in referenced_schema["properties"].items(): + annotated_type = get_annotations(input_model)[k] + if is_optional(annotated_type): + v["nullable"] = True + + response_model = getattr(route, "response_model", None) + if is_pydantic_model_type(response_model): + output_model_union = get_annotations(response_model).get("output") + if output_model_union is None: + continue + output_model = get_args(output_model_union)[0] + schema_node = openapi_schema["components"]["schemas"].get( + output_model.__name__ + ) + root = get_annotations(output_model).get("__root__") + for type_arg in get_args(root): + if not is_none(type_arg): + continue + schema_node["nullable"] = True + break + for count, type_node in enumerate(schema_node.get("anyOf", [])): + ref_node = type_node.get("$ref") + if ref_node is None: + continue + referenced_schema = fetch_referenced_schema( + openapi_schema, ref_node + ) + output_model = get_args(root)[count] + for k, v in referenced_schema["properties"].items(): + annotated_type = get_annotations(output_model)[k] + if is_optional(annotated_type): + v["nullable"] = True + + return None + def update_openapi_schema_for_pydantic_2( openapi_schema: Dict[str, Any], ) -> None: _remove_webhook_events_filter_title(openapi_schema) - _remove_empty_or_nullable_anyof(openapi_schema) + _update_nullable_anyof(openapi_schema) _flatten_selected_allof_refs(openapi_schema) _extract_enum_properties(openapi_schema) _set_default_enumeration_description(openapi_schema) @@ -384,27 +474,36 @@ def _remove_webhook_events_filter_title( pass -def _remove_empty_or_nullable_anyof( +def _update_nullable_anyof( openapi_schema: Union[Dict[str, Any], List[Dict[str, Any]]], + in_header: Union[bool, None] = None, ) -> None: + # Version 3.0.X of OpenAPI doesn't support a `null` type, expecting + # `nullable` to be set instead. if isinstance(openapi_schema, dict): + if in_header is None: + if "in" in openapi_schema: + in_header = openapi_schema.get("in") == "header" for key, value in list(openapi_schema.items()): - if key == "anyOf" and isinstance(value, list): - non_null_types = [item for item in value if item.get("type") != "null"] - if len(non_null_types) == 0: - del openapi_schema[key] - elif len(non_null_types) == 1: - openapi_schema.update(non_null_types[0]) - del openapi_schema[key] + if key != "anyOf" or not isinstance(value, list): + _update_nullable_anyof(value, in_header=in_header) + continue + + non_null_items = [item for item in value if item.get("type") != "null"] + if len(non_null_items) == 0: + del openapi_schema[key] + elif len(non_null_items) == 1: + openapi_schema.update(non_null_items[0]) + del openapi_schema[key] + else: + openapi_schema[key] = non_null_items - # FIXME: Update tests to expect nullable - # openapi_schema["nullable"] = True + if len(non_null_items) < len(value) and not in_header: + openapi_schema["nullable"] = True - else: - _remove_empty_or_nullable_anyof(value) - elif isinstance(openapi_schema, list): # pyright: ignore + elif isinstance(openapi_schema, list): for item in openapi_schema: - _remove_empty_or_nullable_anyof(item) + _update_nullable_anyof(item, in_header=in_header) def _flatten_selected_allof_refs( diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 51f5e972b8..65d583c0f5 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -41,6 +41,8 @@ unwrap_pydantic_serialization_iterators, update_openapi_schema_for_pydantic_2, ) +else: + from .helpers import update_nullable_optional from .probes import ProbeHelper from .runner import ( @@ -136,6 +138,8 @@ def custom_openapi() -> Dict[str, Any]: # See: https://github.com/tiangolo/fastapi/pull/9873#issuecomment-1997105091 if PYDANTIC_V2: update_openapi_schema_for_pydantic_2(openapi_schema) + else: + update_nullable_optional(openapi_schema, app) app.openapi_schema = openapi_schema diff --git a/python/tests/server/fixtures/openapi_complex_input.py b/python/tests/server/fixtures/openapi_complex_input.py index 98b218f935..ab3d5b310e 100644 --- a/python/tests/server/fixtures/openapi_complex_input.py +++ b/python/tests/server/fixtures/openapi_complex_input.py @@ -1,5 +1,6 @@ -from cog import BasePredictor, File, Input, Path +from typing import Optional +from cog import BaseModel, BasePredictor, File, Input, Path class Predictor(BasePredictor): def predict( @@ -11,5 +12,6 @@ def predict( image: File = Input(description="Some path"), choices: str = Input(choices=["foo", "bar"]), int_choices: int = Input(choices=[3, 4, 5]), + optional_str: Optional[str] = Input(), ) -> str: pass diff --git a/python/tests/server/fixtures/openapi_optional_output_type.py b/python/tests/server/fixtures/openapi_optional_output_type.py new file mode 100644 index 0000000000..3e22b1f461 --- /dev/null +++ b/python/tests/server/fixtures/openapi_optional_output_type.py @@ -0,0 +1,16 @@ +from typing import Optional, Union + +from cog import BasePredictor +from pydantic import BaseModel + + +class ModelOutput(BaseModel): + foo_number: int = "42" + foo_string: Optional[str] = None + + +class Predictor(BasePredictor): + def predict( + self, + ) -> Union[ModelOutput, Optional[str]]: + pass diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 6321cecd0e..34ab304538 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -88,7 +88,10 @@ def test_openapi_specification(client): "in": "header", "name": "prefer", "required": False, - "schema": {"title": "Prefer", "type": "string"}, + "schema": { + "title": "Prefer", + "type": "string", + }, } ], "requestBody": { @@ -164,10 +167,13 @@ def test_openapi_specification(client): ], "type": "object", "properties": { - "no_default": { - "title": "No Default", - "type": "string", - "x-order": 0, + "choices": { + "allOf": [ + { + "$ref": "#/components/schemas/choices", + } + ], + "x-order": 5, }, "default_without_input": { "title": "Default Without Input", @@ -175,12 +181,38 @@ def test_openapi_specification(client): "default": "default", "x-order": 1, }, + "image": { + "title": "Image", + "description": "Some path", + "type": "string", + "format": "uri", + "x-order": 4, + }, "input_with_default": { "title": "Input With Default", "type": "integer", "default": -10, "x-order": 2, }, + "int_choices": { + "allOf": [ + { + "$ref": "#/components/schemas/int_choices", + } + ], + "x-order": 6, + }, + "no_default": { + "title": "No Default", + "type": "string", + "x-order": 0, + }, + "optional_str": { + "nullable": True, + "title": "Optional Str", + "type": "string", + "x-order": 7, + }, "path": { "title": "Path", "description": "Some path", @@ -188,21 +220,6 @@ def test_openapi_specification(client): "format": "uri", "x-order": 3, }, - "image": { - "title": "Image", - "description": "Some path", - "type": "string", - "format": "uri", - "x-order": 4, - }, - "choices": { - "allOf": [{"$ref": "#/components/schemas/choices"}], - "x-order": 5, - }, - "int_choices": { - "allOf": [{"$ref": "#/components/schemas/int_choices"}], - "x-order": 6, - }, }, } assert schema["components"]["schemas"]["Output"] == { @@ -253,6 +270,43 @@ def test_openapi_specification_with_custom_user_defined_output_type( } +@uses_predictor("openapi_optional_output_type") +def test_openapi_specification_with_optional_output_type( + client, +): + resp = client.get("/openapi.json") + assert resp.status_code == 200 + schema = resp.json() + assert schema["components"]["schemas"]["Output"] == { + "anyOf": [ + { + "$ref": "#/components/schemas/ModelOutput", + }, + { + "type": "string", + }, + ], + "nullable": True, + "title": "Output", + } + assert schema["components"]["schemas"]["ModelOutput"] == { + "properties": { + "foo_number": { + "default": "42", + "title": "Foo Number", + "type": "integer", + }, + "foo_string": { + "title": "Foo String", + "type": "string", + "nullable": True, + }, + }, + "type": "object", + "title": "ModelOutput", + } + + @uses_predictor("openapi_output_type") def test_openapi_specification_with_custom_user_defined_output_type_called_output( client, @@ -260,7 +314,7 @@ def test_openapi_specification_with_custom_user_defined_output_type_called_outpu resp = client.get("/openapi.json") assert resp.status_code == 200 schema = resp.json() - assert resp.json()["components"]["schemas"]["Output"] == { + assert schema["components"]["schemas"]["Output"] == { "properties": { "foo_number": {"default": "42", "title": "Foo Number", "type": "integer"}, "foo_string": {