diff --git a/cookbook/flows/github_digest.py b/cookbook/flows/github_digest.py
index 8717f074b..96619ae78 100644
--- a/cookbook/flows/github_digest.py
+++ b/cookbook/flows/github_digest.py
@@ -3,7 +3,7 @@
import httpx
import marvin
-from marvin import ai_fn
+from marvin import fn
from marvin.utilities.strings import jinja_env
from prefect import flow, task
from prefect.artifacts import create_markdown_artifact
@@ -47,7 +47,7 @@
) # noqa: E501
-@ai_fn
+@fn
async def summarize_digest(markdown_digest: str) -> str:
"""Produce a short story based on the GitHub digest.
diff --git a/cookbook/maze.py b/cookbook/maze.py
index ff7e57fb8..349b58f4c 100644
--- a/cookbook/maze.py
+++ b/cookbook/maze.py
@@ -14,12 +14,12 @@
import random
from enum import Enum
from io import StringIO
+from typing import Literal
from marvin.beta.applications import AIApplication
from pydantic import BaseModel
from rich.console import Console
from rich.table import Table
-from typing_extensions import Literal
GAME_INSTRUCTIONS = """
This is a TERROR game. You are the disembodied narrator of a maze. You've hidden a key somewhere in the
@@ -205,7 +205,8 @@ def move(self, direction: CardinalDirection) -> str:
if move_monster := random.random() < 0.4:
self.shuffle_monster()
return (
- f"User moved {direction} and is now at {self.user_location}.\n{self.render()}"
+ f"User moved {direction} and is now at"
+ f" {self.user_location}.\n{self.render()}"
f"\nThe user may move in any of the following {self.movable_directions()!r}"
f"\n{'The monster moved somewhere.' if move_monster else ''}"
)
diff --git a/cookbook/slackbot/keywords.py b/cookbook/slackbot/keywords.py
index 1cb7c670b..b3d46ed73 100644
--- a/cookbook/slackbot/keywords.py
+++ b/cookbook/slackbot/keywords.py
@@ -1,4 +1,4 @@
-from marvin import ai_fn
+from marvin import fn
from marvin.utilities.slack import post_slack_message
from prefect import task
from prefect.blocks.system import JSON, Secret, String
@@ -36,7 +36,7 @@ async def get_reduced_kw_relationship_map() -> dict:
}
-@ai_fn
+@fn
def activation_score(message: str, keyword: str, target_relationship: str) -> float:
"""Return a score between 0 and 1 indicating whether the target relationship exists
between the message and the keyword"""
diff --git a/cookbook/slackbot/parent_app.py b/cookbook/slackbot/parent_app.py
index 0437c3634..4c91b34b0 100644
--- a/cookbook/slackbot/parent_app.py
+++ b/cookbook/slackbot/parent_app.py
@@ -3,7 +3,7 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
-from marvin import ai_fn
+from marvin import fn
from marvin.beta.applications import AIApplication
from marvin.beta.applications.state.json_block import JSONBlockState
from marvin.beta.assistants import Assistant
@@ -28,7 +28,7 @@ class Lesson(TypedDict):
heuristic: str | None
-@ai_fn(model="gpt-3.5-turbo-1106")
+@fn(model="gpt-3.5-turbo-1106")
def take_lesson_from_interaction(
transcript: str,
assistant_instructions: str,
diff --git a/docs/api_reference/components/ai_classifier.md b/docs/api_reference/components/ai_classifier.md
index ec5a288dc..75bf3dccd 100644
--- a/docs/api_reference/components/ai_classifier.md
+++ b/docs/api_reference/components/ai_classifier.md
@@ -1 +1 @@
-::: marvin.components.ai_classifier
\ No newline at end of file
+::: marvin.components.classifier
\ No newline at end of file
diff --git a/docs/api_reference/components/ai_function.md b/docs/api_reference/components/ai_function.md
deleted file mode 100644
index 2811b051d..000000000
--- a/docs/api_reference/components/ai_function.md
+++ /dev/null
@@ -1 +0,0 @@
-::: marvin.components.ai_function
\ No newline at end of file
diff --git a/docs/api_reference/components/ai_model.md b/docs/api_reference/components/ai_model.md
deleted file mode 100644
index 9d01d7585..000000000
--- a/docs/api_reference/components/ai_model.md
+++ /dev/null
@@ -1 +0,0 @@
-::: marvin.components.ai_model
\ No newline at end of file
diff --git a/docs/api_reference/components/functions.md b/docs/api_reference/components/functions.md
new file mode 100644
index 000000000..d529e0876
--- /dev/null
+++ b/docs/api_reference/components/functions.md
@@ -0,0 +1 @@
+::: marvin.components.function
\ No newline at end of file
diff --git a/docs/api_reference/components/models.md b/docs/api_reference/components/models.md
new file mode 100644
index 000000000..59882a1fa
--- /dev/null
+++ b/docs/api_reference/components/models.md
@@ -0,0 +1 @@
+::: marvin.components.model
\ No newline at end of file
diff --git a/docs/api_reference/index.md b/docs/api_reference/index.md
index e92675c0a..7239a8abb 100644
--- a/docs/api_reference/index.md
+++ b/docs/api_reference/index.md
@@ -1,7 +1,7 @@
# Sections
## Components
-- [AI Classifiers](/api_reference/components/ai_classifier/)
+- [AI Classifiers](/api_reference/components/classifier/)
- [AI Functions](/api_reference/components/ai_function/)
- [AI Models](/api_reference/components/ai_model/)
diff --git a/docs/components/ai_classifier.md b/docs/components/classifier.md
similarity index 84%
rename from docs/components/ai_classifier.md
rename to docs/components/classifier.md
index f60fe121c..43768d5ab 100644
--- a/docs/components/ai_classifier.md
+++ b/docs/components/classifier.md
@@ -5,13 +5,13 @@ AI Classifiers are a high-level component, or building block, of Marvin. Like al
What it does
- @ai_classifier is a decorator that lets you use LLMs to choose options, tools, or classify input.
+ @classifier is a decorator that lets you use LLMs to choose options, tools, or classify input.
!!! example
```python
- from marvin import ai_classifier
+ from marvin import classifier
from enum import Enum
class CustomerIntent(Enum):
@@ -26,7 +26,7 @@ AI Classifiers are a high-level component, or building block, of Marvin. Like al
ACCOUNT_CANCELLATION = 'ACCOUNT_CANCELLATION'
OPERATOR_CUSTOMER_SERVICE = 'OPERATOR_CUSTOMER_SERVICE'
- @ai_classifier
+ @classifier
def classify_intent(text: str) -> CustomerIntent:
'''Classifies the most likely intent from user input'''
@@ -67,15 +67,15 @@ AI Classifiers are a high-level component, or building block, of Marvin. Like al
## Features
#### 🚅 Bulletproof
-`ai_classifier` will always output one of the options you've given it
+`classifier` will always output one of the options you've given it
```python
-from marvin import ai_classifier
+from marvin import classifier
from enum import Enum
-@ai_classifier
+@classifier
class AppRoute(Enum):
"""Represents distinct routes command bar for a different application"""
@@ -102,8 +102,8 @@ AppRoute("update my name")
#### 🏃 Fast
-`ai_classifier` only asks your LLM to output one token, so it's blazing fast - on the order of ~200ms in testing.
+`classifier` only asks your LLM to output one token, so it's blazing fast - on the order of ~200ms in testing.
#### 🫡 Deterministic
-`ai_classifier` will be deterministic so long as the underlying model and options does not change.
+`classifier` will be deterministic so long as the underlying model and options does not change.
diff --git a/docs/components/ai_function.md b/docs/components/functions.md
similarity index 100%
rename from docs/components/ai_function.md
rename to docs/components/functions.md
diff --git a/docs/components/ai_model.md b/docs/components/models.md
similarity index 99%
rename from docs/components/ai_model.md
rename to docs/components/models.md
index 5e40b30b7..f10b0e173 100644
--- a/docs/components/ai_model.md
+++ b/docs/components/models.md
@@ -393,7 +393,7 @@ CapTable("""\
import datetime
from typing import List
from pydantic import BaseModel
-from typing_extensions import Literal
+from typing import Literal
from marvin import ai_model
diff --git a/docs/components/overview.md b/docs/components/overview.md
index 9a9ff3ac6..4cea7708e 100644
--- a/docs/components/overview.md
+++ b/docs/components/overview.md
@@ -264,10 +264,10 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
!!! example "Example"
=== "As a decorator"
- `ai_classifier` can decorate python functions whose return annotation is an `Enum` or `Literal`. The prompt is tuned for classification tasks,
+ `classifier` can decorate python functions whose return annotation is an `Enum` or `Literal`. The prompt is tuned for classification tasks,
and uses a form of `constrained sampling` to make guarantee a fast valid choice.
```python
- from marvin import ai_classifier
+ from marvin import classifier
from enum import Enum
class AppRoute(Enum):
@@ -283,7 +283,7 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
PROJECTS = "/projects"
WORKSPACES = "/workspaces"
- @ai_classifier(client = client)
+ @classifier(client = client)
def classify_intent(text: str) -> AppRoute:
'''Classifies user's intent into most useful route'''
@@ -329,7 +329,7 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
=== "As a function"
```python
- from marvin import ai_classifier
+ from marvin import classifier
from enum import Enum
class AppRoute(Enum):
@@ -348,12 +348,12 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
def classify_intent(text: str) -> AppRoute:
'''Classifies user's intent into most useful route'''
- ai_classifier(classify_intent, client = client)("update my name")
+ classifier(classify_intent, client = client)("update my name")
```
??? info "Generated Prompt"
You can view and/or eject the generated prompt by simply calling
```python
- ai_classifier(classify_intent, client = client).as_prompt("update my name").serialize()
+ classifier(classify_intent, client = client).as_prompt("update my name").serialize()
```
When you do you'll see the raw payload that's sent to the LLM. The prompt you send is fully customizable.
```json
diff --git a/docs/welcome/overview.md b/docs/welcome/overview.md
index 5240d7b5c..efd1486df 100644
--- a/docs/welcome/overview.md
+++ b/docs/welcome/overview.md
@@ -31,10 +31,10 @@ or fully use its engine to work with OpenAI and other providers.
Marvin exposes a number of high level components to simplify working with AI.
```python
- from marvin import ai_classifier
+ from marvin import classifier
from typing import Literal
- @ai_classifier
+ @classifier
def customer_intent(text: str) -> Literal['Store Hours', 'Pharmacy', 'Returns']:
"""Classifies incoming customer intent"""
diff --git a/docs/welcome/quickstart.md b/docs/welcome/quickstart.md
index ea1bf9573..7ee5775d0 100644
--- a/docs/welcome/quickstart.md
+++ b/docs/welcome/quickstart.md
@@ -15,7 +15,7 @@ After [installing Marvin](../installation), the fastest way to get started is by
MARVIN_OPENAI_ORGANIZATION=org-xxx
```
- - Pass your API Key to Marvin's `OpenAI` client constructor and pass it to Marvin's `ai_fn`, `ai_classifier`, or `ai_model` decorators.
+ - Pass your API Key to Marvin's `OpenAI` client constructor and pass it to Marvin's `ai_fn`, `classifier`, or `ai_model` decorators.
```python
from marvin import ai_fn
@@ -233,10 +233,10 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
!!! example "Example"
=== "As a decorator"
- `ai_classifier` can decorate python functions whose return annotation is an `Enum` or `Literal`. The prompt is tuned for classification tasks,
+ `classifier` can decorate python functions whose return annotation is an `Enum` or `Literal`. The prompt is tuned for classification tasks,
and uses a form of `constrained sampling` to make guarantee a fast valid choice.
```python
- from marvin import ai_classifier
+ from marvin import classifier
from enum import Enum
class AppRoute(Enum):
@@ -252,7 +252,7 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
PROJECTS = "/projects"
WORKSPACES = "/workspaces"
- @ai_classifier(client = client)
+ @classifier(client = client)
def classify_intent(text: str) -> AppRoute:
'''Classifies user's intent into most useful route'''
@@ -298,7 +298,7 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
=== "As a function"
```python
- from marvin import ai_classifier
+ from marvin import classifier
from enum import Enum
class AppRoute(Enum):
@@ -317,12 +317,12 @@ AI Classifiers let you build multi-label classifiers with no code and no trainin
def classify_intent(text: str) -> AppRoute:
'''Classifies user's intent into most useful route'''
- ai_classifier(classify_intent, client = client)("update my name")
+ classifier(classify_intent, client = client)("update my name")
```
??? info "Generated Prompt"
You can view and/or eject the generated prompt by simply calling
```python
- ai_classifier(classify_intent, client = client).as_prompt("update my name").serialize()
+ classifier(classify_intent, client = client).as_prompt("update my name").serialize()
```
When you do you'll see the raw payload that's sent to the LLM. The prompt you send is fully customizable.
```json
diff --git a/docs/welcome/what_is_marvin.md b/docs/welcome/what_is_marvin.md
index b992b0180..bf3272edd 100644
--- a/docs/welcome/what_is_marvin.md
+++ b/docs/welcome/what_is_marvin.md
@@ -62,10 +62,10 @@ or fully use its engine to work with OpenAI and other providers.
Marvin exposes a number of high level components to simplify working with AI.
```python
- from marvin import ai_classifier
+ from marvin import classifier
from typing import Literal
- @ai_classifier
+ @classifier
def customer_intent(text: str) -> Literal['Store Hours', 'Pharmacy', 'Returns']:
"""Classifies incoming customer intent"""
diff --git a/mkdocs.yml b/mkdocs.yml
index f4a2a761f..83acc6157 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -23,9 +23,9 @@ nav:
- Components:
- Overview: components/overview.md
- - AI Function: components/ai_function.md
+ - Function: components/functions.md
- AI Model: components/ai_model.md
- - AI Classifier: components/ai_classifier.md
+ - Classifier: components/ai_classifier.md
- AI Application: components/ai_application.md
- Examples:
@@ -35,7 +35,7 @@ nav:
- api_reference/index.md
- AI Components:
- ai_classifier: api_reference/components/ai_classifier.md
- - ai_function: api_reference/components/ai_function.md
+ - function: api_reference/components/functions.md
- ai_model: api_reference/components/ai_model.md
- Settings:
- settings: api_reference/settings.md
diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py
index 31deceb5a..b2ffa56d8 100644
--- a/src/marvin/__init__.py
+++ b/src/marvin/__init__.py
@@ -1,7 +1,6 @@
from .settings import settings
-# legacy
-from .components import ai_fn, ai_model, ai_classifier
+from .components import fn, image, speech, model, cast, extract, classify
from .components.prompt.fn import prompt_fn
try:
@@ -10,9 +9,18 @@
__version__ = "unknown"
__all__ = [
- "ai_fn",
- "ai_model",
- "ai_classifier",
+ "fn",
+ "image",
+ "model",
+ "cast",
+ "extract",
+ "classify",
+ "speech",
"prompt_fn",
"settings",
]
+
+
+# compatibility with Marvin v1
+from .components import fn as ai_fn, model as ai_model
+from .components.classifier import classifier as ai_classifier
diff --git a/src/marvin/_mappings/types.py b/src/marvin/_mappings/types.py
index a32472e04..609c0fe97 100644
--- a/src/marvin/_mappings/types.py
+++ b/src/marvin/_mappings/types.py
@@ -1,10 +1,9 @@
from enum import Enum
from types import GenericAlias
-from typing import Any, Callable, Optional, Union, get_args, get_origin
+from typing import Any, Callable, Literal, Optional, Union, get_args, get_origin
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
-from typing_extensions import Literal
from marvin.requests import Grammar, Tool, ToolSet
from marvin.settings import settings
diff --git a/src/marvin/ai.py b/src/marvin/ai.py
index 562678dc9..ea18a2589 100644
--- a/src/marvin/ai.py
+++ b/src/marvin/ai.py
@@ -1,6 +1,6 @@
from marvin.client.openai import paint, speak
-from marvin.components.ai_classifier import ai_classifier as classifier
-from marvin.components.ai_function import ai_fn as fn
-from marvin.components.ai_model import ai_model as model
+from marvin.components.classifier import classifier
+from marvin.components.function import fn
+from marvin.components.model import model
__all__ = ["speak", "fn", "model", "speak", "paint", "classifier"]
diff --git a/src/marvin/beta/applications/__init__.py b/src/marvin/beta/applications/__init__.py
index 4911249ee..1e9f86e2b 100644
--- a/src/marvin/beta/applications/__init__.py
+++ b/src/marvin/beta/applications/__init__.py
@@ -1 +1 @@
-from .applications import AIApplication
+from .applications import Application
diff --git a/src/marvin/beta/applications/applications.py b/src/marvin/beta/applications/applications.py
index 0a72de174..62ba3a063 100644
--- a/src/marvin/beta/applications/applications.py
+++ b/src/marvin/beta/applications/applications.py
@@ -11,7 +11,7 @@
from marvin.utilities.tools import tool_from_function
APPLICATION_INSTRUCTIONS = """
-# AI Application
+# Application
You are the natural language interface to an application called {{ self_.name
}}. Your job is to help the user interact with the application by translating
@@ -41,12 +41,12 @@
"""
-class AIApplication(Assistant):
+class Application(Assistant):
"""
- Tools for AI Applications have a special property: if any parameter is
- annotated as `AIApplication`, then the tool will be called with the
- AIApplication instance as the value for that parameter. This allows tools to
- access the AIApplication's state and other properties.
+ Tools for Applications have a special property: if any parameter is
+ annotated as `Application`, then the tool will be called with the
+ Application instance as the value for that parameter. This allows tools to
+ access the Application's state and other properties.
"""
state: State = Field(default_factory=State)
@@ -69,7 +69,7 @@ def get_tools(self) -> list[AssistantTool]:
signature = inspect.signature(tool)
parameter = None
for parameter in signature.parameters.values():
- if parameter.annotation == AIApplication:
+ if parameter.annotation == Application:
break
if parameter is not None:
kwargs = {parameter.name: self}
diff --git a/src/marvin/beta/applications/planner.py b/src/marvin/beta/applications/planner.py
index e68a58c44..9c0155258 100644
--- a/src/marvin/beta/applications/planner.py
+++ b/src/marvin/beta/applications/planner.py
@@ -5,7 +5,7 @@
from marvin.tools.assistants import AssistantTool
from marvin.utilities.jinja import Environment as JinjaEnvironment
-from .applications import AIApplication, State
+from .applications import Application, State
PLANNER_INSTRUCTIONS = """
To assist you with long-term planning and keeping track of multiple threads, you
@@ -45,7 +45,7 @@ class TaskList(BaseModel):
tasks: list[Task] = Field([], description="The list of tasks")
-class AIPlanner(AIApplication):
+class AIPlanner(Application):
plan: State = Field(default_factory=lambda: State(value=TaskList()))
def get_instructions(self) -> str:
diff --git a/src/marvin/components/__init__.py b/src/marvin/components/__init__.py
index d7d34504b..b46613c7a 100644
--- a/src/marvin/components/__init__.py
+++ b/src/marvin/components/__init__.py
@@ -1,17 +1,21 @@
-from .ai_function import ai_fn, AIFunction
-from .ai_classifier import ai_classifier, AIClassifier
-from .ai_model import ai_model
-from .ai_image import ai_image, AIImage
+from .function import fn, Function
+from .model import model
+from .text import cast, extract, classify
+from .classifier import classifier
+from .image import image
+from .speech import speech
from .prompt.fn import prompt_fn, PromptFunction
__all__ = [
- "ai_fn",
- "ai_classifier",
- "ai_model",
- "ai_image",
+ "fn",
+ "model",
+ "image",
+ "cast",
+ "extract",
+ "classify",
+ "classifier",
+ "speech",
"prompt_fn",
- "AIImage",
- "AIFunction",
- "AIClassifier",
+ "Function",
"PromptFunction",
]
diff --git a/src/marvin/components/ai_model.py b/src/marvin/components/ai_model.py
deleted file mode 100644
index 24dc14b59..000000000
--- a/src/marvin/components/ai_model.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import inspect
-from functools import partial
-from typing import Callable, Optional, TypeVar, Union, overload
-
-from typing_extensions import Unpack
-
-from marvin.components.ai_function import (
- AIFunctionKwargs,
- AIFunctionKwargsDefaults,
- ai_fn,
-)
-
-T = TypeVar("T")
-
-prompt = inspect.cleandoc(
- "The user will provide context as text that you need to parse into a structured"
- " form. To validate your response, you must call the"
- " `{{_response_model.function.name}}` function. Use the provided text to extract or"
- " infer any parameters needed by `{{_response_model.function.name}}`, including any"
- " missing data."
- " \n\nHUMAN: The text to parse: {{text}}"
-)
-
-
-class AIModelKwargsDefaults(AIFunctionKwargsDefaults):
- prompt: Optional[str] = prompt
-
-
-@overload
-def ai_model(
- **kwargs: Unpack[AIFunctionKwargs],
-) -> Callable[[Callable[[str], T]], Callable[[str], T]]:
- pass
-
-
-@overload
-def ai_model(
- _type: type[T],
- **kwargs: Unpack[AIFunctionKwargs],
-) -> Callable[[str], T]:
- pass
-
-
-def ai_model(
- _type: Optional[type[T]] = None,
- **kwargs: Unpack[AIFunctionKwargs],
-) -> Union[
- Callable[
- [Callable[[str], T]],
- Callable[[str], T],
- ],
- partial[
- Callable[
- [Callable[[str], T]],
- Callable[[str], T],
- ]
- ],
- Callable[[str], T],
-]:
- if _type is not None:
-
- def extract(text: str) -> T:
- return _type
-
- extract.__annotations__["return"] = _type
- return ai_fn(
- fn=extract,
- **AIModelKwargsDefaults(**kwargs).model_dump(exclude_none=True),
- )
-
- return partial(
- ai_model, **AIModelKwargsDefaults(**kwargs).model_dump(exclude_none=True)
- )
diff --git a/src/marvin/components/ai_classifier.py b/src/marvin/components/classifier.py
similarity index 73%
rename from src/marvin/components/ai_classifier.py
rename to src/marvin/components/classifier.py
index 79beb9319..2cb2bdb48 100644
--- a/src/marvin/components/ai_classifier.py
+++ b/src/marvin/components/classifier.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
from typing import (
Any,
Callable,
@@ -19,6 +18,7 @@
from marvin._mappings.chat_completion import chat_completion_to_type
from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
+from marvin.prompts.classifiers import CLASSIFIER_PROMPT
from marvin.utilities.jinja import BaseEnvironment
T = TypeVar("T")
@@ -26,7 +26,7 @@
P = ParamSpec("P")
-class AIClassifierKwargs(TypedDict):
+class ClassifierKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
encoder: NotRequired[Callable[[str], list[int]]]
@@ -35,7 +35,7 @@ class AIClassifierKwargs(TypedDict):
model: NotRequired[str]
-class AIClassifierKwargsDefaults(BaseModel):
+class ClassifierKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = None
@@ -45,37 +45,11 @@ class AIClassifierKwargsDefaults(BaseModel):
model: Optional[str] = None
-class AIClassifier(
- BaseModel,
- Generic[P, T],
-):
+class Classifier(BaseModel, Generic[P, T]):
model_config = ConfigDict(arbitrary_types_allowed=True)
fn: Optional[Callable[P, Union[T, Coroutine[Any, Any, T]]]] = None
environment: Optional[BaseEnvironment] = None
- prompt: Optional[str] = Field(
- default=inspect.cleandoc(
- """
- ## Expert Classifier
-
- **Objective**: You are an expert classifier that always chooses correctly.
-
- ### Context
- {{ _doc }}
-
- ### Response Format
- You must classify the user provided data into one of the following classes:
- {% for option in _options %}
- - Class {{ loop.index0 }} (value: {{ option }})
- {% endfor %}
- \n\nASSISTANT: ### Data
- The user provided the following data:
- {%for (arg, value) in _arguments.items()%}
- - {{ arg }}: {{ value }}
- {% endfor %}
- \n\nASSISTANT: The most likely class label for the data and context provided above is Class"
- """
- )
- ) # noqa
+ prompt: Optional[str] = Field(default=CLASSIFIER_PROMPT)
encoder: Callable[[str], list[int]] = Field(default=None)
max_tokens: int = 1
temperature: float = 0.0
@@ -140,7 +114,7 @@ def dict(
@classmethod
def as_decorator(
cls: type[Self],
- **kwargs: Unpack[AIClassifierKwargs],
+ **kwargs: Unpack[ClassifierKwargs],
) -> Callable[P, Self]:
pass
@@ -149,7 +123,7 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Callable[P, Union[T, Coroutine[Any, Any, T]]],
- **kwargs: Unpack[AIClassifierKwargs],
+ **kwargs: Unpack[ClassifierKwargs],
) -> Self:
pass
@@ -157,12 +131,12 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Union[T, Coroutine[Any, Any, T]]]] = None,
- **kwargs: Unpack[AIClassifierKwargs],
+ **kwargs: Unpack[ClassifierKwargs],
) -> Union[Callable[[Callable[P, Union[T, Coroutine[Any, Any, T]]]], Self], Self]:
def decorator(func: Callable[P, Union[T, Coroutine[Any, Any, T]]]) -> Self:
return cls(
fn=func,
- **AIClassifierKwargsDefaults(**kwargs).model_dump(exclude_none=True),
+ **ClassifierKwargsDefaults(**kwargs).model_dump(exclude_none=True),
)
if fn is not None:
@@ -172,23 +146,23 @@ def decorator(func: Callable[P, Union[T, Coroutine[Any, Any, T]]]) -> Self:
@overload
-def ai_classifier(
- **kwargs: Unpack[AIClassifierKwargs],
+def classifier(
+ **kwargs: Unpack[ClassifierKwargs],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
pass
@overload
-def ai_classifier(
+def classifier(
fn: Callable[P, T],
- **kwargs: Unpack[AIClassifierKwargs],
+ **kwargs: Unpack[ClassifierKwargs],
) -> Callable[P, T]:
pass
-def ai_classifier(
+def classifier(
fn: Optional[Callable[P, Union[T, Coroutine[Any, Any, T]]]] = None,
- **kwargs: Unpack[AIClassifierKwargs],
+ **kwargs: Unpack[ClassifierKwargs],
) -> Union[
Callable[
[Callable[P, Union[T, Coroutine[Any, Any, T]]]],
@@ -197,16 +171,16 @@ def ai_classifier(
Callable[P, Union[T, Coroutine[Any, Any, T]]],
]:
if fn is not None:
- return AIClassifier[P, T].as_decorator(
- fn=fn, **AIClassifierKwargsDefaults(**kwargs).model_dump(exclude_none=True)
+ return Classifier[P, T].as_decorator(
+ fn=fn, **ClassifierKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)
def decorator(
func: Callable[P, Union[T, Coroutine[Any, Any, T]]],
) -> Callable[P, Union[T, Coroutine[Any, Any, T]]]:
- return AIClassifier[P, T].as_decorator(
+ return Classifier[P, T].as_decorator(
fn=func,
- **AIClassifierKwargsDefaults(**kwargs).model_dump(exclude_none=True),
+ **ClassifierKwargsDefaults(**kwargs).model_dump(exclude_none=True),
)
return decorator
diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/function.py
similarity index 82%
rename from src/marvin/components/ai_function.py
rename to src/marvin/components/function.py
index fdaef254b..19def4cec 100644
--- a/src/marvin/components/ai_function.py
+++ b/src/marvin/components/function.py
@@ -1,5 +1,4 @@
import asyncio
-import inspect
from typing import (
TYPE_CHECKING,
Any,
@@ -21,6 +20,7 @@
from marvin._mappings.chat_completion import chat_completion_to_model
from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
+from marvin.prompts.functions import FUNCTION_PROMPT
from marvin.utilities.asyncio import (
ExposeSyncMethodsMixin,
expose_sync_method,
@@ -37,7 +37,7 @@
P = ParamSpec("P")
-class AIFunctionKwargs(TypedDict):
+class FunctionKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
model_name: NotRequired[str]
@@ -50,7 +50,7 @@ class AIFunctionKwargs(TypedDict):
temperature: NotRequired[float]
-class AIFunctionKwargsDefaults(BaseModel):
+class FunctionKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = None
@@ -58,38 +58,21 @@ class AIFunctionKwargsDefaults(BaseModel):
model_description: str = "Formats the response."
field_name: str = "data"
field_description: str = "The data to format."
- model: str = marvin.settings.openai.chat.completions.model
+ model: str = Field(
+ default_factory=lambda: marvin.settings.openai.chat.completions.model
+ )
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None
- temperature: Optional[float] = marvin.settings.openai.chat.completions.temperature
+ temperature: Optional[float] = Field(
+ default_factory=lambda: marvin.settings.openai.chat.completions.temperature
+ )
-class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
+class Function(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, T]] = None
environment: Optional[BaseEnvironment] = None
- prompt: Optional[str] = Field(
- default=inspect.cleandoc(
- """
- Your job is to generate likely outputs for a Python function with the
- following signature and docstring:
-
- {{_source_code}}
-
- The user will provide function inputs (if any) and you must respond with
- the most likely result.
-
- \n\nHUMAN: The function was called with the following inputs:
- {%for (arg, value) in _arguments.items()%}
- - {{ arg }}: {{ value }}
- {% endfor %}
-
-
-
- What is its output?
- """
- )
- )
+ prompt: Optional[str] = Field(FUNCTION_PROMPT)
name: str = "FormatResponse"
description: str = "Formats the response."
field_name: str = "data"
@@ -195,7 +178,7 @@ def dict(
@classmethod
def as_decorator(
cls: type[Self],
- **kwargs: Unpack[AIFunctionKwargs],
+ **kwargs: Unpack[FunctionKwargs],
) -> Callable[P, Self]:
pass
@@ -204,7 +187,7 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Callable[P, Union[T, Coroutine[Any, Any, T]]],
- **kwargs: Unpack[AIFunctionKwargs],
+ **kwargs: Unpack[FunctionKwargs],
) -> Self:
pass
@@ -212,12 +195,12 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Union[T, Coroutine[Any, Any, T]]]] = None,
- **kwargs: Unpack[AIFunctionKwargs],
+ **kwargs: Unpack[FunctionKwargs],
) -> Union[Callable[[Callable[P, Union[T, Coroutine[Any, Any, T]]]], Self], Self]:
def decorator(func: Callable[P, Union[T, Coroutine[Any, Any, T]]]) -> Self:
return cls(
fn=func,
- **AIFunctionKwargsDefaults(**kwargs).model_dump(exclude_none=True),
+ **FunctionKwargsDefaults(**kwargs).model_dump(exclude_none=True),
)
if fn is not None:
@@ -227,23 +210,23 @@ def decorator(func: Callable[P, Union[T, Coroutine[Any, Any, T]]]) -> Self:
@overload
-def ai_fn(
- **kwargs: Unpack[AIFunctionKwargs],
+def fn(
+ **kwargs: Unpack[FunctionKwargs],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
pass
@overload
-def ai_fn(
+def fn(
fn: Callable[P, T],
- **kwargs: Unpack[AIFunctionKwargs],
+ **kwargs: Unpack[FunctionKwargs],
) -> Callable[P, T]:
pass
-def ai_fn(
+def fn(
fn: Optional[Callable[P, Union[T, Coroutine[Any, Any, T]]]] = None,
- **kwargs: Unpack[AIFunctionKwargs],
+ **kwargs: Unpack[FunctionKwargs],
) -> Union[
Callable[
[Callable[P, Union[T, Coroutine[Any, Any, T]]]],
@@ -252,16 +235,16 @@ def ai_fn(
Callable[P, Union[T, Coroutine[Any, Any, T]]],
]:
if fn is not None:
- return AIFunction[P, T].as_decorator(
- fn=fn, **AIFunctionKwargsDefaults(**kwargs).model_dump(exclude_none=True)
+ return Function[P, T].as_decorator(
+ fn=fn, **FunctionKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)
def decorator(
func: Callable[P, Union[T, Coroutine[Any, Any, T]]],
) -> Callable[P, Union[T, Coroutine[Any, Any, T]]]:
- return AIFunction[P, T].as_decorator(
+ return Function[P, T].as_decorator(
fn=func,
- **AIFunctionKwargsDefaults(**kwargs).model_dump(exclude_none=True),
+ **FunctionKwargsDefaults(**kwargs).model_dump(exclude_none=True),
)
return decorator
diff --git a/src/marvin/components/ai_image.py b/src/marvin/components/image.py
similarity index 90%
rename from src/marvin/components/ai_image.py
rename to src/marvin/components/image.py
index d5c059f1d..e8e11ab8c 100644
--- a/src/marvin/components/ai_image.py
+++ b/src/marvin/components/image.py
@@ -28,14 +28,14 @@
P = ParamSpec("P")
-class AIImageKwargs(TypedDict):
+class ImageKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
client: NotRequired[Client]
aclient: NotRequired[AsyncClient]
-class AIImageKwargsDefaults(BaseModel):
+class ImageKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = IMAGE_PROMPT
@@ -43,7 +43,7 @@ class AIImageKwargsDefaults(BaseModel):
aclient: Optional[AsyncClient] = None
-class AIImage(BaseModel, Generic[P]):
+class Image(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
@@ -84,7 +84,7 @@ def as_prompt(
@classmethod
def as_decorator(
cls: type[Self],
- **kwargs: Unpack[AIImageKwargs],
+ **kwargs: Unpack[ImageKwargs],
) -> Callable[P, Self]:
pass
@@ -93,7 +93,7 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Callable[P, Any],
- **kwargs: Unpack[AIImageKwargs],
+ **kwargs: Unpack[ImageKwargs],
) -> Self:
pass
@@ -101,7 +101,7 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Any]] = None,
- **kwargs: Unpack[AIImageKwargs],
+ **kwargs: Unpack[ImageKwargs],
) -> Union[Self, Callable[[Callable[P, Any]], Self]]:
passed_kwargs: dict[str, Any] = {
k: v for k, v in kwargs.items() if v is not None
@@ -118,9 +118,9 @@ def as_decorator(
)
-def ai_image(
+def image(
fn: Optional[Callable[P, Any]] = None,
- **kwargs: Unpack[AIImageKwargs],
+ **kwargs: Unpack[ImageKwargs],
) -> Union[
Callable[
[Callable[P, Any]],
@@ -131,8 +131,8 @@ def ai_image(
def wrapper(
func: Callable[P, Any], *args_: P.args, **kwargs_: P.kwargs
) -> Union["ImagesResponse", Coroutine[Any, Any, "ImagesResponse"]]:
- return AIImage[P].as_decorator(
- func, **AIImageKwargsDefaults(**kwargs).model_dump(exclude_none=True)
+ return Image[P].as_decorator(
+ func, **ImageKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)(*args_, **kwargs_)
if fn is not None:
diff --git a/src/marvin/components/model.py b/src/marvin/components/model.py
new file mode 100644
index 000000000..329a9afd1
--- /dev/null
+++ b/src/marvin/components/model.py
@@ -0,0 +1,53 @@
+from functools import partial
+from typing import Callable, Optional, TypeVar, Union, overload
+
+from typing_extensions import Unpack
+
+from marvin.components.function import (
+ FunctionKwargs,
+ FunctionKwargsDefaults,
+ fn,
+)
+from marvin.prompts.models import MODEL_PROMPT
+
+T = TypeVar("T")
+
+
+class ModelKwargsDefaults(FunctionKwargsDefaults):
+ prompt: Optional[str] = MODEL_PROMPT
+
+
+@overload
+def model(
+ **kwargs: Unpack[FunctionKwargs],
+) -> Callable[[Callable[[str], T]], Callable[[str], T]]:
+ pass
+
+
+@overload
+def model(
+ _type: type[T],
+ **kwargs: Unpack[FunctionKwargs],
+) -> Callable[[str], T]:
+ pass
+
+
+def model(
+ _type: Optional[type[T]] = None,
+ **kwargs: Unpack[FunctionKwargs],
+) -> Union[
+ Callable[[Callable[[str], T]], Callable[[str], T]],
+ partial[Callable[[Callable[[str], T]], Callable[[str], T]]],
+ Callable[[str], T],
+]:
+ if _type is not None:
+
+ def extract(text: str, instructions: str = None) -> _type:
+ pass
+
+ return fn(
+ fn=extract,
+ **ModelKwargsDefaults(**kwargs).model_dump(exclude_none=True),
+ )
+
+ return partial(model, **ModelKwargsDefaults(**kwargs).model_dump(exclude_none=True))
diff --git a/src/marvin/components/prompt/fn.py b/src/marvin/components/prompt/fn.py
index cac85777e..3872a3821 100644
--- a/src/marvin/components/prompt/fn.py
+++ b/src/marvin/components/prompt/fn.py
@@ -24,6 +24,7 @@
from marvin.settings import settings
from marvin.utilities.jinja import (
BaseEnvironment,
+ Environment,
Transcript,
)
@@ -39,22 +40,32 @@ def fn_to_messages(
prompt=None,
render_kwargs=None,
call_fn: bool = True,
+ environment: Optional[BaseEnvironment] = None,
) -> list[Message]:
- prompt = prompt or fn.__doc__ or ""
+ prompt = prompt or inspect.getdoc(fn) or ""
+ environment = environment or Environment
signature = inspect.signature(fn)
params = signature.bind(*fn_args, **fn_kwargs)
params.apply_defaults()
return_annotation = inspect.signature(fn).return_annotation
return_value = fn(*fn_args, **fn_kwargs) if call_fn else None
+ function_def = f"def {fn.__name__}{signature} -> {return_annotation}:"
+
+ doc = environment.render(inspect.getdoc(fn) or "", **fn_kwargs | params.arguments)
+ source = environment.render(
+ "\ndef" + "def".join(re.split("def", inspect.getsource(fn))[1:]),
+ **fn_kwargs | params.arguments,
+ )
messages = Transcript(content=prompt).render_to_messages(
**fn_kwargs | params.arguments,
_arguments=params.arguments,
- _doc=inspect.getdoc(fn),
+ _signature=function_def,
+ _doc=doc,
_return_value=return_value,
_return_annotation=return_annotation,
- _source_code=("\ndef" + "def".join(re.split("def", inspect.getsource(fn))[1:])),
+ _source_code=source,
**(render_kwargs or {}),
)
return messages
diff --git a/src/marvin/components/ai_speech.py b/src/marvin/components/speech.py
similarity index 90%
rename from src/marvin/components/ai_speech.py
rename to src/marvin/components/speech.py
index 988ccde07..3be173480 100644
--- a/src/marvin/components/ai_speech.py
+++ b/src/marvin/components/speech.py
@@ -28,14 +28,14 @@
P = ParamSpec("P")
-class AISpeechKwargs(TypedDict):
+class SpeechKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
client: NotRequired[Client]
aclient: NotRequired[AsyncClient]
-class AISpeechKwargsDefaults(BaseModel):
+class SpeechKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = SPEECH_PROMPT
@@ -43,7 +43,7 @@ class AISpeechKwargsDefaults(BaseModel):
aclient: Optional[AsyncClient] = None
-class AISpeech(BaseModel, Generic[P]):
+class Speech(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
@@ -82,7 +82,7 @@ def as_prompt(
@classmethod
def as_decorator(
cls: type[Self],
- **kwargs: Unpack[AISpeechKwargs],
+ **kwargs: Unpack[SpeechKwargs],
) -> Callable[P, Self]:
pass
@@ -91,7 +91,7 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Callable[P, Any],
- **kwargs: Unpack[AISpeechKwargs],
+ **kwargs: Unpack[SpeechKwargs],
) -> Self:
pass
@@ -99,7 +99,7 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Any]] = None,
- **kwargs: Unpack[AISpeechKwargs],
+ **kwargs: Unpack[SpeechKwargs],
) -> Union[Self, Callable[[Callable[P, Any]], Self]]:
passed_kwargs: dict[str, Any] = {
k: v for k, v in kwargs.items() if v is not None
@@ -116,9 +116,9 @@ def as_decorator(
)
-def ai_speech(
+def speech(
fn: Optional[Callable[P, Any]] = None,
- **kwargs: Unpack[AISpeechKwargs],
+ **kwargs: Unpack[SpeechKwargs],
) -> Union[
Callable[
[Callable[P, Any]],
@@ -129,8 +129,8 @@ def ai_speech(
def wrapper(
func: Callable[P, Any], *args_: P.args, **kwargs_: P.kwargs
) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]:
- f = AISpeech[P].as_decorator(
- func, **AISpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True)
+ f = Speech[P].as_decorator(
+ func, **SpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)
return f(*args_, **kwargs_)
diff --git a/src/marvin/components/text.py b/src/marvin/components/text.py
new file mode 100644
index 000000000..558be42d7
--- /dev/null
+++ b/src/marvin/components/text.py
@@ -0,0 +1,29 @@
+from typing import TypeVar
+
+import marvin
+
+T = TypeVar("T")
+
+
+def cast(text: str, _type: type[T], instructions: str = None) -> T:
+ return marvin.model(_type)(text, instructions=instructions)
+
+
+def extract(text: str, _type: type[T], instructions: str = None) -> list[T]:
+ @marvin.fn
+ def _extract(text: str) -> list[_type]:
+ msg = "Extract a list of objects from the text, using inference if necessary."
+ if instructions:
+ msg += f' Follow these instructions for extraction: "{instructions}"'
+ return msg
+
+ return _extract(text)
+
+
+def classify(text: str, _type: type[T], instructions: str = None) -> dict[str, T]:
+ @marvin.components.classifier
+ def _classify(text: str) -> _type:
+ if instructions:
+ return f'Follow these instructions for classification: "{instructions}"'
+
+ return _classify(text)
diff --git a/src/marvin/prompts/classifiers.py b/src/marvin/prompts/classifiers.py
new file mode 100644
index 000000000..b9744141d
--- /dev/null
+++ b/src/marvin/prompts/classifiers.py
@@ -0,0 +1,29 @@
+import inspect
+
+CLASSIFIER_PROMPT = inspect.cleandoc(
+ """
+ ## Expert Classifier
+
+ **Objective**: You are an expert classifier that always chooses correctly.
+
+ ### Context
+ {{ _doc }}
+ {{_return_value | default("", true)}}
+
+ ### Response Format
+ You must classify the user provided data into one of the following classes:
+ {% for option in _options %}
+ - Class {{ loop.index0 }} (value: {{ option }})
+ {% endfor %}
+
+
+ ASSISTANT: ### Data
+ The user provided the following data:
+ {%for (arg, value) in _arguments.items()%}
+ - {{ arg }}: {{ value }}
+ {% endfor %}
+
+
+ ASSISTANT: The most likely class label for the data and context provided above is Class"
+ """
+)
diff --git a/src/marvin/prompts/functions.py b/src/marvin/prompts/functions.py
new file mode 100644
index 000000000..8f12dbdee
--- /dev/null
+++ b/src/marvin/prompts/functions.py
@@ -0,0 +1,29 @@
+import inspect
+
+FUNCTION_PROMPT = inspect.cleandoc(
+ """
+ Your job is to generate likely outputs for a Python function with the
+ following signature and docstring:
+
+ {{_signature}}
+ {{_doc}}
+
+ The user will provide function inputs (if any) and you must respond with
+ the most likely result.
+
+ HUMAN: The function was called with the following inputs:
+ {%for (arg, value) in _arguments.items()%}
+ - {{ arg }}: {{ value }}
+ {% endfor %}
+
+ {% if _return_value %}
+ This context was also provided:
+ {{_return_value}}
+ {% endif %}
+
+
+ What is its output?
+
+ The output is
+ """
+)
diff --git a/src/marvin/prompts/models.py b/src/marvin/prompts/models.py
new file mode 100644
index 000000000..ae0701449
--- /dev/null
+++ b/src/marvin/prompts/models.py
@@ -0,0 +1,19 @@
+import inspect
+
+MODEL_PROMPT = inspect.cleandoc(
+ """
+ The user will provide context as text that you need to parse into a structured
+ form. To validate your response, you must call the
+ `{{_response_model.function.name}}` function. Use the provided text to extract or
+ infer any parameters needed by `{{_response_model.function.name}}`, including any
+ missing data.
+
+
+ HUMAN: The text to parse: {{text}}
+
+ {% if instructions %}
+ Pay attention to these additional instructions: {{instructions}}
+ {% endif %}
+
+ """
+)
diff --git a/src/marvin/requests.py b/src/marvin/requests.py
index 95a78cbff..490b1b43e 100644
--- a/src/marvin/requests.py
+++ b/src/marvin/requests.py
@@ -1,7 +1,7 @@
-from typing import Any, Callable, Generic, Optional, TypeVar, Union
+from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union
from pydantic import BaseModel, Field, PrivateAttr
-from typing_extensions import Annotated, Literal, Self
+from typing_extensions import Annotated, Self
from marvin.settings import settings
diff --git a/src/marvin/serializers.py b/src/marvin/serializers.py
index 0c303aea7..8a7c7422f 100644
--- a/src/marvin/serializers.py
+++ b/src/marvin/serializers.py
@@ -3,6 +3,7 @@
from typing import (
Any,
Callable,
+ Literal,
Optional,
TypeVar,
Union,
@@ -13,7 +14,6 @@
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode
-from typing_extensions import Literal
from marvin import settings
from marvin.requests import Function, Grammar, Tool
diff --git a/src/marvin/settings.py b/src/marvin/settings.py
index 00170c053..b93df5f35 100644
--- a/src/marvin/settings.py
+++ b/src/marvin/settings.py
@@ -13,14 +13,19 @@
import os
from contextlib import contextmanager
from copy import deepcopy
-from typing import Any, Optional, Union
+from typing import Any, Literal, Optional, Union
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
-from typing_extensions import Literal
class MarvinSettings(BaseSettings):
+ model_config = SettingsConfigDict(
+ env_file="" if os.getenv("MARVIN_TEST_MODE") else "~/.marvin/.env",
+ extra="allow",
+ arbitrary_types_allowed=True,
+ )
+
def __setattr__(self, name: str, value: Any) -> None:
# wrap bare strings in SecretStr if the field is annotated with SecretStr
field = self.model_fields.get(name)
@@ -39,9 +44,6 @@ def __setattr__(self, name: str, value: Any) -> None:
class ChatCompletionSettings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_llm_",
- env_file="~/.marvin/.env",
- extra="allow",
- arbitrary_types_allowed=True,
)
model: str = Field(
description="The default chat model to use.", default="gpt-3.5-turbo"
@@ -70,9 +72,6 @@ class ImageSettings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_image_",
- env_file="~/.marvin/.env",
- extra="allow",
- arbitrary_types_allowed=True,
)
model: str = Field(
@@ -98,9 +97,6 @@ class SpeechSettings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_speech_",
- env_file="~/.marvin/.env",
- extra="allow",
- arbitrary_types_allowed=True,
)
model: str = Field(
@@ -123,9 +119,6 @@ class AssistantSettings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_llm",
- env_file="~/.marvin/.env",
- extra="allow",
- arbitrary_types_allowed=True,
)
model: str = Field(
@@ -167,9 +160,6 @@ class OpenAISettings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_openai_",
- env_file="~/.marvin/.env",
- extra="allow",
- arbitrary_types_allowed=True,
)
api_key: Optional[SecretStr] = Field(
@@ -193,8 +183,9 @@ def discover_api_key(cls, v):
v = SecretStr(os.environ.get("OPENAI_API_KEY"))
if v.get_secret_value() is None:
raise ValueError(
- "OpenAI API key not found. Please either set `MARVIN_OPENAI_API_KEY` in `~/.marvin/.env`"
- " or otherwise set `OPENAI_API_KEY` in your environment."
+ "OpenAI API key not found. Please either set"
+ " `MARVIN_OPENAI_API_KEY` in `~/.marvin/.env` or otherwise set"
+ " `OPENAI_API_KEY` in your environment."
)
return v
@@ -221,9 +212,6 @@ class Settings(MarvinSettings):
model_config = SettingsConfigDict(
env_prefix="marvin_",
- env_file="~/.marvin/.env",
- extra="allow",
- arbitrary_types_allowed=True,
protected_namespaces=(),
)
diff --git a/src/marvin/tools/chroma.py b/src/marvin/tools/chroma.py
index d8af1824c..30bc23418 100644
--- a/src/marvin/tools/chroma.py
+++ b/src/marvin/tools/chroma.py
@@ -12,7 +12,7 @@
)
-from typing_extensions import Literal
+from typing import Literal
import marvin
diff --git a/src/marvin/utilities/pydantic.py b/src/marvin/utilities/pydantic.py
index 450f76713..409ad19bc 100644
--- a/src/marvin/utilities/pydantic.py
+++ b/src/marvin/utilities/pydantic.py
@@ -1,11 +1,10 @@
"""Module for Pydantic utilities."""
from types import FunctionType, GenericAlias
-from typing import Annotated, Any, Callable, Optional, Union, cast, get_origin
+from typing import Annotated, Any, Callable, Literal, Optional, Union, cast, get_origin
from pydantic import BaseModel, TypeAdapter, create_model
from pydantic.deprecated.decorator import validate_arguments
-from typing_extensions import Literal
def cast_callable_to_model(
diff --git a/tests/components/test_cast.py b/tests/components/test_cast.py
new file mode 100644
index 000000000..49a649b74
--- /dev/null
+++ b/tests/components/test_cast.py
@@ -0,0 +1,57 @@
+import marvin
+import pytest
+from pydantic import BaseModel, Field
+
+from tests.utils import pytest_mark_class
+
+
+class Location(BaseModel):
+ city: str = Field(description="The city's proper name")
+ state: str = Field(description="2-letter abbreviation")
+
+
+@pytest_mark_class("llm")
+class TestCast:
+ class TestBuiltins:
+ def test_cast_text_to_int(self):
+ result = marvin.cast("one", int)
+ assert result == 1
+
+ def test_cast_text_to_list_of_ints(self):
+ result = marvin.cast("one, TWO, three", list[int])
+ assert result == [1, 2, 3]
+
+ def test_cast_text_to_list_of_ints_2(self):
+ result = marvin.cast("4 and 5 then 6", list[int])
+ assert result == [4, 5, 6]
+
+ def test_cast_text_to_list_of_floats(self):
+ result = marvin.cast("1.1, 2.2, 3.3", list[float])
+ assert result == [1.1, 2.2, 3.3]
+
+ def test_cast_text_to_bool(self):
+ result = marvin.cast("nope", bool)
+ assert result is False
+
+ def test_cast_text_to_bool_with_true(self):
+ result = marvin.cast("yes", bool)
+ assert result is True
+
+ class TestPydantic:
+ @pytest.mark.parametrize("text", ["New York, NY", "NYC", "the big apple"])
+ def test_cast_text_to_location(self, text, gpt_4):
+ result = marvin.cast(f"I live in {text}", Location)
+ assert result == Location(city="New York", state="NY")
+
+ class TestInstructions:
+ def test_cast_text_with_significant_instructions(self):
+ result = marvin.cast("one", int, instructions="return the number 4")
+ assert result == 4
+
+ def test_cast_text_with_subtle_instructions(self, gpt_4):
+ result = marvin.cast(
+ "My name is marvin",
+ str,
+ instructions="makes names uppercase",
+ )
+ assert result == "My name is MARVIN"
diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_classifiers.py
similarity index 72%
rename from tests/components/test_ai_classifier.py
rename to tests/components/test_classifiers.py
index 23f117f34..a3a2f6f4e 100644
--- a/tests/components/test_ai_classifier.py
+++ b/tests/components/test_classifiers.py
@@ -1,8 +1,8 @@
from enum import Enum
+from typing import Literal
import pytest
-from marvin import ai_classifier
-from typing_extensions import Literal
+from marvin.components import classifier
from tests.utils import pytest_mark_class
@@ -17,10 +17,10 @@ class GitHubIssueTag(Enum):
@pytest_mark_class("llm")
-class TestAIClassifer:
+class TestClassifer:
class TestLiteral:
- def test_ai_classifier_literal_return_type(self):
- @ai_classifier
+ def test_classifier_literal_return_type(self):
+ @classifier
def sentiment(text: str) -> Sentiment:
"""Classify sentiment"""
@@ -29,8 +29,8 @@ def sentiment(text: str) -> Sentiment:
assert result == "Positive"
@pytest.mark.flaky(reruns=3)
- def test_ai_classifier_literal_return_type_with_docstring(self):
- @ai_classifier
+ def test_classifier_literal_return_type_with_docstring(self):
+ @classifier
def sentiment(text: str) -> Sentiment:
"""Classify sentiment. Keep in mind it's opposite day"""
@@ -39,8 +39,8 @@ def sentiment(text: str) -> Sentiment:
assert result == "Negative"
class TestEnum:
- def test_ai_classifier_enum_return_type(self):
- @ai_classifier
+ def test_classifier_enum_return_type(self):
+ @classifier
def labeler(text: str) -> GitHubIssueTag:
"""Classify GitHub issue tags"""
diff --git a/tests/components/test_classify.py b/tests/components/test_classify.py
new file mode 100644
index 000000000..ad11936ef
--- /dev/null
+++ b/tests/components/test_classify.py
@@ -0,0 +1,51 @@
+from enum import Enum
+from typing import Literal
+
+import marvin
+
+from tests.utils import pytest_mark_class
+
+Sentiment = Literal["Positive", "Negative"]
+
+
+class GitHubIssueTag(Enum):
+ BUG = "bug"
+ FEATURE = "feature"
+ ENHANCEMENT = "enhancement"
+ DOCS = "docs"
+
+
+@pytest_mark_class("llm")
+class TestClassify:
+ class TestLiteral:
+ def test_classify_sentiment(self):
+ result = marvin.classify("This is a great feature!", Sentiment)
+ assert result == "Positive"
+
+ def test_classify_negative_sentiment(self):
+ result = marvin.classify("This feature is terrible!", Sentiment)
+ assert result == "Negative"
+
+ class TestEnum:
+ def test_classify_bug_tag(self):
+ result = marvin.classify("This is a bug", GitHubIssueTag)
+ assert result == GitHubIssueTag.BUG
+
+ def test_classify_feature_tag(self):
+ result = marvin.classify("This is a great feature!", GitHubIssueTag)
+ assert result == GitHubIssueTag.FEATURE
+
+ def test_classify_enhancement_tag(self):
+ result = marvin.classify("This is an enhancement", GitHubIssueTag)
+ assert result == GitHubIssueTag.ENHANCEMENT
+
+ def test_classify_docs_tag(self):
+ result = marvin.classify("This is a documentation update", GitHubIssueTag)
+ assert result == GitHubIssueTag.DOCS
+
+ class TestInstructions:
+ def test_classify_positive_sentiment_with_instructions(self):
+ result = marvin.classify(
+ "This is a great feature!", Sentiment, instructions="It's opposite day."
+ )
+ assert result == "Negative"
diff --git a/tests/components/test_extract.py b/tests/components/test_extract.py
new file mode 100644
index 000000000..5185d58a9
--- /dev/null
+++ b/tests/components/test_extract.py
@@ -0,0 +1,63 @@
+import marvin
+import pytest
+from pydantic import BaseModel, Field
+
+from tests.utils import pytest_mark_class
+
+
+class Location(BaseModel):
+ city: str = Field(description="The city's proper name")
+ state: str = Field(description="2-letter abbreviation")
+
+
+@pytest_mark_class("llm")
+class TestExtract:
+ class TestBuiltins:
+ def test_extract_numbers(self):
+ result = marvin.extract("one, TWO, three", int)
+ assert result == [1, 2, 3]
+
+ def test_extract_complex_numbers(self, gpt_4):
+ result = marvin.extract(
+ "I paid $10 for 3 coffees and they gave me back a dollar and 25 cents",
+ float,
+ )
+ assert result == [10.0, 3.0, 1.25]
+
+ def test_extract_money(self):
+ result = marvin.extract(
+ "I paid $10 for 3 coffees and they gave me back a dollar and 25 cents",
+ float,
+ instructions="money",
+ )
+ assert result == [10.0, 1.25]
+
+ class TestPydantic:
+ def test_extract_location(self):
+ result = marvin.extract("I live in New York, NY", Location)
+ assert result == [Location(city="New York", state="NY")]
+
+ def test_extract_multiple_locations(self):
+ result = marvin.extract(
+ "I live in New York, NY and work in San Francisco, CA", Location
+ )
+ assert result == [
+ Location(city="New York", state="NY"),
+ Location(city="San Francisco", state="CA"),
+ ]
+
+ def test_extract_multiple_locations_by_nickname(self, gpt_4):
+ result = marvin.extract("I live in the big apple and work in SF", Location)
+ assert result == [
+ Location(city="New York", state="NY"),
+ Location(city="San Francisco", state="CA"),
+ ]
+
+ @pytest.mark.xfail(reason="tuples aren't working right now")
+ def test_extract_complex_pattern(self, gpt_4):
+ result = marvin.extract(
+ "John lives in Boston, Mary lives in NYC, and I live in SF",
+ tuple[str, Location],
+ instructions="pair names and locations",
+ )
+ assert result == []
diff --git a/tests/components/test_ai_functions.py b/tests/components/test_functions.py
similarity index 94%
rename from tests/components/test_ai_functions.py
rename to tests/components/test_functions.py
index 79845ee08..794510b62 100644
--- a/tests/components/test_ai_functions.py
+++ b/tests/components/test_functions.py
@@ -3,24 +3,24 @@
import marvin
import pytest
-from marvin import ai_fn
+from marvin import fn
from pydantic import BaseModel
from tests.utils import pytest_mark_class
-@ai_fn
+@fn
def list_fruit(n: int = 2) -> list[str]:
"""Returns a list of `n` fruit"""
-@ai_fn
+@fn
def list_fruit_color(n: int, color: str = None) -> list[str]:
"""Returns a list of `n` fruit that all have the provided `color`"""
@pytest_mark_class("llm")
-class TestAIFunctions:
+class TestFunctions:
class TestBasics:
def test_list_fruit(self):
result = list_fruit()
@@ -31,7 +31,7 @@ def test_list_fruit_argument(self):
assert len(result) == 5
async def test_list_fruit_async(self):
- @ai_fn
+ @fn
async def list_fruit(n: int) -> list[str]:
"""Returns a list of `n` fruit"""
@@ -42,7 +42,7 @@ async def list_fruit(n: int) -> list[str]:
class TestAnnotations:
def test_no_annotations(self):
- @ai_fn
+ @fn
def f(x):
"""returns x + 1"""
@@ -50,7 +50,7 @@ def f(x):
assert result == "4"
def test_arg_annotations(self):
- @ai_fn
+ @fn
def f(x: int):
"""returns x + 1"""
@@ -58,7 +58,7 @@ def f(x: int):
assert result == "4"
def test_return_annotations(self):
- @ai_fn
+ @fn
def f(x) -> int:
"""returns x + 1"""
@@ -66,7 +66,7 @@ def f(x) -> int:
assert result == 4
def test_list_fruit_with_generic_type_hints(self):
- @ai_fn
+ @fn
def list_fruit(n: int) -> List[str]:
"""Returns a list of `n` fruit"""
@@ -78,7 +78,7 @@ class Fruit(BaseModel):
name: str
color: str
- @ai_fn
+ @fn
def get_fruit(description: str) -> Fruit:
"""Returns a fruit with the provided description"""
@@ -88,7 +88,7 @@ def get_fruit(description: str) -> Fruit:
@pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)])
def test_bool_return_annotation(self, name, expected):
- @ai_fn
+ @fn
def is_fruit(name: str) -> bool:
"""Returns True if the provided name is a fruit"""
@@ -99,7 +99,7 @@ def is_fruit(name: str) -> bool:
reason="3.5 turbo doesn't do well with unknown schemas",
)
def test_plain_dict_return_type(self):
- @ai_fn
+ @fn
def describe_fruit(description: str) -> dict:
"""guess the fruit and return the name and color"""
@@ -112,7 +112,7 @@ def describe_fruit(description: str) -> dict:
reason="3.5 turbo doesn't do well with unknown schemas",
)
def test_annotated_dict_return_type(self):
- @ai_fn
+ @fn
def describe_fruit(description: str) -> dict[str, str]:
"""guess the fruit and return the name and color"""
@@ -125,7 +125,7 @@ def describe_fruit(description: str) -> dict[str, str]:
reason="3.5 turbo doesn't do well with unknown schemas",
)
def test_generic_dict_return_type(self):
- @ai_fn
+ @fn
def describe_fruit(description: str) -> Dict[str, str]:
"""guess the fruit and return the name and color"""
@@ -140,7 +140,7 @@ class Fruit(TypedDict):
name: str
color: str
- @ai_fn
+ @fn
def describe_fruit(description: str) -> Fruit:
"""guess the fruit and return the name and color"""
@@ -149,21 +149,21 @@ def describe_fruit(description: str) -> Fruit:
assert fruit["color"].lower() == "yellow"
def test_int_return_type(self):
- @ai_fn
+ @fn
def get_fruit(name: str) -> int:
"""Returns the number of letters in the alluded fruit name"""
assert get_fruit("banana") == 6
def test_float_return_type(self):
- @ai_fn
+ @fn
def get_pi(n: int) -> float:
"""Return the first n digits of pi"""
assert get_pi(5) == 3.14159
def test_tuple_return_type(self):
- @ai_fn
+ @fn
def get_fruit(name: str) -> tuple:
"""Returns a tuple of fruit"""
@@ -174,14 +174,14 @@ def get_fruit(name: str) -> tuple:
)
def test_set_return_type(self):
- @ai_fn
+ @fn
def get_fruit_letters(name: str) -> set:
"""Returns the letters in the provided fruit name"""
assert get_fruit_letters("banana") == {"a", "b", "n"}
def test_frozenset_return_type(self):
- @ai_fn
+ @fn
def get_fruit_letters(name: str) -> frozenset:
"""Returns the letters in the provided fruit name"""
@@ -191,7 +191,7 @@ def get_fruit_letters(name: str) -> frozenset:
@pytest_mark_class("llm")
-class TestAIFunctionsMap:
+class TestFunctionsMap:
def test_map(self):
result = list_fruit.map([2, 3])
assert len(result) == 2
diff --git a/tests/components/test_ai_model.py b/tests/components/test_models.py
similarity index 93%
rename from tests/components/test_ai_model.py
rename to tests/components/test_models.py
index f70d9a282..58e9ad878 100644
--- a/tests/components/test_ai_model.py
+++ b/tests/components/test_models.py
@@ -1,17 +1,16 @@
-from typing import List, Optional
+from typing import List, Literal, Optional
import pytest
-from marvin import ai_model
+from marvin import model
from pydantic import BaseModel, Field
-from typing_extensions import Literal
from tests.utils import pytest_mark_class
@pytest_mark_class("llm")
-class TestAIModels:
+class TestModels:
def test_arithmetic(self):
- @ai_model
+ @model
class Arithmetic(BaseModel):
sum: float = Field(
..., description="The resolved sum of provided arguments"
@@ -23,7 +22,7 @@ class Arithmetic(BaseModel):
assert x.is_odd
def test_geospatial(self):
- @ai_model
+ @model
class Location(BaseModel):
latitude: float
longitude: float
@@ -53,7 +52,7 @@ class Neighborhood(BaseModel):
name: str
city: City
- @ai_model
+ @model
class RentalHistory(BaseModel):
neighborhood: List[Neighborhood]
@@ -70,7 +69,7 @@ class Experience(BaseModel):
years_of_experience: int
supporting_phrase: Optional[str]
- @ai_model
+ @model
class Resume(BaseModel):
"""Details about a person's work experience."""
@@ -92,7 +91,7 @@ class Resume(BaseModel):
assert len(x.technologies) == 2
def test_literal(self):
- @ai_model
+ @model
class LLMConference(BaseModel):
speakers: list[
Literal["Adam", "Nate", "Jeremiah", "Marvin", "Billy Bob Thornton"]
@@ -123,7 +122,7 @@ class Candidate(BaseModel):
campaign_slogan: str
birthplace: Location
- @ai_model
+ @model
class Election(BaseModel):
candidates: List[Candidate]
winner: Candidate
@@ -150,7 +149,7 @@ class Election(BaseModel):
@pytest.mark.skip(reason="old behavior, may revisit")
def test_correct_class_is_returned(self):
- @ai_model
+ @model
class Fruit(BaseModel):
color: str
name: str
@@ -164,7 +163,7 @@ class Fruit(BaseModel):
@pytest_mark_class("llm")
class TestInstructions:
def test_instructions_error(self):
- @ai_model
+ @model
class Test(BaseModel):
text: str
@@ -176,7 +175,7 @@ class Test(BaseModel):
Test("Hello!", model=None)
def test_instructions(self):
- @ai_model
+ @model
class Text(BaseModel):
text: str
@@ -184,7 +183,7 @@ class Text(BaseModel):
assert t1.text == "Hello"
# this model is identical except it has an instruction
- @ai_model(instructions="first translate the text to French")
+ @model(instructions="first translate the text to French")
class Text(BaseModel):
text: str
@@ -192,7 +191,7 @@ class Text(BaseModel):
assert t2.text == "Bonjour"
def test_follow_instance_instructions(self):
- @ai_model
+ @model
class Test(BaseModel):
text: str
@@ -200,7 +199,7 @@ class Test(BaseModel):
assert t1.text == "Hello"
# this model is identical except it has an instruction
- @ai_model
+ @model
class Test(BaseModel):
text: str
@@ -208,7 +207,7 @@ class Test(BaseModel):
assert t2.text == "Bonjour"
def test_follow_global_and_instance_instructions(self):
- @ai_model(instructions="Always set color_1 to 'red'")
+ @model(instructions="Always set color_1 to 'red'")
class Test(BaseModel):
color_1: str
color_2: str
@@ -217,7 +216,7 @@ class Test(BaseModel):
assert t1 == Test(color_1="red", color_2="blue")
def test_follow_docstring_and_global_and_instance_instructions(self):
- @ai_model(instructions="Always set color_1 to 'red'")
+ @model(instructions="Always set color_1 to 'red'")
class Test(BaseModel):
"""Always set color_3 to 'orange'"""
@@ -230,7 +229,7 @@ class Test(BaseModel):
def test_follow_multiple_instructions(self):
# ensure that instructions don't bleed to other invocations
- @ai_model
+ @model
class Translation(BaseModel):
"""Translates from one language to another language"""
@@ -249,9 +248,9 @@ class Translation(BaseModel):
@pytest_mark_class("llm")
-class TestAIModelMapping:
+class TestModelMapping:
def test_arithmetic(self):
- @ai_model
+ @model
class Arithmetic(BaseModel):
sum: float
@@ -262,7 +261,7 @@ class Arithmetic(BaseModel):
@pytest.mark.skip(reason="TODO: flaky on 3.5")
def test_fix_misspellings(self):
- @ai_model
+ @model
class City(BaseModel):
"""Standardize misspelled or informal city names"""
diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py
index e69de29bb..96394ca9a 100644
--- a/tests/fixtures/__init__.py
+++ b/tests/fixtures/__init__.py
@@ -0,0 +1 @@
+from .llms import *
diff --git a/tests/fixtures/llms.py b/tests/fixtures/llms.py
new file mode 100644
index 000000000..345fbf9ee
--- /dev/null
+++ b/tests/fixtures/llms.py
@@ -0,0 +1,11 @@
+import pytest
+from marvin.settings import temporary_settings
+
+
+@pytest.fixture
+def gpt_4():
+ """
+ Uses GPT 4 for the duration of the test
+ """
+ with temporary_settings(openai__chat__completions__model="gpt-4-1106-preview"):
+ yield