Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def update_openapi_schema_for_pydantic_2(
openapi_schema: Dict[str, Any],
) -> None:
_remove_webhook_events_filter_title(openapi_schema)
_remove_empty_or_nullable_anyof(openapi_schema)
_update_nullable_anyof(openapi_schema)
_flatten_selected_allof_refs(openapi_schema)
_extract_enum_properties(openapi_schema)
_set_default_enumeration_description(openapi_schema)
Expand All @@ -384,27 +384,32 @@ def _remove_webhook_events_filter_title(
pass


def _remove_empty_or_nullable_anyof(
def _update_nullable_anyof(
openapi_schema: Union[Dict[str, Any], List[Dict[str, Any]]],
) -> None:
# Version 3.0.X of OpenAPI doesn't support a `null` type, expecting
# `nullable` to be set instead.
if isinstance(openapi_schema, dict):
for key, value in list(openapi_schema.items()):
if key == "anyOf" and isinstance(value, list):
non_null_types = [item for item in value if item.get("type") != "null"]
if len(non_null_types) == 0:
del openapi_schema[key]
elif len(non_null_types) == 1:
openapi_schema.update(non_null_types[0])
del openapi_schema[key]
if key != "anyOf" or not isinstance(value, list):
_update_nullable_anyof(value)
continue

non_null_items = [item for item in value if item.get("type") != "null"]
if len(non_null_items) == 0:
del openapi_schema[key]
elif len(non_null_items) == 1:
openapi_schema.update(non_null_items[0])
del openapi_schema[key]
else:
openapi_schema[key] = non_null_items

# FIXME: Update tests to expect nullable
# openapi_schema["nullable"] = True
if len(non_null_items) < len(value):
openapi_schema["nullable"] = True

else:
_remove_empty_or_nullable_anyof(value)
elif isinstance(openapi_schema, list): # pyright: ignore
elif isinstance(openapi_schema, list):
for item in openapi_schema:
_remove_empty_or_nullable_anyof(item)
_update_nullable_anyof(item)


def _flatten_selected_allof_refs(
Expand Down
6 changes: 6 additions & 0 deletions python/tests/server/fixtures/input_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor, Seed


class Predictor(BasePredictor):
def predict(self, seed: Seed) -> int:
return seed
4 changes: 3 additions & 1 deletion python/tests/server/fixtures/openapi_complex_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cog import BasePredictor, File, Input, Path
from typing import Optional

from cog import BaseModel, BasePredictor, File, Input, Path

class Predictor(BasePredictor):
def predict(
Expand All @@ -11,5 +12,6 @@ def predict(
image: File = Input(description="Some path"),
choices: str = Input(choices=["foo", "bar"]),
int_choices: int = Input(choices=[3, 4, 5]),
optional_str: Optional[str] = Input(),
) -> str:
pass
16 changes: 16 additions & 0 deletions python/tests/server/fixtures/openapi_optional_output_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional, Union

from cog import BasePredictor
from pydantic import BaseModel


class ModelOutput(BaseModel):
foo_number: int = "42"
foo_string: Optional[str] = None


class Predictor(BasePredictor):
def predict(
self,
) -> Union[ModelOutput, Optional[str]]:
pass
Loading