Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mellea/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
from collections.abc import Sequence
from typing import TypeVar

import pydantic
Expand Down Expand Up @@ -36,12 +37,13 @@ def __init__(
@abc.abstractmethod
async def generate_from_context(
self,
action: Component | CBlock,
action: Component | CBlock | None,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
labels: Sequence[str] | None = None,
) -> tuple[ModelOutputThunk, Context]:
"""Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request.

Expand All @@ -51,6 +53,7 @@ async def generate_from_context(
format: A response format to used for structured outputs / constrained decoding.
model_options: Any model options to upsert into the defaults for this call.
tool_calls: If `True`, then tool calls are extracts from the `action` `Component`. Assumption: if tool_calls is enabled, then the action `Component` has a TemplateRepresentation
labels: The labels under which to execute this action.

Returns:
a tuple of (ModelOutputThunk, Context) where the Context is the new context after the generation has been completed.
Expand Down
26 changes: 20 additions & 6 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import functools
import inspect
import json
from collections.abc import Callable, Coroutine
from collections.abc import Callable, Coroutine, Sequence
from copy import deepcopy
from enum import Enum
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -286,12 +286,13 @@ def _make_backend_specific_and_remove(

async def generate_from_context(
self,
action: Component | CBlock,
action: Component | CBlock | None,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
labels: Sequence[str] | None = None,
):
"""See `generate_from_chat_context`."""
assert ctx.is_chat_context, NotImplementedError(
Expand All @@ -303,18 +304,27 @@ async def generate_from_context(
_format=format,
model_options=model_options,
tool_calls=tool_calls,
labels=labels,
)

# only add action to context if provided
if action is not None:
ctx = ctx.add(action, labels=labels)

# return
return mot, ctx.add(mot, labels=labels)

async def generate_from_chat_context(
self,
action: Component | CBlock,
action: Component | CBlock | None,
ctx: Context,
*,
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
) -> tuple[ModelOutputThunk, Context]:
labels: Sequence[str] | None = None,
) -> ModelOutputThunk:
"""Generates a new completion from the provided Context using this backend's `Formatter`."""
# Requirements can be automatically rerouted to a requirement adapter.
if isinstance(action, Requirement):
Expand Down Expand Up @@ -366,6 +376,7 @@ async def generate_from_chat_context(
_format=_format,
model_options=model_options,
tool_calls=tool_calls,
labels=labels,
)
return mot, ctx.add(action).add(mot)

Expand Down Expand Up @@ -564,18 +575,19 @@ def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]:

async def _generate_from_chat_context_standard(
self,
action: Component | CBlock,
action: Component | CBlock | None,
ctx: Context,
*,
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
labels: Sequence[str] | None = None,
) -> ModelOutputThunk:
model_opts = self._simplify_and_merge(
model_options, is_chat_context=ctx.is_chat_context
)
linearized_context = ctx.view_for_generation()
linearized_context = ctx.view_for_generation(labels=labels)
assert linearized_context is not None, (
"Cannot generate from a non-linear context in a FormatterBackend."
)
Expand All @@ -587,6 +599,8 @@ async def _generate_from_chat_context_standard(
raise Exception(
"The OpenAI backend does not support currently support activated LoRAs."
)
case None:
action = ctx.node_data # action defaults to node_data if None provided
case _:
messages.extend(self.formatter.to_chat_messages([action]))
conversation: list[dict] = []
Expand Down
70 changes: 50 additions & 20 deletions mellea/stdlib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import binascii
import datetime
import enum
from collections.abc import Callable, Coroutine, Iterable, Mapping
from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
from copy import copy, deepcopy
from dataclasses import dataclass
from io import BytesIO
Expand Down Expand Up @@ -436,18 +436,23 @@ class Context(abc.ABC):
_data: Component | CBlock | None
_is_root: bool
_is_chat_context: bool = True
_labels: set[str] | None = None

def __init__(self):
"""Constructs a new root context with no content."""
self._previous = None
self._data = None
self._is_root = True
self._labels = set()

# factory functions below this line.

@classmethod
def from_previous(
cls: type[ContextT], previous: Context, data: Component | CBlock
cls: type[ContextT],
previous: Context,
data: Component | CBlock,
labels: Sequence[str] | None = None,
) -> ContextT:
"""Constructs a new context from an existing context."""
assert isinstance(previous, Context), (
Expand All @@ -460,6 +465,7 @@ def from_previous(
x._data = data
x._is_root = False
x._is_chat_context = previous._is_chat_context
x._labels = set(labels) if labels is not None else previous._labels
return x

@classmethod
Expand Down Expand Up @@ -495,27 +501,35 @@ def is_chat_context(self) -> bool:
"""Returns whether this context is a chat context."""
return self._is_chat_context

@property
def labels(self) -> set[str]:
"""Returns the list of labels for this context node."""
return self._labels

# User functions below this line.

def as_list(self, last_n_components: int | None = None) -> list[Component | CBlock]:
def as_list(
self, last_n_components: int | None = None, labels: Sequence[str] | None = None
) -> list[Component | CBlock]:
"""Returns a list of the last n components in the context sorted from FIRST TO LAST.

If `last_n_components` is `None`, then all components are returned.
"""
context_list: list[Component | CBlock] = []
current_context: Context = self

last_n_count = 0
while not current_context.is_root_node and (
last_n_components is None or last_n_count < last_n_components
last_n_components is None or len(context_list) < last_n_components
):
data = current_context.node_data
assert data is not None, "Data cannot be None (except for root context)."
assert data not in context_list, (
"There might be a cycle in the context tree. That is not allowed."
)
context_list.append(data)
last_n_count += 1

# append only if no label is specified or label is within allowed label set
if labels is None or current_context.labels.intersection(labels):
context_list.append(data)

current_context = current_context.previous_node # type: ignore
assert current_context is not None, (
Expand All @@ -532,19 +546,23 @@ def actions_for_available_tools(self) -> list[Component | CBlock] | None:
"""
return self.view_for_generation()

def last_output(self, check_last_n_components: int = 3) -> ModelOutputThunk | None:
def last_output(
self, check_last_n_components: int = 3, labels: Sequence[str] | None = None
) -> ModelOutputThunk | None:
"""The last output thunk of the context."""
for c in self.as_list(last_n_components=check_last_n_components)[::-1]:
for c in self.as_list(last_n_components=check_last_n_components, labels=labels)[
::-1
]:
if isinstance(c, ModelOutputThunk):
return c
return None

def last_turn(self):
def last_turn(self, labels: Sequence[str] | None = None):
"""The last input/output turn of the context.

This can be partial. If the last event is an input, then the output is None.
"""
history = self.as_list(last_n_components=2)
history = self.as_list(last_n_components=2, labels=labels)

if len(history) == 0:
return None
Expand All @@ -563,13 +581,17 @@ def last_turn(self):
# Abstract methods below this line.

@abc.abstractmethod
def add(self, c: Component | CBlock) -> Context:
def add(
self, c: Component | CBlock, labels: Sequence[str] | None = None
) -> Context:
"""Returns a new context obtained by adding `c` to this context."""
# something along ....from_previous(self, c)
...

@abc.abstractmethod
def view_for_generation(self) -> list[Component | CBlock] | None:
def view_for_generation(
self, labels: Sequence[str] | None = None
) -> list[Component | CBlock] | None:
"""Provides a linear list of context components to use for generation, or None if that is not possible to construct."""
...

Expand All @@ -582,25 +604,33 @@ def __init__(self, *, window_size: int | None = None):
super().__init__()
self._window_size = window_size

def add(self, c: Component | CBlock) -> ChatContext:
def add(
self, c: Component | CBlock, labels: Sequence[str] | None = None
) -> ChatContext:
"""Add a new component/cblock to the context. Returns the new context."""
new = ChatContext.from_previous(self, c)
new = ChatContext.from_previous(self, c, labels=labels)
new._window_size = self._window_size
return new

def view_for_generation(self) -> list[Component | CBlock] | None:
def view_for_generation(
self, labels: Sequence[str] | None = None
) -> list[Component | CBlock] | None:
"""Returns the context in a linearized form. Uses the window_size set during initialization."""
return self.as_list(self._window_size)
return self.as_list(self._window_size, labels=labels)


class SimpleContext(Context):
"""A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved.."""

def add(self, c: Component | CBlock) -> SimpleContext:
def add(
self, c: Component | CBlock, labels: Sequence[str] | None = None
) -> SimpleContext:
"""Add a new component/cblock to the context. Returns the new context."""
return SimpleContext.from_previous(self, c)
return SimpleContext.from_previous(self, c, labels=labels)

def view_for_generation(self) -> list[Component | CBlock] | None:
def view_for_generation(
self, labels: Sequence[str] | None = None
) -> list[Component | CBlock] | None:
"""Returns an empty list."""
return []

Expand Down
Loading