Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/marvin/_mappings/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Callable,
TypeVar,
Union,
cast,
)

from pydantic import BaseModel, TypeAdapter, ValidationError
Expand Down Expand Up @@ -41,9 +40,6 @@ def chat_completion_to_model(
data: dict[str, Any] = {}
data[field_name] = json.loads(tool_arguments[0])
return response_model.model_validate_json(json.dumps(data))
else:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was unreachable

data: dict[str, Any] = json.loads(tool_arguments[0])
return cast(T, data)


def chat_completion_to_type(response_type: U, completion: "ChatCompletion") -> "U":
Expand Down
1 change: 1 addition & 0 deletions src/marvin/_mappings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def cast_type_to_model(

return create_model(
model_name,
__doc__=model_description,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_description was unused

__config__=None,
__base__=None,
__module__=__name__,
Expand Down
13 changes: 10 additions & 3 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import NotRequired, ParamSpec, Self, Unpack

import marvin
from marvin._mappings.chat_completion import chat_completion_to_model
from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
from marvin.utilities.jinja import BaseEnvironment
from marvin.utilities.logging import get_logger

if TYPE_CHECKING:
from openai.types.chat import ChatCompletion
Expand All @@ -44,17 +46,17 @@ class AIFunctionKwargs(TypedDict):


class AIFunctionKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
Copy link
Collaborator Author

@zzstoatzz zzstoatzz Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid warning about infringing on sacred model_* namespace

environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = None
model_name: str = "FormatResponse"
model_description: str = "Formats the response."
field_name: str = "data"
field_description: str = "The data to format."
model: Optional[str] = None
model: str = marvin.settings.openai.chat.completions.model
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None
temperature: Optional[float] = None
temperature: Optional[float] = marvin.settings.openai.chat.completions.temperature
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was getting 400s without these



class AIFunction(
Expand Down Expand Up @@ -89,6 +91,10 @@ class AIFunction(
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

@property
def logger(self):
return get_logger(self.__class__.__name__)

def __call__(
self, *args: P.args, **kwargs: P.kwargs
) -> Union[T, Coroutine[Any, Any, T]]:
Expand All @@ -101,6 +107,7 @@ def call(self, *args: P.args, **kwargs: P.kwargs) -> T:
response: ChatCompletion = MarvinClient(client=self.client).chat(
**prompt.serialize()
)
self.logger.debug_kv("Calling", f"{self.fn.__name__}({args}, {kwargs})", "blue")
return getattr(
chat_completion_to_model(model, response, field_name=self.field_name),
self.field_name,
Expand Down
5 changes: 5 additions & 0 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

class MarvinSettings(BaseSettings):
def __setattr__(self, name: str, value: Any) -> None:
# wrap bare strings in SecretStr if the field is annotated with SecretStr
field = self.model_fields.get(name)
if field:
annotation = field.annotation
Expand All @@ -46,6 +47,10 @@ class ChatCompletionSettings(MarvinSettings):
description="The default chat model to use.", default="gpt-3.5-turbo"
)

temperature: float = Field(
description="The default temperature to use.", default=0.1
)

@property
def encoder(self):
import tiktoken
Expand Down