diff --git a/src/mxcp/sdk/executor/plugins/python.py b/src/mxcp/sdk/executor/plugins/python.py index cc1e2d97..54cc05f5 100644 --- a/src/mxcp/sdk/executor/plugins/python.py +++ b/src/mxcp/sdk/executor/plugins/python.py @@ -46,13 +46,16 @@ import hashlib import inspect import logging +import sys import tempfile +import typing from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd +from pydantic import BaseModel, TypeAdapter, ValidationError from mxcp.sdk.telemetry import ( decrement_gauge, @@ -555,12 +558,25 @@ async def _execute_function( try: from ..context import reset_execution_context, set_execution_context + # Convert parameters based on function signature + converted_params = self._convert_parameters(func, params) + + # Handle validation errors + if isinstance(converted_params, dict) and "__validation_errors" in converted_params: + # Convert structured errors to string for backward compatibility + # TODO: In future, return structured errors directly for better UI + error_parts = [] + for param_name, param_errors in converted_params["__validation_errors"].items(): + error_details = [f"{err['field']}: {err['message']}" for err in param_errors] + error_parts.append(f"Invalid {param_name}: {', '.join(error_details)}") + return "; ".join(error_parts) + # Check if function is async if asyncio.iscoroutinefunction(func): # For async functions, set context and let contextvars propagate it context_token = set_execution_context(context) try: - result = await func(**params) + result = await func(**converted_params) finally: reset_execution_context(context_token) else: @@ -570,7 +586,7 @@ def sync_function_wrapper() -> Any: thread_token = set_execution_context(context) try: - return func(**params) + return func(**converted_params) finally: reset_execution_context(thread_token) @@ -586,3 +602,157 @@ def sync_function_wrapper() -> Any: except Exception as e: logger.error(f"Function execution failed: {e}") raise + + def _convert_parameters( + self, func: Callable[..., Any], params: dict[str, Any] + ) -> dict[str, Any]: + """Convert parameters based on function signature with comprehensive type support. + + Features: + - Uses TypeAdapter for robust type conversion and validation + - Supports forward references with proper module namespace context + - Collects structured validation errors per parameter for better UI + - Uses signature binding for proper parameter handling + - Handles **kwargs functions specially to maintain compatibility + + Args: + func: The function to call + params: Raw parameter dictionary + + Returns: + Converted parameters dictionary, or dict with __validation_errors on failure. + The __validation_errors contains structured per-parameter error details. + """ + + # Get function signature + sig = inspect.signature(func) + + # Resolve type hints with proper module context for forward references + try: + globalns = getattr(func, "__globals__", {}) + # Use the defining module's namespace for localns (handles more forward-ref edge cases) + mod = sys.modules.get(getattr(func, "__module__", ""), None) + localns = vars(mod) if mod else globalns + type_hints = typing.get_type_hints( + func, globalns=globalns, localns=localns, include_extras=True + ) + except (NameError, AttributeError, TypeError) as e: + logger.debug(f"Failed to resolve type hints for {func.__name__}: {e}") + type_hints = {} + + # Use signature binding for proper parameter mapping, but handle **kwargs specially + has_var_keyword = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + if has_var_keyword: + # For functions with **kwargs, use direct parameter mapping to maintain compatibility + bound_params = params + else: + # For regular functions, use bind_partial for proper parameter handling + try: + bound_args = sig.bind_partial(**params) + bound_args.apply_defaults() + bound_params = bound_args.arguments + except TypeError as e: + # If binding fails, fall back to direct parameter mapping + logger.debug(f"Parameter binding failed for {func.__name__}: {e}") + bound_params = params + + converted_params = {} + validation_errors = {} # Dict to store per-parameter error details + + for param_name, param_value in bound_params.items(): + try: + # Get resolved type hint, fall back to raw annotation if available + param_type = type_hints.get(param_name) + if param_type is None and param_name in sig.parameters: + param_type = sig.parameters[param_name].annotation + + # Skip if no type annotation or annotation is Any/object + if ( + param_type is None + or param_type == inspect.Parameter.empty + or param_type in (typing.Any, object) + ): + converted_params[param_name] = param_value + continue + + # Convert the parameter value based on its type + converted_value = self._convert_parameter_value(param_name, param_value, param_type) + converted_params[param_name] = converted_value + + except ValidationError as ve: + # Collect structured validation errors for this parameter + param_errors = [] + for error in ve.errors(): + param_errors.append( + { + "field": ( + ".".join(str(loc) for loc in error["loc"]) + if error["loc"] + else param_name + ), + "message": error["msg"], + "type": error["type"], + } + ) + validation_errors[param_name] = param_errors + + except Exception as e: + # Log unexpected errors with stack trace and fail validation + logger.exception(f"Unexpected error converting parameter '{param_name}'") + validation_errors[param_name] = [ + { + "field": param_name, + "message": f"Unexpected conversion error: {str(e)}", + "type": "unexpected_error", + } + ] + + # Return structured validation errors if any occurred + if validation_errors: + logger.error( + f"Parameter validation failed for parameters: {list(validation_errors.keys())}" + ) + return {"__validation_errors": validation_errors} + + return converted_params + + def _convert_parameter_value(self, param_name: str, param_value: Any, param_type: Any) -> Any: + """Convert a single parameter value based on its type annotation. + + Uses TypeAdapter to handle all type conversions uniformly: + - Direct models: User + - Optional models: Optional[User], User | None + - Container types: list[User], dict[str, User], etc. + - Complex nested types: dict[str, list[User]] + - Non-model types: TypedDicts, dataclasses, primitives + + Args: + param_name: Parameter name (for logging/errors) + param_value: Raw parameter value + param_type: Resolved type annotation + + Returns: + Converted parameter value + + Raises: + ValidationError: If pydantic validation fails + """ + + # Fast path: if value is already a BaseModel instance, keep it as-is + if isinstance(param_value, BaseModel): + logger.debug(f"Parameter '{param_name}' is already a BaseModel instance") + return param_value + + # Try to create TypeAdapter for the parameter type + try: + adapter = TypeAdapter(param_type) + logger.debug(f"Converting parameter '{param_name}' using TypeAdapter for {param_type}") + return adapter.validate_python(param_value) + except (TypeError, ValueError) as e: + # TypeAdapter couldn't be created or type doesn't need validation + # This happens for types that don't require special handling + logger.debug(f"No TypeAdapter needed for parameter '{param_name}': {e}") + return param_value diff --git a/src/mxcp/sdk/validator/converters.py b/src/mxcp/sdk/validator/converters.py index 51f0bedd..d10b779f 100644 --- a/src/mxcp/sdk/validator/converters.py +++ b/src/mxcp/sdk/validator/converters.py @@ -323,8 +323,26 @@ def validate_output(value: Any, schema: TypeSchema) -> None: raise ValidationError(f"Array item {i}: {str(e)}") from e elif return_type == "object": + # Handle pydantic models by converting them to dict first if not isinstance(value, dict): - raise ValidationError(f"Expected object, got {type(value).__name__}") + if hasattr(value, "model_dump"): + # Pydantic v2 + try: + value = value.model_dump() + except Exception: + raise ValidationError( + f"Expected object, got {type(value).__name__}" + ) from None + elif hasattr(value, "dict"): + # Pydantic v1 + try: + value = value.dict() + except Exception: + raise ValidationError( + f"Expected object, got {type(value).__name__}" + ) from None + else: + raise ValidationError(f"Expected object, got {type(value).__name__}") properties = schema.properties or {} required = schema.required or [] @@ -366,6 +384,20 @@ def serialize_for_output(obj: Any) -> Any: elif hasattr(obj, "isoformat"): # Handle any other datetime-like objects return obj.isoformat() + elif hasattr(obj, "model_dump"): + # Handle Pydantic v2 models + try: + serialized = obj.model_dump() + return TypeConverter.serialize_for_output(serialized) + except Exception: + return obj + elif hasattr(obj, "dict"): + # Handle Pydantic v1 models + try: + serialized = obj.dict() + return TypeConverter.serialize_for_output(serialized) + except Exception: + return obj else: return obj diff --git a/tests/server/fixtures/integration/python/test_container_types.py b/tests/server/fixtures/integration/python/test_container_types.py new file mode 100644 index 00000000..af2cbf6e --- /dev/null +++ b/tests/server/fixtures/integration/python/test_container_types.py @@ -0,0 +1,84 @@ +from typing import List, Dict, Optional +from pydantic import BaseModel + + +class User(BaseModel): + """Simple user model for testing container types.""" + + name: str + age: int + email: str + + +def process_user_list(users: List[User]) -> dict: + """Test function that takes a list of pydantic models.""" + total_age = sum(user.age for user in users) + avg_age = total_age / len(users) if users else 0 + names = [user.name for user in users] + + return { + "user_count": len(users), + "average_age": avg_age, + "names": names, + "total_age": total_age, + } + + +def process_user_dict(user_map: Dict[str, User]) -> dict: + """Test function that takes a dict of pydantic models.""" + user_count = len(user_map) + names = list(user_map.keys()) + ages = [user.age for user in user_map.values()] + avg_age = sum(ages) / len(ages) if ages else 0 + + return { + "user_count": user_count, + "user_keys": sorted(names), + "average_age": avg_age, + "oldest_user": max(user_map.values(), key=lambda u: u.age).name if user_map else None, + } + + +def process_optional_user(user: Optional[User] = None) -> dict: + """Test function with optional pydantic model.""" + if user is None: + return {"has_user": False, "message": "No user provided"} + + return { + "has_user": True, + "user_name": user.name, + "user_age": user.age, + "user_email": user.email, + } + + +# Test functions using built-in container types (Python 3.9+) +def process_builtin_user_list(users: list[User]) -> dict: + """Test function using built-in list[User] (Python 3.9+).""" + total_age = sum(user.age for user in users) + avg_age = total_age / len(users) if users else 0 + names = [user.name for user in users] + + return { + "user_count": len(users), + "average_age": avg_age, + "names": names, + "total_age": total_age, + "type_used": "builtin_list", + } + + +def process_builtin_user_dict(user_map: dict[str, User]) -> dict: + """Test function using built-in dict[str, User] (Python 3.9+).""" + user_count = len(user_map) + names = list(user_map.keys()) + ages = [user.age for user in user_map.values()] + avg_age = sum(ages) / len(ages) if ages else 0 + + return { + "user_count": user_count, + "user_keys": sorted(names), + "average_age": avg_age, + "oldest_user": max(user_map.values(), key=lambda u: u.age).name if user_map else None, + "type_used": "builtin_dict", + } diff --git a/tests/server/fixtures/integration/python/test_endpoints.py b/tests/server/fixtures/integration/python/test_endpoints.py index 3bbd3b1b..1fe9f592 100644 --- a/tests/server/fixtures/integration/python/test_endpoints.py +++ b/tests/server/fixtures/integration/python/test_endpoints.py @@ -1,5 +1,6 @@ from typing import Dict, Any from mxcp.runtime import config, on_init +from pydantic import BaseModel global_var = None @@ -121,3 +122,53 @@ def process_user_data(user_data: Dict[str, Any]) -> Dict[str, Any]: } return {"original_data": user_data, "analysis": analysis, "processing_status": "success"} + + +# Pydantic models for testing +class UserProfile(BaseModel): + """User profile model.""" + + name: str + age: int + email: str + is_premium: bool = False + + +class UserStats(BaseModel): + """User statistics model.""" + + total_users: int + active_users: int + premium_users: int + average_age: float + + +def validate_user_profile(profile: UserProfile) -> str: + """Take a pydantic model as parameter and return a primitive result.""" + # Validate and process the user profile + if profile.age < 0: + return "Invalid age: must be non-negative" + + if not profile.email or "@" not in profile.email: + return "Invalid email format" + + status = "premium" if profile.is_premium else "regular" + return f"User {profile.name} ({profile.age} years old, {profile.email}) is a {status} user - validation passed" + + +def get_user_stats(user_count: int) -> UserStats: + """Take a primitive argument and return a pydantic model.""" + # Generate some mock statistics based on the user count + active_ratio = 0.8 + premium_ratio = 0.3 + + active_users = int(user_count * active_ratio) + premium_users = int(user_count * premium_ratio) + average_age = 32.5 # Mock average age + + return UserStats( + total_users=user_count, + active_users=active_users, + premium_users=premium_users, + average_age=average_age, + ) diff --git a/tests/server/fixtures/integration/tools/get_user_stats.yml b/tests/server/fixtures/integration/tools/get_user_stats.yml new file mode 100644 index 00000000..e8c32855 --- /dev/null +++ b/tests/server/fixtures/integration/tools/get_user_stats.yml @@ -0,0 +1,32 @@ +mxcp: 1 +tool: + name: get_user_stats + description: Get user statistics based on user count, returns pydantic model + language: python + source: + file: ../python/test_endpoints.py + parameters: + - name: user_count + type: integer + description: Total number of users to generate stats for + return: + type: object + description: User statistics object + properties: + total_users: + type: integer + description: Total number of users + active_users: + type: integer + description: Number of active users + premium_users: + type: integer + description: Number of premium users + average_age: + type: number + description: Average age of users + required: + - total_users + - active_users + - premium_users + - average_age diff --git a/tests/server/fixtures/integration/tools/process_user_dict.yml b/tests/server/fixtures/integration/tools/process_user_dict.yml new file mode 100644 index 00000000..2b781f1d --- /dev/null +++ b/tests/server/fixtures/integration/tools/process_user_dict.yml @@ -0,0 +1,37 @@ +mxcp: 1 +tool: + name: process_user_dict + description: Process a dictionary of users (pydantic models in dict containers) + language: python + source: + file: ../python/test_container_types.py + parameters: + - name: user_map + type: object + additionalProperties: + type: object + properties: + name: + type: string + age: + type: integer + email: + type: string + required: + - name + - age + - email + description: Dictionary mapping keys to user objects + return: + type: object + properties: + user_count: + type: integer + user_keys: + type: array + items: + type: string + average_age: + type: number + oldest_user: + type: string diff --git a/tests/server/fixtures/integration/tools/process_user_list.yml b/tests/server/fixtures/integration/tools/process_user_list.yml new file mode 100644 index 00000000..e3d6dd03 --- /dev/null +++ b/tests/server/fixtures/integration/tools/process_user_list.yml @@ -0,0 +1,37 @@ +mxcp: 1 +tool: + name: process_user_list + description: Process a list of users (pydantic models in containers) + language: python + source: + file: ../python/test_container_types.py + parameters: + - name: users + type: array + items: + type: object + properties: + name: + type: string + age: + type: integer + email: + type: string + required: + - name + - age + - email + description: List of user objects + return: + type: object + properties: + user_count: + type: integer + average_age: + type: number + names: + type: array + items: + type: string + total_age: + type: integer diff --git a/tests/server/fixtures/integration/tools/validate_user_profile.yml b/tests/server/fixtures/integration/tools/validate_user_profile.yml new file mode 100644 index 00000000..fbf80910 --- /dev/null +++ b/tests/server/fixtures/integration/tools/validate_user_profile.yml @@ -0,0 +1,31 @@ +mxcp: 1 +tool: + name: validate_user_profile + description: Validate a user profile using pydantic model and return validation result + language: python + source: + file: ../python/test_endpoints.py + parameters: + - name: profile + type: object + description: User profile to validate + properties: + name: + type: string + description: User's full name + age: + type: integer + description: User's age + email: + type: string + description: User's email address + is_premium: + type: boolean + description: Whether user has premium status + required: + - name + - age + - email + return: + type: string + description: Validation result message diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 8e590fe9..a139f12d 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -744,3 +744,77 @@ async def test_complex_object_input(self, integration_fixture_dir): assert analysis["is_premium"] is True assert "123 Main St, San Francisco, USA" in analysis["full_address"] assert "John Doe is a 25-year-old premium user" in analysis["summary"] + + @pytest.mark.asyncio + async def test_pydantic_model_input(self, integration_fixture_dir): + """Test tool that takes a pydantic model as input parameter and returns primitive.""" + with ServerProcess(integration_fixture_dir) as server: + server.start() + + async with MCPTestClient(server.port) as client: + # Create a user profile object matching the pydantic model + profile_data = { + "name": "Jane Smith", + "age": 29, + "email": "jane.smith@example.com", + "is_premium": True, + } + + # Call the tool with the pydantic model data + result = await client.call_tool("validate_user_profile", {"profile": profile_data}) + + # Verify the result is a string (primitive) with expected content + assert isinstance(result, str) + assert "Jane Smith" in result + assert "29 years old" in result + assert "jane.smith@example.com" in result + assert "premium user" in result + assert "validation passed" in result + + @pytest.mark.asyncio + async def test_pydantic_model_output(self, integration_fixture_dir): + """Test tool that takes primitive input and returns a pydantic model.""" + with ServerProcess(integration_fixture_dir) as server: + server.start() + + async with MCPTestClient(server.port) as client: + # Call the tool with a primitive integer + result = await client.call_tool("get_user_stats", {"user_count": 100}) + + # Verify the result is an object (pydantic model serialized) with expected structure + assert isinstance(result, dict) + assert "total_users" in result + assert "active_users" in result + assert "premium_users" in result + assert "average_age" in result + + # Verify the calculated values + assert result["total_users"] == 100 + assert result["active_users"] == 80 # 80% of 100 + assert result["premium_users"] == 30 # 30% of 100 + assert result["average_age"] == 32.5 + + @pytest.mark.asyncio + async def test_pydantic_validation_error(self, integration_fixture_dir): + """Test pydantic model validation with invalid data.""" + with ServerProcess(integration_fixture_dir) as server: + server.start() + + async with MCPTestClient(server.port) as client: + # Create invalid profile data (negative age) + invalid_profile = { + "name": "Invalid User", + "age": -5, + "email": "invalid@example.com", + "is_premium": False, + } + + # Call the tool with invalid data + result = await client.call_tool( + "validate_user_profile", {"profile": invalid_profile} + ) + + # Verify the validation error is returned + assert isinstance(result, str) + assert "Invalid age" in result + assert "must be non-negative" in result