Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 96 additions & 36 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import inspect
import json
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -18,6 +18,11 @@

from marvin.components.prompt import PromptFunction
from marvin.serializers import create_tool_from_type
from marvin.utilities.asyncio import (
ExposeSyncMethodsMixin,
expose_sync_method,
run_async,
)
from marvin.utilities.jinja import (
BaseEnvironment,
)
Expand All @@ -30,7 +35,7 @@
P = ParamSpec("P")


class AIFunction(BaseModel, Generic[P, T]):
class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
fn: Optional[Callable[P, T]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=inspect.cleandoc("""
Expand All @@ -57,14 +62,32 @@ class AIFunction(BaseModel, Generic[P, T]):

create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
create = self.create
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Union[T, Awaitable[T]]:
if self.fn is None:
raise NotImplementedError
if create is None:
from marvin.settings import settings

create = settings.openai.chat.completions.create
from marvin import settings

is_async_fn = asyncio.iscoroutinefunction(self.fn)

call = "async_call" if is_async_fn else "sync_call"
create = (
self.create or settings.openai.chat.completions.acreate
if is_async_fn
else settings.openai.chat.completions.create
)

return getattr(self, call)(create, *args, **kwargs)

async def async_call(
self, acreate: Callable[..., Awaitable[Any]], *args: P.args, **kwargs: P.kwargs
) -> T:
_response = await acreate(**self.as_prompt(*args, **kwargs).serialize())
return self.parse(_response)

def sync_call(
self, create: Callable[..., Any], *args: P.args, **kwargs: P.kwargs
) -> T:
_response = create(**self.as_prompt(*args, **kwargs).serialize())
return self.parse(_response)

Expand Down Expand Up @@ -93,6 +116,46 @@ def parse(self, response: "ChatCompletion") -> T:
_arguments: str = json.dumps({self.field_name: json.loads(arguments)})
return getattr(tool.model.model_validate_json(_arguments), self.field_name)

@expose_sync_method("map")
async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[T]:
"""
Map the AI function over a sequence of arguments. Runs concurrently.

A `map` twin method is provided by the `expose_sync_method` decorator.

You can use `map` or `amap` synchronously or asynchronously, respectively,
regardless of whether the user function is synchronous or asynchronous.

Arguments should be provided as if calling the function normally, but
each argument must be a list. The function is called once for each item
in the list, and the results are returned in a list.

For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)].

fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')].
"""
tasks: list[Any] = []
if map_args and map_kwargs:
max_length = max(
len(arg) for arg in (map_args + tuple(map_kwargs.values()))
)
elif map_args:
max_length = max(len(arg) for arg in map_args)
else:
max_length = max(len(v) for v in map_kwargs.values())

for i in range(max_length):
call_args = [arg[i] if i < len(arg) else None for arg in map_args]
call_kwargs = (
{k: v[i] if i < len(v) else None for k, v in map_kwargs.items()}
if map_kwargs
else {}
)

tasks.append(run_async(self, *call_args, **call_kwargs))

return await asyncio.gather(*tasks)

def as_prompt(
self,
*args: P.args,
Expand Down Expand Up @@ -153,33 +216,24 @@ def as_decorator(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
acreate: Optional[Callable[..., Awaitable[Any]]] = None,
**render_kwargs: Any,
) -> Union[Self, Callable[[Callable[P, T]], Self]]:
if fn is None:
return partial(
cls,
) -> Union[Callable[[Callable[P, T]], Self], Self]:
def decorator(func: Callable[P, T]) -> Self:
return cls(
fn=func,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
name=model_name,
description=model_description,
field_name=field_name,
field_description=field_description,
acreate=acreate,
**({"prompt": prompt} if prompt else {}),
**render_kwargs,
)

return cls(
fn=fn,
environment=environment,
name=model_name,
description=model_description,
field_name=field_name,
field_description=field_description,
**({"prompt": prompt} if prompt else {}),
**render_kwargs,
)
if fn is not None:
return decorator(fn)

return decorator


@overload
Expand Down Expand Up @@ -221,23 +275,29 @@ def ai_fn(
field_name: str = "data",
field_description: str = "The data to format.",
**render_kwargs: Any,
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T],]:
def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
return AIFunction[P, T].as_decorator(
func,
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
if fn is not None:
return AIFunction.as_decorator( # type: ignore
fn=fn,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
field_name=field_name,
field_description=field_description,
**render_kwargs,
)(*args, **kwargs)

if fn is not None:
return wraps(fn)(partial(wrapper, fn))
)

def decorator(fn: Callable[P, T]) -> Callable[P, T]:
return wraps(fn)(partial(wrapper, fn))
def decorator(func: Callable[P, T]) -> Callable[P, T]:
return AIFunction.as_decorator( # type: ignore
fn=func,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
field_name=field_name,
field_description=field_description,
**render_kwargs,
)

return decorator
Loading