diff --git a/src/marvin/_mappings/chat_completion.py b/src/marvin/_mappings/chat_completion.py index b19457181..fef3f94f3 100644 --- a/src/marvin/_mappings/chat_completion.py +++ b/src/marvin/_mappings/chat_completion.py @@ -6,7 +6,6 @@ Callable, TypeVar, Union, - cast, ) from pydantic import BaseModel, TypeAdapter, ValidationError @@ -41,9 +40,6 @@ def chat_completion_to_model( data: dict[str, Any] = {} data[field_name] = json.loads(tool_arguments[0]) return response_model.model_validate_json(json.dumps(data)) - else: - data: dict[str, Any] = json.loads(tool_arguments[0]) - return cast(T, data) def chat_completion_to_type(response_type: U, completion: "ChatCompletion") -> "U": diff --git a/src/marvin/_mappings/types.py b/src/marvin/_mappings/types.py index f10aec0f4..d1a1f49b9 100644 --- a/src/marvin/_mappings/types.py +++ b/src/marvin/_mappings/types.py @@ -26,6 +26,7 @@ def cast_type_to_model( return create_model( model_name, + __doc__=model_description, __config__=None, __base__=None, __module__=__name__, diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index ae7aa7b37..676678132 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -17,10 +17,12 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import NotRequired, ParamSpec, Self, Unpack +import marvin from marvin._mappings.chat_completion import chat_completion_to_model from marvin.client.openai import AsyncMarvinClient, MarvinClient from marvin.components.prompt.fn import PromptFunction from marvin.utilities.jinja import BaseEnvironment +from marvin.utilities.logging import get_logger if TYPE_CHECKING: from openai.types.chat import ChatCompletion @@ -44,17 +46,17 @@ class AIFunctionKwargs(TypedDict): class AIFunctionKwargsDefaults(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) environment: Optional[BaseEnvironment] = None prompt: Optional[str] = None model_name: str = "FormatResponse" model_description: str = "Formats the response." field_name: str = "data" field_description: str = "The data to format." - model: Optional[str] = None + model: str = marvin.settings.openai.chat.completions.model client: Optional[Client] = None aclient: Optional[AsyncClient] = None - temperature: Optional[float] = None + temperature: Optional[float] = marvin.settings.openai.chat.completions.temperature class AIFunction( @@ -89,6 +91,10 @@ class AIFunction( client: Client = Field(default_factory=lambda: MarvinClient().client) aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client) + @property + def logger(self): + return get_logger(self.__class__.__name__) + def __call__( self, *args: P.args, **kwargs: P.kwargs ) -> Union[T, Coroutine[Any, Any, T]]: @@ -101,6 +107,7 @@ def call(self, *args: P.args, **kwargs: P.kwargs) -> T: response: ChatCompletion = MarvinClient(client=self.client).chat( **prompt.serialize() ) + self.logger.debug_kv("Calling", f"{self.fn.__name__}({args}, {kwargs})", "blue") return getattr( chat_completion_to_model(model, response, field_name=self.field_name), self.field_name, diff --git a/src/marvin/settings.py b/src/marvin/settings.py index fb3a13b0a..3a754d41b 100644 --- a/src/marvin/settings.py +++ b/src/marvin/settings.py @@ -22,6 +22,7 @@ class MarvinSettings(BaseSettings): def __setattr__(self, name: str, value: Any) -> None: + # wrap bare strings in SecretStr if the field is annotated with SecretStr field = self.model_fields.get(name) if field: annotation = field.annotation @@ -46,6 +47,10 @@ class ChatCompletionSettings(MarvinSettings): description="The default chat model to use.", default="gpt-3.5-turbo" ) + temperature: float = Field( + description="The default temperature to use.", default=0.1 + ) + @property def encoder(self): import tiktoken