Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/marvin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .settings import settings

from .components import ai_fn, ai_model, ai_classifier
from .components import ai_fn, ai_model, ai_classifier, ai_image
from .components.prompt.fn import prompt_fn

try:
Expand All @@ -12,6 +12,7 @@
"ai_fn",
"ai_model",
"ai_classifier",
"ai_image",
"prompt_fn",
"settings",
]
2 changes: 2 additions & 0 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
- {{ arg }}: {{ value }}
{% endfor %}



What is its output?
"""
)
Expand Down
27 changes: 15 additions & 12 deletions src/marvin/components/ai_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import textwrap
from functools import partial, wraps
from typing import (
Any,
Expand All @@ -24,9 +25,15 @@
)

T = TypeVar("T")

P = ParamSpec("P")

DEFAULT_PROMPT = textwrap.dedent(
"""
{{_doc}}
{{_return_value}}
"""
)


class AIImageKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
Expand All @@ -38,7 +45,7 @@ class AIImageKwargs(TypedDict):
class AIImageKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = None
prompt: Optional[str] = DEFAULT_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None

Expand All @@ -47,7 +54,7 @@ class AIImage(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=None)
prompt: Optional[str] = Field(default=DEFAULT_PROMPT)
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

Expand All @@ -73,16 +80,12 @@ def as_prompt(
*args: P.args,
**kwargs: P.kwargs,
) -> str:
return (
PromptFunction[BaseModel]
.as_tool_call(
fn=self.fn,
environment=self.environment,
prompt=self.prompt,
)(*args, **kwargs)
.messages[0]
.content
tool_call = PromptFunction[BaseModel].as_tool_call(
fn=self.fn,
environment=self.environment,
prompt=self.prompt,
)
return tool_call(*args, **kwargs).messages[0].content

@overload
@classmethod
Expand Down
78 changes: 46 additions & 32 deletions src/marvin/components/prompt/fn.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a pattern I like of keeping type mappings in a mapping layer

I can move this to our mapping layer in the future, but colocating all these mappings is 🔥

Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@
U = TypeVar("U", bound=BaseModel)


def fn_to_messages(
fn: Callable,
fn_args,
fn_kwargs,
prompt=None,
render_kwargs=None,
call_fn: bool = True,
) -> list[Message]:
prompt = prompt or fn.__doc__ or ""

signature = inspect.signature(fn)
params = signature.bind(*fn_args, **fn_kwargs)
params.apply_defaults()
return_annotation = inspect.signature(fn).return_annotation
return_value = fn(*fn_args, **fn_kwargs) if call_fn else None

messages = Transcript(content=prompt).render_to_messages(
**fn_kwargs | params.arguments,
_arguments=params.arguments,
_doc=inspect.getdoc(fn),
_return_value=return_value,
_return_annotation=return_annotation,
_source_code=("\ndef" + "def".join(re.split("def", inspect.getsource(fn))[1:])),
**(render_kwargs or {}),
)
return messages


class PromptFunction(Prompt[U]):
model_config = pydantic.ConfigDict(
extra="allow",
Expand Down Expand Up @@ -99,35 +127,24 @@ def as_grammar(
Callable[[Callable[P, Any]], Callable[P, Self]],
Callable[P, Self],
]:
def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
# Get the signature of the function
signature = inspect.signature(func)
params = signature.bind(*args, **kwargs)
params.apply_defaults()

def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
vocabulary = create_vocabulary_from_type(
inspect.signature(func).return_annotation
)

messages = fn_to_messages(
fn=fn,
fn_args=args,
fn_kwargs=kwargs_,
prompt=prompt,
render_kwargs=dict(_options=vocabulary),
)
grammar = create_grammar_from_vocabulary(
vocabulary=vocabulary,
encoder=encoder,
_enumerate=enumerate,
max_tokens=max_tokens,
)

messages = Transcript(
content=prompt or func.__doc__ or ""
).render_to_messages(
**kwargs | params.arguments,
_arguments=params.arguments,
_options=vocabulary,
_doc=func.__doc__,
_source_code=(
"\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:])
),
)

return cls(
messages=messages,
temperature=temperature,
Expand All @@ -154,6 +171,7 @@ def as_tool_call(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
render_kwargs: Optional[dict[str, Any]] = None,
) -> Callable[[Callable[P, Any]], Callable[P, Self]]:
pass

Expand All @@ -169,6 +187,7 @@ def as_tool_call(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
render_kwargs: Optional[dict[str, Any]] = None,
) -> Callable[P, Self]:
pass

Expand All @@ -183,15 +202,13 @@ def as_tool_call(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
render_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Union[
Callable[[Callable[P, Any]], Callable[P, Self]],
Callable[P, Self],
]:
def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
signature = inspect.signature(func)
params = signature.bind(*args, **kwargs_)
params.apply_defaults()
_type = inspect.signature(func).return_annotation
if _type is inspect._empty:
_type = str
Expand All @@ -204,16 +221,13 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
field_description=field_description,
)

messages = Transcript(
content=prompt or func.__doc__ or ""
).render_to_messages(
**kwargs_ | params.arguments,
_doc=func.__doc__,
_arguments=params.arguments,
_response_model=toolset.tools[0], # type: ignore
_source_code=(
"\ndef" + "def".join(re.split("def", inspect.getsource(func))[1:])
),
messages = fn_to_messages(
fn=fn,
fn_args=args,
fn_kwargs=kwargs_,
prompt=prompt,
render_kwargs=(render_kwargs or {})
| dict(_response_model=toolset.tools[0]),
)

return cls(
Expand Down