diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py index b71aab30a..31deceb5a 100644 --- a/src/marvin/__init__.py +++ b/src/marvin/__init__.py @@ -1,5 +1,6 @@ from .settings import settings +# legacy from .components import ai_fn, ai_model, ai_classifier from .components.prompt.fn import prompt_fn diff --git a/src/marvin/beta/applications/state/state.py b/src/marvin/beta/applications/state/state.py index 1013e3425..806e66545 100644 --- a/src/marvin/beta/applications/state/state.py +++ b/src/marvin/beta/applications/state/state.py @@ -1,5 +1,5 @@ +import inspect import json -import textwrap from typing import Optional, Union from jsonpatch import JsonPatch @@ -71,7 +71,7 @@ def as_tool(self, name: str = None) -> "Tool": name = "state" schema = self.get_schema() if schema: - description = textwrap.dedent( + description = inspect.cleandoc( f"Update the {name} object using JSON Patch documents. Updates will" " fail if they do not comply with the following" " schema:\n\n```json\n{schema}\n```" diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index 7ceb97f0a..fdaef254b 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -84,6 +84,8 @@ class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin): - {{ arg }}: {{ value }} {% endfor %} + + What is its output? """ ) diff --git a/src/marvin/components/ai_image.py b/src/marvin/components/ai_image.py index dd6e7bcdb..5818ac18b 100644 --- a/src/marvin/components/ai_image.py +++ b/src/marvin/components/ai_image.py @@ -1,4 +1,5 @@ import asyncio +import inspect from functools import partial, wraps from typing import ( Any, @@ -24,9 +25,15 @@ ) T = TypeVar("T") - P = ParamSpec("P") +DEFAULT_PROMPT = inspect.cleandoc( + """ + {{_doc}} + {{_return_value}} + """ +) + class AIImageKwargs(TypedDict): environment: NotRequired[BaseEnvironment] @@ -38,7 +45,7 @@ class AIImageKwargs(TypedDict): class AIImageKwargsDefaults(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) environment: Optional[BaseEnvironment] = None - prompt: Optional[str] = None + prompt: Optional[str] = DEFAULT_PROMPT client: Optional[Client] = None aclient: Optional[AsyncClient] = None @@ -47,7 +54,7 @@ class AIImage(BaseModel, Generic[P]): model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) fn: Optional[Callable[P, Any]] = None environment: Optional[BaseEnvironment] = None - prompt: Optional[str] = Field(default=None) + prompt: Optional[str] = Field(default=DEFAULT_PROMPT) client: Client = Field(default_factory=lambda: MarvinClient().client) aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client) @@ -73,16 +80,12 @@ def as_prompt( *args: P.args, **kwargs: P.kwargs, ) -> str: - return ( - PromptFunction[BaseModel] - .as_tool_call( - fn=self.fn, - environment=self.environment, - prompt=self.prompt, - )(*args, **kwargs) - .messages[0] - .content + tool_call = PromptFunction[BaseModel].as_tool_call( + fn=self.fn, + environment=self.environment, + prompt=self.prompt, ) + return tool_call(*args, **kwargs).messages[0].content @overload @classmethod diff --git a/src/marvin/components/prompt/fn.py b/src/marvin/components/prompt/fn.py index 1dfa70cb8..cac85777e 100644 --- a/src/marvin/components/prompt/fn.py +++ b/src/marvin/components/prompt/fn.py @@ -32,6 +32,34 @@ U = TypeVar("U", bound=BaseModel) +def fn_to_messages( + fn: Callable, + fn_args, + fn_kwargs, + prompt=None, + render_kwargs=None, + call_fn: bool = True, +) -> list[Message]: + prompt = prompt or fn.__doc__ or "" + + signature = inspect.signature(fn) + params = signature.bind(*fn_args, **fn_kwargs) + params.apply_defaults() + return_annotation = inspect.signature(fn).return_annotation + return_value = fn(*fn_args, **fn_kwargs) if call_fn else None + + messages = Transcript(content=prompt).render_to_messages( + **fn_kwargs | params.arguments, + _arguments=params.arguments, + _doc=inspect.getdoc(fn), + _return_value=return_value, + _return_annotation=return_annotation, + _source_code=("\ndef" + "def".join(re.split("def", inspect.getsource(fn))[1:])), + **(render_kwargs or {}), + ) + return messages + + class PromptFunction(Prompt[U]): model_config = pydantic.ConfigDict( extra="allow", @@ -99,16 +127,17 @@ def as_grammar( Callable[[Callable[P, Any]], Callable[P, Self]], Callable[P, Self], ]: - def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: - # Get the signature of the function - signature = inspect.signature(func) - params = signature.bind(*args, **kwargs) - params.apply_defaults() - + def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self: vocabulary = create_vocabulary_from_type( inspect.signature(func).return_annotation ) - + messages = fn_to_messages( + fn=fn, + fn_args=args, + fn_kwargs=kwargs_, + prompt=prompt, + render_kwargs=dict(_options=vocabulary), + ) grammar = create_grammar_from_vocabulary( vocabulary=vocabulary, encoder=encoder, @@ -116,18 +145,6 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: max_tokens=max_tokens, ) - messages = Transcript( - content=prompt or func.__doc__ or "" - ).render_to_messages( - **kwargs | params.arguments, - _arguments=params.arguments, - _options=vocabulary, - _doc=func.__doc__, - _source_code=( - "\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:]) - ), - ) - return cls( messages=messages, temperature=temperature, @@ -154,6 +171,7 @@ def as_tool_call( model_description: str = "Formats the response.", field_name: str = "data", field_description: str = "The data to format.", + render_kwargs: Optional[dict[str, Any]] = None, ) -> Callable[[Callable[P, Any]], Callable[P, Self]]: pass @@ -169,6 +187,7 @@ def as_tool_call( model_description: str = "Formats the response.", field_name: str = "data", field_description: str = "The data to format.", + render_kwargs: Optional[dict[str, Any]] = None, ) -> Callable[P, Self]: pass @@ -183,15 +202,13 @@ def as_tool_call( model_description: str = "Formats the response.", field_name: str = "data", field_description: str = "The data to format.", + render_kwargs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Union[ Callable[[Callable[P, Any]], Callable[P, Self]], Callable[P, Self], ]: def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self: - signature = inspect.signature(func) - params = signature.bind(*args, **kwargs_) - params.apply_defaults() _type = inspect.signature(func).return_annotation if _type is inspect._empty: _type = str @@ -204,16 +221,13 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self: field_description=field_description, ) - messages = Transcript( - content=prompt or func.__doc__ or "" - ).render_to_messages( - **kwargs_ | params.arguments, - _doc=func.__doc__, - _arguments=params.arguments, - _response_model=toolset.tools[0], # type: ignore - _source_code=( - "\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:]) - ), + messages = fn_to_messages( + fn=fn, + fn_args=args, + fn_kwargs=kwargs_, + prompt=prompt, + render_kwargs=(render_kwargs or {}) + | dict(_response_model=toolset.tools[0]), ) return cls(