Skip to content

Commit 0dad77b

Browse files
committed
Add pydantic model tests for MXCP integration
1 parent ae22058 commit 0dad77b

File tree

10 files changed

+2103
-1561
lines changed

10 files changed

+2103
-1561
lines changed

src/mxcp/sdk/executor/plugins/python.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import hashlib
4747
import inspect
4848
import logging
49+
import sys
4950
import tempfile
5051
from collections.abc import Callable
5152
from pathlib import Path
@@ -555,12 +556,25 @@ async def _execute_function(
555556
try:
556557
from ..context import reset_execution_context, set_execution_context
557558

559+
# Convert parameters based on function signature
560+
converted_params = self._convert_parameters(func, params)
561+
562+
# Handle validation errors
563+
if isinstance(converted_params, dict) and "__validation_errors" in converted_params:
564+
# Convert structured errors to string for backward compatibility
565+
# TODO: In future, return structured errors directly for better UI
566+
error_parts = []
567+
for param_name, param_errors in converted_params["__validation_errors"].items():
568+
error_details = [f"{err['field']}: {err['message']}" for err in param_errors]
569+
error_parts.append(f"Invalid {param_name}: {', '.join(error_details)}")
570+
return "; ".join(error_parts)
571+
558572
# Check if function is async
559573
if asyncio.iscoroutinefunction(func):
560574
# For async functions, set context and let contextvars propagate it
561575
context_token = set_execution_context(context)
562576
try:
563-
result = await func(**params)
577+
result = await func(**converted_params)
564578
finally:
565579
reset_execution_context(context_token)
566580
else:
@@ -570,7 +584,7 @@ def sync_function_wrapper() -> Any:
570584

571585
thread_token = set_execution_context(context)
572586
try:
573-
return func(**params)
587+
return func(**converted_params)
574588
finally:
575589
reset_execution_context(thread_token)
576590

@@ -586,3 +600,155 @@ def sync_function_wrapper() -> Any:
586600
except Exception as e:
587601
logger.error(f"Function execution failed: {e}")
588602
raise
603+
604+
def _convert_parameters(
605+
self, func: Callable[..., Any], params: dict[str, Any]
606+
) -> dict[str, Any]:
607+
"""Convert parameters based on function signature with comprehensive type support.
608+
609+
Features:
610+
- Uses TypeAdapter for robust type conversion and validation
611+
- Supports forward references with proper module namespace context
612+
- Collects structured validation errors per parameter for better UI
613+
- Uses signature binding for proper parameter handling
614+
- Handles **kwargs functions specially to maintain compatibility
615+
616+
Args:
617+
func: The function to call
618+
params: Raw parameter dictionary
619+
620+
Returns:
621+
Converted parameters dictionary, or dict with __validation_errors on failure.
622+
The __validation_errors contains structured per-parameter error details.
623+
"""
624+
import typing
625+
626+
from pydantic import ValidationError
627+
628+
# Get function signature
629+
sig = inspect.signature(func)
630+
631+
# Resolve type hints with proper module context for forward references
632+
try:
633+
globalns = getattr(func, "__globals__", {})
634+
# Use the defining module's namespace for localns (handles more forward-ref edge cases)
635+
mod = sys.modules.get(getattr(func, "__module__", ""), None)
636+
localns = vars(mod) if mod else globalns
637+
type_hints = typing.get_type_hints(
638+
func, globalns=globalns, localns=localns, include_extras=True
639+
)
640+
except (NameError, AttributeError, TypeError) as e:
641+
logger.debug(f"Failed to resolve type hints for {func.__name__}: {e}")
642+
type_hints = {}
643+
644+
# Use signature binding for proper parameter mapping, but handle **kwargs specially
645+
has_var_keyword = any(
646+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
647+
)
648+
649+
if has_var_keyword:
650+
# For functions with **kwargs, use direct parameter mapping to maintain compatibility
651+
bound_params = params
652+
else:
653+
# For regular functions, use bind_partial for proper parameter handling
654+
try:
655+
bound_args = sig.bind_partial(**params)
656+
bound_args.apply_defaults()
657+
bound_params = bound_args.arguments
658+
except TypeError as e:
659+
# If binding fails, fall back to direct parameter mapping
660+
logger.debug(f"Parameter binding failed for {func.__name__}: {e}")
661+
bound_params = params
662+
663+
converted_params = {}
664+
validation_errors = {} # Dict to store per-parameter error details
665+
666+
for param_name, param_value in bound_params.items():
667+
try:
668+
# Get resolved type hint, fall back to raw annotation if available
669+
param_type = type_hints.get(param_name)
670+
if param_type is None and param_name in sig.parameters:
671+
param_type = sig.parameters[param_name].annotation
672+
673+
# Skip if no type annotation or annotation is Any/object
674+
if (
675+
param_type is None
676+
or param_type == inspect.Parameter.empty
677+
or param_type in (typing.Any, object)
678+
):
679+
converted_params[param_name] = param_value
680+
continue
681+
682+
# Convert the parameter value based on its type
683+
converted_value = self._convert_parameter_value(param_name, param_value, param_type)
684+
converted_params[param_name] = converted_value
685+
686+
except ValidationError as ve:
687+
# Collect structured validation errors for this parameter
688+
param_errors = []
689+
for error in ve.errors():
690+
param_errors.append(
691+
{
692+
"field": (
693+
".".join(str(loc) for loc in error["loc"])
694+
if error["loc"]
695+
else param_name
696+
),
697+
"message": error["msg"],
698+
"type": error["type"],
699+
}
700+
)
701+
validation_errors[param_name] = param_errors
702+
703+
except Exception as e:
704+
# Log unexpected errors with stack trace but continue processing other parameters
705+
logger.exception(f"Unexpected error converting parameter '{param_name}': {e}")
706+
converted_params[param_name] = param_value
707+
708+
# Return structured validation errors if any occurred
709+
if validation_errors:
710+
logger.error(
711+
f"Parameter validation failed for parameters: {list(validation_errors.keys())}"
712+
)
713+
return {"__validation_errors": validation_errors}
714+
715+
return converted_params
716+
717+
def _convert_parameter_value(self, param_name: str, param_value: Any, param_type: Any) -> Any:
718+
"""Convert a single parameter value based on its type annotation.
719+
720+
Uses TypeAdapter to handle all type conversions uniformly:
721+
- Direct models: User
722+
- Optional models: Optional[User], User | None
723+
- Container types: list[User], dict[str, User], etc.
724+
- Complex nested types: dict[str, list[User]]
725+
- Non-model types: TypedDicts, dataclasses, primitives
726+
727+
Args:
728+
param_name: Parameter name (for logging/errors)
729+
param_value: Raw parameter value
730+
param_type: Resolved type annotation
731+
732+
Returns:
733+
Converted parameter value
734+
735+
Raises:
736+
ValidationError: If pydantic validation fails
737+
"""
738+
from pydantic import BaseModel, TypeAdapter
739+
740+
# Fast path: if value is already a BaseModel instance, keep it as-is
741+
if isinstance(param_value, BaseModel):
742+
logger.debug(f"Parameter '{param_name}' is already a BaseModel instance")
743+
return param_value
744+
745+
# Try to create TypeAdapter for the parameter type
746+
try:
747+
adapter = TypeAdapter(param_type)
748+
logger.debug(f"Converting parameter '{param_name}' using TypeAdapter for {param_type}")
749+
return adapter.validate_python(param_value)
750+
except (TypeError, ValueError) as e:
751+
# TypeAdapter couldn't be created or type doesn't need validation
752+
# This happens for types that don't require special handling
753+
logger.debug(f"No TypeAdapter needed for parameter '{param_name}': {e}")
754+
return param_value

src/mxcp/sdk/validator/converters.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,26 @@ def validate_output(value: Any, schema: TypeSchema) -> None:
323323
raise ValidationError(f"Array item {i}: {str(e)}") from e
324324

325325
elif return_type == "object":
326+
# Handle pydantic models by converting them to dict first
326327
if not isinstance(value, dict):
327-
raise ValidationError(f"Expected object, got {type(value).__name__}")
328+
if hasattr(value, "model_dump"):
329+
# Pydantic v2
330+
try:
331+
value = value.model_dump()
332+
except Exception:
333+
raise ValidationError(
334+
f"Expected object, got {type(value).__name__}"
335+
) from None
336+
elif hasattr(value, "dict"):
337+
# Pydantic v1
338+
try:
339+
value = value.dict()
340+
except Exception:
341+
raise ValidationError(
342+
f"Expected object, got {type(value).__name__}"
343+
) from None
344+
else:
345+
raise ValidationError(f"Expected object, got {type(value).__name__}")
328346

329347
properties = schema.properties or {}
330348
required = schema.required or []
@@ -366,6 +384,20 @@ def serialize_for_output(obj: Any) -> Any:
366384
elif hasattr(obj, "isoformat"):
367385
# Handle any other datetime-like objects
368386
return obj.isoformat()
387+
elif hasattr(obj, "model_dump"):
388+
# Handle Pydantic v2 models
389+
try:
390+
serialized = obj.model_dump()
391+
return TypeConverter.serialize_for_output(serialized)
392+
except Exception:
393+
return obj
394+
elif hasattr(obj, "dict"):
395+
# Handle Pydantic v1 models
396+
try:
397+
serialized = obj.dict()
398+
return TypeConverter.serialize_for_output(serialized)
399+
except Exception:
400+
return obj
369401
else:
370402
return obj
371403

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import List, Dict, Optional
2+
from pydantic import BaseModel
3+
4+
5+
class User(BaseModel):
6+
"""Simple user model for testing container types."""
7+
8+
name: str
9+
age: int
10+
email: str
11+
12+
13+
def process_user_list(users: List[User]) -> dict:
14+
"""Test function that takes a list of pydantic models."""
15+
total_age = sum(user.age for user in users)
16+
avg_age = total_age / len(users) if users else 0
17+
names = [user.name for user in users]
18+
19+
return {
20+
"user_count": len(users),
21+
"average_age": avg_age,
22+
"names": names,
23+
"total_age": total_age,
24+
}
25+
26+
27+
def process_user_dict(user_map: Dict[str, User]) -> dict:
28+
"""Test function that takes a dict of pydantic models."""
29+
user_count = len(user_map)
30+
names = list(user_map.keys())
31+
ages = [user.age for user in user_map.values()]
32+
avg_age = sum(ages) / len(ages) if ages else 0
33+
34+
return {
35+
"user_count": user_count,
36+
"user_keys": sorted(names),
37+
"average_age": avg_age,
38+
"oldest_user": max(user_map.values(), key=lambda u: u.age).name if user_map else None,
39+
}
40+
41+
42+
def process_optional_user(user: Optional[User] = None) -> dict:
43+
"""Test function with optional pydantic model."""
44+
if user is None:
45+
return {"has_user": False, "message": "No user provided"}
46+
47+
return {
48+
"has_user": True,
49+
"user_name": user.name,
50+
"user_age": user.age,
51+
"user_email": user.email,
52+
}
53+
54+
55+
# Test functions using built-in container types (Python 3.9+)
56+
def process_builtin_user_list(users: list[User]) -> dict:
57+
"""Test function using built-in list[User] (Python 3.9+)."""
58+
total_age = sum(user.age for user in users)
59+
avg_age = total_age / len(users) if users else 0
60+
names = [user.name for user in users]
61+
62+
return {
63+
"user_count": len(users),
64+
"average_age": avg_age,
65+
"names": names,
66+
"total_age": total_age,
67+
"type_used": "builtin_list",
68+
}
69+
70+
71+
def process_builtin_user_dict(user_map: dict[str, User]) -> dict:
72+
"""Test function using built-in dict[str, User] (Python 3.9+)."""
73+
user_count = len(user_map)
74+
names = list(user_map.keys())
75+
ages = [user.age for user in user_map.values()]
76+
avg_age = sum(ages) / len(ages) if ages else 0
77+
78+
return {
79+
"user_count": user_count,
80+
"user_keys": sorted(names),
81+
"average_age": avg_age,
82+
"oldest_user": max(user_map.values(), key=lambda u: u.age).name if user_map else None,
83+
"type_used": "builtin_dict",
84+
}

tests/server/fixtures/integration/python/test_endpoints.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Any
22
from mxcp.runtime import config, on_init
3+
from pydantic import BaseModel
34

45
global_var = None
56

@@ -121,3 +122,53 @@ def process_user_data(user_data: Dict[str, Any]) -> Dict[str, Any]:
121122
}
122123

123124
return {"original_data": user_data, "analysis": analysis, "processing_status": "success"}
125+
126+
127+
# Pydantic models for testing
128+
class UserProfile(BaseModel):
129+
"""User profile model."""
130+
131+
name: str
132+
age: int
133+
email: str
134+
is_premium: bool = False
135+
136+
137+
class UserStats(BaseModel):
138+
"""User statistics model."""
139+
140+
total_users: int
141+
active_users: int
142+
premium_users: int
143+
average_age: float
144+
145+
146+
def validate_user_profile(profile: UserProfile) -> str:
147+
"""Take a pydantic model as parameter and return a primitive result."""
148+
# Validate and process the user profile
149+
if profile.age < 0:
150+
return "Invalid age: must be non-negative"
151+
152+
if not profile.email or "@" not in profile.email:
153+
return "Invalid email format"
154+
155+
status = "premium" if profile.is_premium else "regular"
156+
return f"User {profile.name} ({profile.age} years old, {profile.email}) is a {status} user - validation passed"
157+
158+
159+
def get_user_stats(user_count: int) -> UserStats:
160+
"""Take a primitive argument and return a pydantic model."""
161+
# Generate some mock statistics based on the user count
162+
active_ratio = 0.8
163+
premium_ratio = 0.3
164+
165+
active_users = int(user_count * active_ratio)
166+
premium_users = int(user_count * premium_ratio)
167+
average_age = 32.5 # Mock average age
168+
169+
return UserStats(
170+
total_users=user_count,
171+
active_users=active_users,
172+
premium_users=premium_users,
173+
average_age=average_age,
174+
)

0 commit comments

Comments
 (0)