Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 172 additions & 2 deletions src/mxcp/sdk/executor/plugins/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,16 @@
import hashlib
import inspect
import logging
import sys
import tempfile
import typing
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
from pydantic import BaseModel, TypeAdapter, ValidationError

from mxcp.sdk.telemetry import (
decrement_gauge,
Expand Down Expand Up @@ -555,12 +558,25 @@ async def _execute_function(
try:
from ..context import reset_execution_context, set_execution_context

# Convert parameters based on function signature
converted_params = self._convert_parameters(func, params)

# Handle validation errors
if isinstance(converted_params, dict) and "__validation_errors" in converted_params:
# Convert structured errors to string for backward compatibility
# TODO: In future, return structured errors directly for better UI
error_parts = []
for param_name, param_errors in converted_params["__validation_errors"].items():
error_details = [f"{err['field']}: {err['message']}" for err in param_errors]
error_parts.append(f"Invalid {param_name}: {', '.join(error_details)}")
return "; ".join(error_parts)

# Check if function is async
if asyncio.iscoroutinefunction(func):
# For async functions, set context and let contextvars propagate it
context_token = set_execution_context(context)
try:
result = await func(**params)
result = await func(**converted_params)
finally:
reset_execution_context(context_token)
else:
Expand All @@ -570,7 +586,7 @@ def sync_function_wrapper() -> Any:

thread_token = set_execution_context(context)
try:
return func(**params)
return func(**converted_params)
finally:
reset_execution_context(thread_token)

Expand All @@ -586,3 +602,157 @@ def sync_function_wrapper() -> Any:
except Exception as e:
logger.error(f"Function execution failed: {e}")
raise

def _convert_parameters(
self, func: Callable[..., Any], params: dict[str, Any]
) -> dict[str, Any]:
"""Convert parameters based on function signature with comprehensive type support.

Features:
- Uses TypeAdapter for robust type conversion and validation
- Supports forward references with proper module namespace context
- Collects structured validation errors per parameter for better UI
- Uses signature binding for proper parameter handling
- Handles **kwargs functions specially to maintain compatibility

Args:
func: The function to call
params: Raw parameter dictionary

Returns:
Converted parameters dictionary, or dict with __validation_errors on failure.
The __validation_errors contains structured per-parameter error details.
"""

# Get function signature
sig = inspect.signature(func)

# Resolve type hints with proper module context for forward references
try:
globalns = getattr(func, "__globals__", {})
# Use the defining module's namespace for localns (handles more forward-ref edge cases)
mod = sys.modules.get(getattr(func, "__module__", ""), None)
localns = vars(mod) if mod else globalns
type_hints = typing.get_type_hints(
func, globalns=globalns, localns=localns, include_extras=True
)
except (NameError, AttributeError, TypeError) as e:
logger.debug(f"Failed to resolve type hints for {func.__name__}: {e}")
type_hints = {}

# Use signature binding for proper parameter mapping, but handle **kwargs specially
has_var_keyword = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)

if has_var_keyword:
# For functions with **kwargs, use direct parameter mapping to maintain compatibility
bound_params = params
else:
# For regular functions, use bind_partial for proper parameter handling
try:
bound_args = sig.bind_partial(**params)
bound_args.apply_defaults()
bound_params = bound_args.arguments
except TypeError as e:
# If binding fails, fall back to direct parameter mapping
logger.debug(f"Parameter binding failed for {func.__name__}: {e}")
bound_params = params

converted_params = {}
validation_errors = {} # Dict to store per-parameter error details

for param_name, param_value in bound_params.items():
try:
# Get resolved type hint, fall back to raw annotation if available
param_type = type_hints.get(param_name)
if param_type is None and param_name in sig.parameters:
param_type = sig.parameters[param_name].annotation

# Skip if no type annotation or annotation is Any/object
if (
param_type is None
or param_type == inspect.Parameter.empty
or param_type in (typing.Any, object)
):
converted_params[param_name] = param_value
continue

# Convert the parameter value based on its type
converted_value = self._convert_parameter_value(param_name, param_value, param_type)
converted_params[param_name] = converted_value

except ValidationError as ve:
# Collect structured validation errors for this parameter
param_errors = []
for error in ve.errors():
param_errors.append(
{
"field": (
".".join(str(loc) for loc in error["loc"])
if error["loc"]
else param_name
),
"message": error["msg"],
"type": error["type"],
}
)
validation_errors[param_name] = param_errors

except Exception as e:
# Log unexpected errors with stack trace and fail validation
logger.exception(f"Unexpected error converting parameter '{param_name}'")
validation_errors[param_name] = [
{
"field": param_name,
"message": f"Unexpected conversion error: {str(e)}",
"type": "unexpected_error",
}
]

# Return structured validation errors if any occurred
if validation_errors:
logger.error(
f"Parameter validation failed for parameters: {list(validation_errors.keys())}"
)
return {"__validation_errors": validation_errors}

return converted_params

def _convert_parameter_value(self, param_name: str, param_value: Any, param_type: Any) -> Any:
"""Convert a single parameter value based on its type annotation.

Uses TypeAdapter to handle all type conversions uniformly:
- Direct models: User
- Optional models: Optional[User], User | None
- Container types: list[User], dict[str, User], etc.
- Complex nested types: dict[str, list[User]]
- Non-model types: TypedDicts, dataclasses, primitives

Args:
param_name: Parameter name (for logging/errors)
param_value: Raw parameter value
param_type: Resolved type annotation

Returns:
Converted parameter value

Raises:
ValidationError: If pydantic validation fails
"""

# Fast path: if value is already a BaseModel instance, keep it as-is
if isinstance(param_value, BaseModel):
logger.debug(f"Parameter '{param_name}' is already a BaseModel instance")
return param_value

# Try to create TypeAdapter for the parameter type
try:
adapter = TypeAdapter(param_type)
logger.debug(f"Converting parameter '{param_name}' using TypeAdapter for {param_type}")
return adapter.validate_python(param_value)
except (TypeError, ValueError) as e:
# TypeAdapter couldn't be created or type doesn't need validation
# This happens for types that don't require special handling
logger.debug(f"No TypeAdapter needed for parameter '{param_name}': {e}")
return param_value
34 changes: 33 additions & 1 deletion src/mxcp/sdk/validator/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,26 @@ def validate_output(value: Any, schema: TypeSchema) -> None:
raise ValidationError(f"Array item {i}: {str(e)}") from e

elif return_type == "object":
# Handle pydantic models by converting them to dict first
if not isinstance(value, dict):
raise ValidationError(f"Expected object, got {type(value).__name__}")
if hasattr(value, "model_dump"):
# Pydantic v2
try:
value = value.model_dump()
except Exception:
raise ValidationError(
f"Expected object, got {type(value).__name__}"
) from None
elif hasattr(value, "dict"):
# Pydantic v1
try:
value = value.dict()
except Exception:
raise ValidationError(
f"Expected object, got {type(value).__name__}"
) from None
else:
raise ValidationError(f"Expected object, got {type(value).__name__}")

properties = schema.properties or {}
required = schema.required or []
Expand Down Expand Up @@ -366,6 +384,20 @@ def serialize_for_output(obj: Any) -> Any:
elif hasattr(obj, "isoformat"):
# Handle any other datetime-like objects
return obj.isoformat()
elif hasattr(obj, "model_dump"):
# Handle Pydantic v2 models
try:
serialized = obj.model_dump()
return TypeConverter.serialize_for_output(serialized)
except Exception:
return obj
elif hasattr(obj, "dict"):
# Handle Pydantic v1 models
try:
serialized = obj.dict()
return TypeConverter.serialize_for_output(serialized)
except Exception:
return obj
else:
return obj

Expand Down
84 changes: 84 additions & 0 deletions tests/server/fixtures/integration/python/test_container_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List, Dict, Optional
from pydantic import BaseModel


class User(BaseModel):
"""Simple user model for testing container types."""

name: str
age: int
email: str


def process_user_list(users: List[User]) -> dict:
"""Test function that takes a list of pydantic models."""
total_age = sum(user.age for user in users)
avg_age = total_age / len(users) if users else 0
names = [user.name for user in users]

return {
"user_count": len(users),
"average_age": avg_age,
"names": names,
"total_age": total_age,
}


def process_user_dict(user_map: Dict[str, User]) -> dict:
"""Test function that takes a dict of pydantic models."""
user_count = len(user_map)
names = list(user_map.keys())
ages = [user.age for user in user_map.values()]
avg_age = sum(ages) / len(ages) if ages else 0

return {
"user_count": user_count,
"user_keys": sorted(names),
"average_age": avg_age,
"oldest_user": max(user_map.values(), key=lambda u: u.age).name if user_map else None,
}


def process_optional_user(user: Optional[User] = None) -> dict:
"""Test function with optional pydantic model."""
if user is None:
return {"has_user": False, "message": "No user provided"}

return {
"has_user": True,
"user_name": user.name,
"user_age": user.age,
"user_email": user.email,
}


# Test functions using built-in container types (Python 3.9+)
def process_builtin_user_list(users: list[User]) -> dict:
"""Test function using built-in list[User] (Python 3.9+)."""
total_age = sum(user.age for user in users)
avg_age = total_age / len(users) if users else 0
names = [user.name for user in users]

return {
"user_count": len(users),
"average_age": avg_age,
"names": names,
"total_age": total_age,
"type_used": "builtin_list",
}


def process_builtin_user_dict(user_map: dict[str, User]) -> dict:
"""Test function using built-in dict[str, User] (Python 3.9+)."""
user_count = len(user_map)
names = list(user_map.keys())
ages = [user.age for user in user_map.values()]
avg_age = sum(ages) / len(ages) if ages else 0

return {
"user_count": user_count,
"user_keys": sorted(names),
"average_age": avg_age,
"oldest_user": max(user_map.values(), key=lambda u: u.age).name if user_map else None,
"type_used": "builtin_dict",
}
Loading
Loading