Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 294 additions & 16 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,28 @@
import threading
import uuid
from types import TracebackType
from typing import Any, BinaryIO, Callable, Dict, List, Sequence, TextIO, Union
from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Sequence,
Set,
TextIO,
Type,
Union,
get_args,
get_origin,
)

import pydantic
from fastapi import FastAPI
from fastapi.routing import APIRoute
from pydantic import BaseModel
from typing_extensions import Self # added to typing in python 3.11

from ..predictor import is_optional
from ..types import PYDANTIC_V2
from .errors import CogRuntimeError, CogTimeoutError

Expand Down Expand Up @@ -361,12 +378,268 @@ def unwrap_pydantic_serialization_iterators(obj: Any) -> Any:
return [unwrap_pydantic_serialization_iterators(value) for value in obj]
return obj

else:

def update_nullable_optional(openapi_schema: Dict[str, Any], app: FastAPI) -> None:
def patch_nullable_parameters(openapi_schema: Dict[str, Any]) -> None:
for _, methods in openapi_schema.get("paths", {}).items():
for _, operation in methods.items():
for param in operation.get("parameters", []):
# If the parameter is optional (required: false), make it nullable
if not param.get("required", True):
schema = param.get("schema", {})
if "nullable" not in schema:
schema["nullable"] = True

def patch_nullable_union_outputs(openapi_schema: Dict[str, Any]) -> None:
for _, schema in (
openapi_schema.get("components", {}).get("schemas", {}).items()
):
# Look for anyOf with more than one entry
if (
"anyOf" in schema
and isinstance(schema["anyOf"], list)
and len(schema["anyOf"]) > 1
):
# If it's missing nullable, and it's meant to represent an Optional/Union output
if "nullable" not in schema:
schema["nullable"] = True

def is_pydantic_model_type(tp) -> bool:
try:
return isinstance(tp, type) and issubclass(tp, BaseModel)
except TypeError:
return False

def extract_nullable_fields_recursive(
model: BaseModel, prefix: str = "", is_output: bool = False
) -> Dict[str, bool]:
nullable_map = {}
for field_name, field in model.__fields__.items():
full_field_name = f"{prefix}.{field_name}" if prefix else field_name
type_hint = field.annotation

if is_optional(type_hint) and (
full_field_name.startswith("input.") or is_output
):
nullable_map[full_field_name] = True

inner_type = (
get_args(type_hint)[0] if is_optional(type_hint) else type_hint
)
if is_pydantic_model_type(inner_type):
nested = extract_nullable_fields_recursive(
inner_type, prefix=full_field_name, is_output=is_output
)
nullable_map.update(nested)
return nullable_map

def resolve_schema_ref(
ref: str, openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
parts = ref.lstrip("#/").split("/")
node = openapi_schema
for part in parts:
node = node.get(part, {})
return node

def patch_nullable_fields_for_model(
model: BaseModel,
schema: Dict[str, Any],
openapi_schema: Dict[str, Any],
is_output: bool = False,
) -> None:
nullable_fields = extract_nullable_fields_recursive(
model, is_output=is_output
)

for field_path in nullable_fields:
parts = field_path.split(".")
node = schema

for i, part in enumerate(parts):
if "properties" not in node:
break

prop = node["properties"].get(part)
if prop is None:
break

# Handle nested $ref
if "$ref" in prop:
node = resolve_schema_ref(prop["$ref"], openapi_schema)
else:
node = prop

if i == len(parts) - 1:
node["nullable"] = True

def extract_pydantic_models_from_type(tp) -> Set[Type[BaseModel]]:
"""Recursively extract all Pydantic models from a response_model type."""
models = set()

origin = get_origin(tp)
args = get_args(tp)

if origin is Union or origin is list or origin is List:
for arg in args:
models.update(extract_pydantic_models_from_type(arg))
elif isinstance(tp, type) and issubclass(tp, BaseModel):
models.add(tp)

return models

def collect_nested_models_from_pydantic_model(
model: Type[BaseModel], visited=None
) -> Set[Type[BaseModel]]:
"""Recursively collect all nested Pydantic models inside a given model."""
if visited is None:
visited = set()

if model in visited:
return set()
visited.add(model)

models = {model}
for field in model.__fields__.values():
field_type = field.annotation
origin = get_origin(field_type)

if origin is Union:
args = get_args(field_type)
else:
args = [field_type]

for arg in args:
if is_pydantic_model_type(arg):
models.update(
collect_nested_models_from_pydantic_model(arg, visited)
)

return models

for route in app.routes:
if not isinstance(route, APIRoute):
continue

for method in route.methods:
method = method.lower()
operation = openapi_schema["paths"].get(route.path, {}).get(method, {})
if not operation:
continue

response_model = getattr(route, "response_model", None)
if response_model:
for model in extract_pydantic_models_from_type(response_model):
ref_name = model.__name__
schema_node = (
openapi_schema.get("components", {})
.get("schemas", {})
.get(ref_name)
)
if schema_node:
patch_nullable_fields_for_model(
model,
schema_node,
openapi_schema,
)
# Also patch any properties that reference other models
properties = schema_node.get("properties", {})

all_models = collect_nested_models_from_pydantic_model(
model
)

for field_schema in properties.values():
if "$ref" in field_schema:
ref = field_schema["$ref"]
nested_model_name = ref.split("/")[-1]
nested_model = next(
(
m
for m in all_models
if m.__name__ == nested_model_name
),
None,
)
if nested_model:
nested_schema_node = resolve_schema_ref(
ref, openapi_schema
)
if "anyOf" in nested_schema_node:
for item in nested_schema_node["anyOf"]:
if "$ref" in item:
inner_ref = item["$ref"]
inner_model_name = inner_ref.split(
"/"
)[-1]

inner_model = next(
(
m
for m in all_models
if m.__name__
== inner_model_name
),
None,
)
if inner_model:
actual_schema = (
resolve_schema_ref(
inner_ref,
openapi_schema,
)
)
patch_nullable_fields_for_model(
inner_model,
actual_schema,
openapi_schema,
is_output=True,
)
patch_nullable_fields_for_model(
nested_model,
nested_schema_node,
openapi_schema,
)

request_body = (
operation.get("requestBody", {})
.get("content", {})
.get("application/json", {})
)
schema = request_body.get("schema", {})

for dep in route.dependant.body_params:
model = getattr(dep, "type_", None)
if not model or not issubclass(model, BaseModel):
continue

# Resolve schema node for this model
if "$ref" in schema:
schema_node = resolve_schema_ref(schema["$ref"], openapi_schema)
elif "allOf" in schema:
for item in schema["allOf"]:
if "$ref" in item:
schema_node = resolve_schema_ref(
item["$ref"], openapi_schema
)
break
else:
schema_node = schema
else:
schema_node = schema

patch_nullable_fields_for_model(model, schema_node, openapi_schema)

patch_nullable_parameters(openapi_schema)
patch_nullable_union_outputs(openapi_schema)


def update_openapi_schema_for_pydantic_2(
openapi_schema: Dict[str, Any],
) -> None:
_remove_webhook_events_filter_title(openapi_schema)
_remove_empty_or_nullable_anyof(openapi_schema)
_update_nullable_anyof(openapi_schema)
_flatten_selected_allof_refs(openapi_schema)
_extract_enum_properties(openapi_schema)
_set_default_enumeration_description(openapi_schema)
Expand All @@ -384,27 +657,32 @@ def _remove_webhook_events_filter_title(
pass


def _remove_empty_or_nullable_anyof(
def _update_nullable_anyof(
openapi_schema: Union[Dict[str, Any], List[Dict[str, Any]]],
) -> None:
# Version 3.0.X of OpenAPI doesn't support a `null` type, expecting
# `nullable` to be set instead.
if isinstance(openapi_schema, dict):
for key, value in list(openapi_schema.items()):
if key == "anyOf" and isinstance(value, list):
non_null_types = [item for item in value if item.get("type") != "null"]
if len(non_null_types) == 0:
del openapi_schema[key]
elif len(non_null_types) == 1:
openapi_schema.update(non_null_types[0])
del openapi_schema[key]
if key != "anyOf" or not isinstance(value, list):
_update_nullable_anyof(value)
continue

non_null_items = [item for item in value if item.get("type") != "null"]
if len(non_null_items) == 0:
del openapi_schema[key]
elif len(non_null_items) == 1:
openapi_schema.update(non_null_items[0])
del openapi_schema[key]
else:
openapi_schema[key] = non_null_items

# FIXME: Update tests to expect nullable
# openapi_schema["nullable"] = True
if len(non_null_items) < len(value):
openapi_schema["nullable"] = True

else:
_remove_empty_or_nullable_anyof(value)
elif isinstance(openapi_schema, list): # pyright: ignore
elif isinstance(openapi_schema, list):
for item in openapi_schema:
_remove_empty_or_nullable_anyof(item)
_update_nullable_anyof(item)


def _flatten_selected_allof_refs(
Expand Down
4 changes: 4 additions & 0 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
unwrap_pydantic_serialization_iterators,
update_openapi_schema_for_pydantic_2,
)
else:
from .helpers import update_nullable_optional

from .probes import ProbeHelper
from .runner import (
Expand Down Expand Up @@ -136,6 +138,8 @@ def custom_openapi() -> Dict[str, Any]:
# See: https://github.com/tiangolo/fastapi/pull/9873#issuecomment-1997105091
if PYDANTIC_V2:
update_openapi_schema_for_pydantic_2(openapi_schema)
else:
update_nullable_optional(openapi_schema, app)

app.openapi_schema = openapi_schema

Expand Down
4 changes: 3 additions & 1 deletion python/tests/server/fixtures/openapi_complex_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cog import BasePredictor, File, Input, Path
from typing import Optional

from cog import BaseModel, BasePredictor, File, Input, Path

class Predictor(BasePredictor):
def predict(
Expand All @@ -11,5 +12,6 @@ def predict(
image: File = Input(description="Some path"),
choices: str = Input(choices=["foo", "bar"]),
int_choices: int = Input(choices=[3, 4, 5]),
optional_str: Optional[str] = Input(),
) -> str:
pass
16 changes: 16 additions & 0 deletions python/tests/server/fixtures/openapi_optional_output_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional, Union

from cog import BasePredictor
from pydantic import BaseModel


class ModelOutput(BaseModel):
foo_number: int = "42"
foo_string: Optional[str] = None


class Predictor(BasePredictor):
def predict(
self,
) -> Union[ModelOutput, Optional[str]]:
pass
Loading