From 99d03331a05b1eb74fea1b2ff446671dc767e83f Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 27 Oct 2025 09:04:36 -0500 Subject: [PATCH 1/4] Fix forward reference resolution in ValidatedFunction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #19288 This PR fixes an issue where Flow.validate_parameters() fails when using Pydantic models as parameters with `from __future__ import annotations`. The problem occurred because the dynamically created validation model didn't have access to the original function's namespace to resolve forward references. The fix adds model_rebuild() with the function's global namespace, allowing proper resolution of forward references to Pydantic models. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../_internal/pydantic/validated_func.py | 5 ++ .../_internal/pydantic/test_validated_func.py | 78 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/src/prefect/_internal/pydantic/validated_func.py b/src/prefect/_internal/pydantic/validated_func.py index 03e8ec195197..445b0ad6ca95 100644 --- a/src/prefect/_internal/pydantic/validated_func.py +++ b/src/prefect/_internal/pydantic/validated_func.py @@ -226,6 +226,11 @@ def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None: **fields, ) + # Rebuild the model with the original function's namespace to resolve forward references + # This is necessary when using `from __future__ import annotations` or when + # parameters reference types not in the current scope + self.model.model_rebuild(_types_namespace=self.raw_function.__globals__) + def validate_call_args( self, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> dict[str, Any]: diff --git a/tests/_internal/pydantic/test_validated_func.py b/tests/_internal/pydantic/test_validated_func.py index 6a002d8fd758..a4fa2710a365 100644 --- a/tests/_internal/pydantic/test_validated_func.py +++ b/tests/_internal/pydantic/test_validated_func.py @@ -411,3 +411,81 @@ def func(v__duplicate_kwargs): ValueError, match="Function parameters conflict with internal field names" ): ValidatedFunction(func) + + +class TestForwardReferences: + """Test handling of forward references and `from __future__ import annotations`.""" + + def test_pydantic_model_with_future_annotations(self): + """Test that Pydantic models work with forward reference annotations. + + This is a regression test for issue #19288. + When using `from __future__ import annotations`, type hints become strings + and need to be resolved via model_rebuild() with the proper namespace. + """ + # Define a test module namespace that simulates using future annotations + namespace = {} + + # Create a model in that namespace + exec( + """ +from pydantic import BaseModel, Field + +class TestModel(BaseModel): + name: str = Field(..., description="Test name") + value: int = 42 +""", + namespace, + ) + + TestModel = namespace["TestModel"] + + # Define a function with the model as a parameter using string annotation + # This simulates what happens with `from __future__ import annotations` + def process_model(model: "TestModel") -> dict: # noqa: F821 + return {"name": model.name, "value": model.value} + + # Update the function's globals to include the TestModel + process_model.__globals__.update(namespace) + + # Create validated function + vf = ValidatedFunction(process_model) + + # Create an instance of the model + test_instance = TestModel(name="test") + + # This should work without raising PydanticUserError about undefined models + result = vf.validate_call_args((test_instance,), {}) + + assert isinstance(result["model"], TestModel) + assert result["model"].name == "test" + assert result["model"].value == 42 + + def test_nested_pydantic_models_with_forward_refs(self): + """Test nested Pydantic models with forward references work correctly.""" + + class Inner(BaseModel): + value: int + + class Outer(BaseModel): + inner: Inner + name: str + + # Simulate forward reference by using string annotation + def process_nested(data: "Outer") -> str: # noqa: F821 + return data.name + + # Add the types to the function's globals + process_nested.__globals__["Outer"] = Outer + process_nested.__globals__["Inner"] = Inner + + vf = ValidatedFunction(process_nested) + + # Create nested structure + outer_instance = Outer(inner=Inner(value=42), name="test") + + result = vf.validate_call_args((outer_instance,), {}) + + assert isinstance(result["data"], Outer) + assert result["data"].name == "test" + assert result["data"].inner.value == 42 From 9427aa2cf4324c09092901783ef59a31fb6ba89e Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 27 Oct 2025 09:19:53 -0500 Subject: [PATCH 2/4] Optimize model_rebuild to only run when forward refs detected MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds a performance optimization to avoid calling model_rebuild() when there are no forward references in the function signature. The optimization: - Detects string annotations during field building - Only calls model_rebuild() if forward references are found - Avoids unnecessary overhead for functions without forward refs 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../_internal/pydantic/validated_func.py | 20 +++++++++++----- .../_internal/pydantic/test_validated_func.py | 24 +++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/prefect/_internal/pydantic/validated_func.py b/src/prefect/_internal/pydantic/validated_func.py index 445b0ad6ca95..d41eaf20ff2d 100644 --- a/src/prefect/_internal/pydantic/validated_func.py +++ b/src/prefect/_internal/pydantic/validated_func.py @@ -90,19 +90,20 @@ def __init__( ) # Build the validation model - fields, takes_args, takes_kwargs = self._build_fields() - self._create_model(fields, takes_args, takes_kwargs, config) + fields, takes_args, takes_kwargs, has_forward_refs = self._build_fields() + self._create_model(fields, takes_args, takes_kwargs, config, has_forward_refs) - def _build_fields(self) -> tuple[dict[str, Any], bool, bool]: + def _build_fields(self) -> tuple[dict[str, Any], bool, bool, bool]: """ Build field definitions from function signature. Returns: - Tuple of (fields_dict, takes_args, takes_kwargs) + Tuple of (fields_dict, takes_args, takes_kwargs, has_forward_refs) """ fields: dict[str, Any] = {} takes_args = False takes_kwargs = False + has_forward_refs = False position = 0 for param_name, param in self.signature.parameters.items(): @@ -131,6 +132,10 @@ def _build_fields(self) -> tuple[dict[str, Any], bool, bool]: param.annotation if param.annotation != inspect.Parameter.empty else Any ) + # Check if annotation is a string (forward reference) + if isinstance(annotation, str): + has_forward_refs = True + if param.default == inspect.Parameter.empty: # Required parameter fields[param_name] = (annotation, Field(...)) @@ -146,7 +151,7 @@ def _build_fields(self) -> tuple[dict[str, Any], bool, bool]: fields[V_POSITIONAL_ONLY_NAME] = (Optional[list[str]], Field(default=None)) fields[V_DUPLICATE_KWARGS] = (Optional[list[str]], Field(default=None)) - return fields, takes_args, takes_kwargs + return fields, takes_args, takes_kwargs, has_forward_refs def _create_model( self, @@ -154,6 +159,7 @@ def _create_model( takes_args: bool, takes_kwargs: bool, config: ConfigDict | None, + has_forward_refs: bool, ) -> None: """Create the Pydantic validation model.""" pos_args = len(self.arg_mapping) @@ -229,7 +235,9 @@ def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None: # Rebuild the model with the original function's namespace to resolve forward references # This is necessary when using `from __future__ import annotations` or when # parameters reference types not in the current scope - self.model.model_rebuild(_types_namespace=self.raw_function.__globals__) + # Only rebuild if we detected forward references to avoid performance overhead + if has_forward_refs: + self.model.model_rebuild(_types_namespace=self.raw_function.__globals__) def validate_call_args( self, args: tuple[Any, ...], kwargs: dict[str, Any] diff --git a/tests/_internal/pydantic/test_validated_func.py b/tests/_internal/pydantic/test_validated_func.py index a4fa2710a365..c748ac6a529c 100644 --- a/tests/_internal/pydantic/test_validated_func.py +++ b/tests/_internal/pydantic/test_validated_func.py @@ -489,3 +489,27 @@ def process_nested(data: "Outer") -> str: # noqa: F821 assert isinstance(result["data"], Outer) assert result["data"].name == "test" assert result["data"].inner.value == 42 + + def test_no_rebuild_without_forward_refs(self): + """Test that model_rebuild is not called when there are no forward references. + + This is a performance optimization test - we should avoid the overhead + of model_rebuild() when it's not necessary. + """ + + class MyModel(BaseModel): + name: str + + # Function with concrete type annotations (no forward refs) + def process_data(model: MyModel, count: int = 0) -> dict: + return {"name": model.name, "count": count} + + vf = ValidatedFunction(process_data) + + # The model should work correctly without rebuild + instance = MyModel(name="test") + result = vf.validate_call_args((instance,), {"count": 5}) + + assert isinstance(result["model"], MyModel) + assert result["model"].name == "test" + assert result["count"] == 5 From 13b9746a4f7c9e7a71db901c0caa1877963cb56a Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 27 Oct 2025 09:39:19 -0500 Subject: [PATCH 3/4] Add spy tests to verify model_rebuild optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improves test coverage by using unittest.mock to verify that: - model_rebuild() is NOT called when there are no forward references - model_rebuild() IS called with correct namespace when forward refs exist This ensures the performance optimization is actually working as intended. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../_internal/pydantic/test_validated_func.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/_internal/pydantic/test_validated_func.py b/tests/_internal/pydantic/test_validated_func.py index c748ac6a529c..461e1b2eb68e 100644 --- a/tests/_internal/pydantic/test_validated_func.py +++ b/tests/_internal/pydantic/test_validated_func.py @@ -1,5 +1,7 @@ """Tests for the pure Pydantic v2 validated function implementation.""" +from unittest.mock import patch + import pytest from pydantic import BaseModel, ValidationError @@ -504,7 +506,12 @@ class MyModel(BaseModel): def process_data(model: MyModel, count: int = 0) -> dict: return {"name": model.name, "count": count} - vf = ValidatedFunction(process_data) + # Spy on model_rebuild to ensure it's NOT called + with patch.object(BaseModel, "model_rebuild") as mock_rebuild: + vf = ValidatedFunction(process_data) + + # model_rebuild should NOT have been called since there are no forward refs + mock_rebuild.assert_not_called() # The model should work correctly without rebuild instance = MyModel(name="test") @@ -513,3 +520,38 @@ def process_data(model: MyModel, count: int = 0) -> dict: assert isinstance(result["model"], MyModel) assert result["model"].name == "test" assert result["count"] == 5 + + def test_rebuild_called_with_forward_refs(self): + """Test that model_rebuild IS called when forward references exist. + + This verifies that the optimization correctly detects forward refs + and calls model_rebuild when needed. + """ + + class MyModel(BaseModel): + name: str + + # Function with string annotation (forward reference) + def process_data(model: "MyModel", count: int = 0) -> dict: # noqa: F821 + return {"name": model.name, "count": count} + + # Add the type to the function's globals + process_data.__globals__["MyModel"] = MyModel + + # Spy on model_rebuild to ensure it IS called + with patch( + "prefect._internal.pydantic.validated_func.create_model" + ) as mock_create: + # Mock the created model + mock_model = type("MockModel", (BaseModel,), {}) + mock_create.return_value = mock_model + + with patch.object(mock_model, "model_rebuild") as mock_rebuild: + _vf = ValidatedFunction(process_data) + + # model_rebuild should have been called since there are forward refs + mock_rebuild.assert_called_once() + # Verify it was called with the function's globals + call_kwargs = mock_rebuild.call_args[1] + assert "_types_namespace" in call_kwargs + assert call_kwargs["_types_namespace"] is process_data.__globals__ From 1810a4e69bdd20a6a5d2a5c5f3004f21f1492b1c Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Mon, 27 Oct 2025 09:42:47 -0500 Subject: [PATCH 4/4] Remove useless test --- .../_internal/pydantic/test_validated_func.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/tests/_internal/pydantic/test_validated_func.py b/tests/_internal/pydantic/test_validated_func.py index 461e1b2eb68e..02230bd7ad64 100644 --- a/tests/_internal/pydantic/test_validated_func.py +++ b/tests/_internal/pydantic/test_validated_func.py @@ -520,38 +520,3 @@ def process_data(model: MyModel, count: int = 0) -> dict: assert isinstance(result["model"], MyModel) assert result["model"].name == "test" assert result["count"] == 5 - - def test_rebuild_called_with_forward_refs(self): - """Test that model_rebuild IS called when forward references exist. - - This verifies that the optimization correctly detects forward refs - and calls model_rebuild when needed. - """ - - class MyModel(BaseModel): - name: str - - # Function with string annotation (forward reference) - def process_data(model: "MyModel", count: int = 0) -> dict: # noqa: F821 - return {"name": model.name, "count": count} - - # Add the type to the function's globals - process_data.__globals__["MyModel"] = MyModel - - # Spy on model_rebuild to ensure it IS called - with patch( - "prefect._internal.pydantic.validated_func.create_model" - ) as mock_create: - # Mock the created model - mock_model = type("MockModel", (BaseModel,), {}) - mock_create.return_value = mock_model - - with patch.object(mock_model, "model_rebuild") as mock_rebuild: - _vf = ValidatedFunction(process_data) - - # model_rebuild should have been called since there are forward refs - mock_rebuild.assert_called_once() - # Verify it was called with the function's globals - call_kwargs = mock_rebuild.call_args[1] - assert "_types_namespace" in call_kwargs - assert call_kwargs["_types_namespace"] is process_data.__globals__