Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/prefect/_internal/pydantic/validated_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,20 @@ def __init__(
)

# Build the validation model
fields, takes_args, takes_kwargs = self._build_fields()
self._create_model(fields, takes_args, takes_kwargs, config)
fields, takes_args, takes_kwargs, has_forward_refs = self._build_fields()
self._create_model(fields, takes_args, takes_kwargs, config, has_forward_refs)

def _build_fields(self) -> tuple[dict[str, Any], bool, bool]:
def _build_fields(self) -> tuple[dict[str, Any], bool, bool, bool]:
"""
Build field definitions from function signature.

Returns:
Tuple of (fields_dict, takes_args, takes_kwargs)
Tuple of (fields_dict, takes_args, takes_kwargs, has_forward_refs)
"""
fields: dict[str, Any] = {}
takes_args = False
takes_kwargs = False
has_forward_refs = False
position = 0

for param_name, param in self.signature.parameters.items():
Expand Down Expand Up @@ -131,6 +132,10 @@ def _build_fields(self) -> tuple[dict[str, Any], bool, bool]:
param.annotation if param.annotation != inspect.Parameter.empty else Any
)

# Check if annotation is a string (forward reference)
if isinstance(annotation, str):
has_forward_refs = True

if param.default == inspect.Parameter.empty:
# Required parameter
fields[param_name] = (annotation, Field(...))
Expand All @@ -146,14 +151,15 @@ def _build_fields(self) -> tuple[dict[str, Any], bool, bool]:
fields[V_POSITIONAL_ONLY_NAME] = (Optional[list[str]], Field(default=None))
fields[V_DUPLICATE_KWARGS] = (Optional[list[str]], Field(default=None))

return fields, takes_args, takes_kwargs
return fields, takes_args, takes_kwargs, has_forward_refs

def _create_model(
self,
fields: dict[str, Any],
takes_args: bool,
takes_kwargs: bool,
config: ConfigDict | None,
has_forward_refs: bool,
) -> None:
"""Create the Pydantic validation model."""
pos_args = len(self.arg_mapping)
Expand Down Expand Up @@ -226,6 +232,13 @@ def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None:
**fields,
)

# Rebuild the model with the original function's namespace to resolve forward references
# This is necessary when using `from __future__ import annotations` or when
# parameters reference types not in the current scope
# Only rebuild if we detected forward references to avoid performance overhead
if has_forward_refs:
self.model.model_rebuild(_types_namespace=self.raw_function.__globals__)

def validate_call_args(
self, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> dict[str, Any]:
Expand Down
109 changes: 109 additions & 0 deletions tests/_internal/pydantic/test_validated_func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the pure Pydantic v2 validated function implementation."""

from unittest.mock import patch

import pytest
from pydantic import BaseModel, ValidationError

Expand Down Expand Up @@ -411,3 +413,110 @@ def func(v__duplicate_kwargs):
ValueError, match="Function parameters conflict with internal field names"
):
ValidatedFunction(func)


class TestForwardReferences:
"""Test handling of forward references and `from __future__ import annotations`."""

def test_pydantic_model_with_future_annotations(self):
"""Test that Pydantic models work with forward reference annotations.

This is a regression test for issue #19288.
When using `from __future__ import annotations`, type hints become strings
and need to be resolved via model_rebuild() with the proper namespace.
"""
# Define a test module namespace that simulates using future annotations
namespace = {}

# Create a model in that namespace
exec(
"""
from pydantic import BaseModel, Field

class TestModel(BaseModel):
name: str = Field(..., description="Test name")
value: int = 42
""",
namespace,
)

TestModel = namespace["TestModel"]

# Define a function with the model as a parameter using string annotation
# This simulates what happens with `from __future__ import annotations`
def process_model(model: "TestModel") -> dict: # noqa: F821
return {"name": model.name, "value": model.value}

# Update the function's globals to include the TestModel
process_model.__globals__.update(namespace)

# Create validated function
vf = ValidatedFunction(process_model)

# Create an instance of the model
test_instance = TestModel(name="test")

# This should work without raising PydanticUserError about undefined models
result = vf.validate_call_args((test_instance,), {})

assert isinstance(result["model"], TestModel)
assert result["model"].name == "test"
assert result["model"].value == 42

def test_nested_pydantic_models_with_forward_refs(self):
"""Test nested Pydantic models with forward references work correctly."""

class Inner(BaseModel):
value: int

class Outer(BaseModel):
inner: Inner
name: str

# Simulate forward reference by using string annotation
def process_nested(data: "Outer") -> str: # noqa: F821
return data.name

# Add the types to the function's globals
process_nested.__globals__["Outer"] = Outer
process_nested.__globals__["Inner"] = Inner

vf = ValidatedFunction(process_nested)

# Create nested structure
outer_instance = Outer(inner=Inner(value=42), name="test")

result = vf.validate_call_args((outer_instance,), {})

assert isinstance(result["data"], Outer)
assert result["data"].name == "test"
assert result["data"].inner.value == 42

def test_no_rebuild_without_forward_refs(self):
"""Test that model_rebuild is not called when there are no forward references.

This is a performance optimization test - we should avoid the overhead
of model_rebuild() when it's not necessary.
"""

class MyModel(BaseModel):
name: str

# Function with concrete type annotations (no forward refs)
def process_data(model: MyModel, count: int = 0) -> dict:
return {"name": model.name, "count": count}

# Spy on model_rebuild to ensure it's NOT called
with patch.object(BaseModel, "model_rebuild") as mock_rebuild:
vf = ValidatedFunction(process_data)

# model_rebuild should NOT have been called since there are no forward refs
mock_rebuild.assert_not_called()

# The model should work correctly without rebuild
instance = MyModel(name="test")
result = vf.validate_call_args((instance,), {"count": 5})

assert isinstance(result["model"], MyModel)
assert result["model"].name == "test"
assert result["count"] == 5