diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 4a56665a..39ae6d05 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +from collections.abc import Sequence from typing import TypeVar import pydantic @@ -36,12 +37,13 @@ def __init__( @abc.abstractmethod async def generate_from_context( self, - action: Component | CBlock, + action: Component | CBlock | None, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context]: """Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request. @@ -51,6 +53,7 @@ async def generate_from_context( format: A response format to used for structured outputs / constrained decoding. model_options: Any model options to upsert into the defaults for this call. tool_calls: If `True`, then tool calls are extracts from the `action` `Component`. Assumption: if tool_calls is enabled, then the action `Component` has a TemplateRepresentation + labels: The labels under which to execute this action. Returns: a tuple of (ModelOutputThunk, Context) where the Context is the new context after the generation has been completed. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index ba825753..bfe927c8 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -6,7 +6,7 @@ import functools import inspect import json -from collections.abc import Callable, Coroutine +from collections.abc import Callable, Coroutine, Sequence from copy import deepcopy from enum import Enum from typing import TYPE_CHECKING, Any, cast @@ -286,12 +286,13 @@ def _make_backend_specific_and_remove( async def generate_from_context( self, - action: Component | CBlock, + action: Component | CBlock | None, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ): """See `generate_from_chat_context`.""" assert ctx.is_chat_context, NotImplementedError( @@ -303,18 +304,27 @@ async def generate_from_context( _format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) + # only add action to context if provided + if action is not None: + ctx = ctx.add(action, labels=labels) + + # return + return mot, ctx.add(mot, labels=labels) + async def generate_from_chat_context( self, - action: Component | CBlock, + action: Component | CBlock | None, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, - ) -> tuple[ModelOutputThunk, Context]: + labels: Sequence[str] | None = None, + ) -> ModelOutputThunk: """Generates a new completion from the provided Context using this backend's `Formatter`.""" # Requirements can be automatically rerouted to a requirement adapter. if isinstance(action, Requirement): @@ -366,6 +376,7 @@ async def generate_from_chat_context( _format=_format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) return mot, ctx.add(action).add(mot) @@ -564,18 +575,19 @@ def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]: async def _generate_from_chat_context_standard( self, - action: Component | CBlock, + action: Component | CBlock | None, ctx: Context, *, _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> ModelOutputThunk: model_opts = self._simplify_and_merge( model_options, is_chat_context=ctx.is_chat_context ) - linearized_context = ctx.view_for_generation() + linearized_context = ctx.view_for_generation(labels=labels) assert linearized_context is not None, ( "Cannot generate from a non-linear context in a FormatterBackend." ) @@ -587,6 +599,8 @@ async def _generate_from_chat_context_standard( raise Exception( "The OpenAI backend does not support currently support activated LoRAs." ) + case None: + action = ctx.node_data # action defaults to node_data if None provided case _: messages.extend(self.formatter.to_chat_messages([action])) conversation: list[dict] = [] diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 111d44f6..b4935660 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -8,7 +8,7 @@ import binascii import datetime import enum -from collections.abc import Callable, Coroutine, Iterable, Mapping +from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO @@ -436,18 +436,23 @@ class Context(abc.ABC): _data: Component | CBlock | None _is_root: bool _is_chat_context: bool = True + _labels: set[str] | None = None def __init__(self): """Constructs a new root context with no content.""" self._previous = None self._data = None self._is_root = True + self._labels = set() # factory functions below this line. @classmethod def from_previous( - cls: type[ContextT], previous: Context, data: Component | CBlock + cls: type[ContextT], + previous: Context, + data: Component | CBlock, + labels: Sequence[str] | None = None, ) -> ContextT: """Constructs a new context from an existing context.""" assert isinstance(previous, Context), ( @@ -460,6 +465,7 @@ def from_previous( x._data = data x._is_root = False x._is_chat_context = previous._is_chat_context + x._labels = set(labels) if labels is not None else previous._labels return x @classmethod @@ -495,9 +501,16 @@ def is_chat_context(self) -> bool: """Returns whether this context is a chat context.""" return self._is_chat_context + @property + def labels(self) -> set[str]: + """Returns the list of labels for this context node.""" + return self._labels + # User functions below this line. - def as_list(self, last_n_components: int | None = None) -> list[Component | CBlock]: + def as_list( + self, last_n_components: int | None = None, labels: Sequence[str] | None = None + ) -> list[Component | CBlock]: """Returns a list of the last n components in the context sorted from FIRST TO LAST. If `last_n_components` is `None`, then all components are returned. @@ -505,17 +518,18 @@ def as_list(self, last_n_components: int | None = None) -> list[Component | CBlo context_list: list[Component | CBlock] = [] current_context: Context = self - last_n_count = 0 while not current_context.is_root_node and ( - last_n_components is None or last_n_count < last_n_components + last_n_components is None or len(context_list) < last_n_components ): data = current_context.node_data assert data is not None, "Data cannot be None (except for root context)." assert data not in context_list, ( "There might be a cycle in the context tree. That is not allowed." ) - context_list.append(data) - last_n_count += 1 + + # append only if no label is specified or label is within allowed label set + if labels is None or current_context.labels.intersection(labels): + context_list.append(data) current_context = current_context.previous_node # type: ignore assert current_context is not None, ( @@ -532,19 +546,23 @@ def actions_for_available_tools(self) -> list[Component | CBlock] | None: """ return self.view_for_generation() - def last_output(self, check_last_n_components: int = 3) -> ModelOutputThunk | None: + def last_output( + self, check_last_n_components: int = 3, labels: Sequence[str] | None = None + ) -> ModelOutputThunk | None: """The last output thunk of the context.""" - for c in self.as_list(last_n_components=check_last_n_components)[::-1]: + for c in self.as_list(last_n_components=check_last_n_components, labels=labels)[ + ::-1 + ]: if isinstance(c, ModelOutputThunk): return c return None - def last_turn(self): + def last_turn(self, labels: Sequence[str] | None = None): """The last input/output turn of the context. This can be partial. If the last event is an input, then the output is None. """ - history = self.as_list(last_n_components=2) + history = self.as_list(last_n_components=2, labels=labels) if len(history) == 0: return None @@ -563,13 +581,17 @@ def last_turn(self): # Abstract methods below this line. @abc.abstractmethod - def add(self, c: Component | CBlock) -> Context: + def add( + self, c: Component | CBlock, labels: Sequence[str] | None = None + ) -> Context: """Returns a new context obtained by adding `c` to this context.""" # something along ....from_previous(self, c) ... @abc.abstractmethod - def view_for_generation(self) -> list[Component | CBlock] | None: + def view_for_generation( + self, labels: Sequence[str] | None = None + ) -> list[Component | CBlock] | None: """Provides a linear list of context components to use for generation, or None if that is not possible to construct.""" ... @@ -582,25 +604,33 @@ def __init__(self, *, window_size: int | None = None): super().__init__() self._window_size = window_size - def add(self, c: Component | CBlock) -> ChatContext: + def add( + self, c: Component | CBlock, labels: Sequence[str] | None = None + ) -> ChatContext: """Add a new component/cblock to the context. Returns the new context.""" - new = ChatContext.from_previous(self, c) + new = ChatContext.from_previous(self, c, labels=labels) new._window_size = self._window_size return new - def view_for_generation(self) -> list[Component | CBlock] | None: + def view_for_generation( + self, labels: Sequence[str] | None = None + ) -> list[Component | CBlock] | None: """Returns the context in a linearized form. Uses the window_size set during initialization.""" - return self.as_list(self._window_size) + return self.as_list(self._window_size, labels=labels) class SimpleContext(Context): """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" - def add(self, c: Component | CBlock) -> SimpleContext: + def add( + self, c: Component | CBlock, labels: Sequence[str] | None = None + ) -> SimpleContext: """Add a new component/cblock to the context. Returns the new context.""" - return SimpleContext.from_previous(self, c) + return SimpleContext.from_previous(self, c, labels=labels) - def view_for_generation(self) -> list[Component | CBlock] | None: + def view_for_generation( + self, labels: Sequence[str] | None = None + ) -> list[Component | CBlock] | None: """Returns an empty list.""" return [] diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index e758be04..51bcc0ff 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from collections.abc import Coroutine +from collections.abc import Coroutine, Sequence from typing import Any, Literal, overload from PIL import Image as PILImage @@ -46,6 +46,7 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context]: ... @@ -61,6 +62,7 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> SamplingResult: ... @@ -75,6 +77,7 @@ def act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. @@ -88,6 +91,7 @@ def act( format: if set, the BaseModel to use for constrained decoding. model_options: additional model options, which will upsert into the model/backend's defaults. tool_calls: if true, tool calling is enabled. + labels: if provided, restrict generation to context nodes with matching types. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -104,6 +108,7 @@ def act( model_options=model_options, tool_calls=tool_calls, silence_context_type_warning=True, # We can safely silence this here since it's in a sync function. + labels=labels, ) # type: ignore[call-overload] # Mypy doesn't like the bool for return_sampling_results. ) @@ -129,6 +134,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context]: ... @@ -150,6 +156,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> SamplingResult: ... @@ -170,6 +177,7 @@ def instruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: """Generates from an instruction. @@ -189,6 +197,7 @@ def instruct( model_options: Additional model options, which will upsert into the model/backend's defaults. tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. + labels: if provided, restrict generation to context nodes with matching types. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -221,6 +230,7 @@ def instruct( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) # type: ignore[call-overload] @@ -235,6 +245,7 @@ def chat( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[Message, Context]: """Sends a simple chat message and returns the response. Adds both messages to the Context.""" if user_variables is not None: @@ -254,6 +265,7 @@ def chat( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) parsed_assistant_message = result.parsed_repr assert isinstance(parsed_assistant_message, Message) @@ -429,6 +441,7 @@ async def aact( model_options: dict | None = None, tool_calls: bool = False, silence_context_type_warning: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context]: ... @@ -445,11 +458,12 @@ async def aact( model_options: dict | None = None, tool_calls: bool = False, silence_context_type_warning: bool = False, + labels: Sequence[str] | None = None, ) -> SamplingResult: ... async def aact( - action: Component, + action: Component | None, context: Context, backend: Backend, *, @@ -460,6 +474,7 @@ async def aact( model_options: dict | None = None, tool_calls: bool = False, silence_context_type_warning: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. @@ -474,6 +489,7 @@ async def aact( model_options: additional model options, which will upsert into the model/backend's defaults. tool_calls: if true, tool calling is enabled. silence_context_type_warning: if called directly from an asynchronous function, will log a warning if not using a SimpleContext + labels: if provided, restrict generation to context nodes with matching types. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -505,6 +521,7 @@ async def aact( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) await result.avalue() @@ -526,6 +543,7 @@ async def aact( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) assert sampling_result.sample_generations is not None @@ -567,6 +585,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context]: ... @@ -588,6 +607,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> SamplingResult: ... @@ -608,6 +628,7 @@ async def ainstruct( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: """Generates from an instruction. @@ -627,6 +648,7 @@ async def ainstruct( model_options: Additional model options, which will upsert into the model/backend's defaults. tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. + labels: if provided, restrict generation to context nodes with matching types. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -659,6 +681,7 @@ async def ainstruct( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) # type: ignore[call-overload] @@ -673,6 +696,7 @@ async def achat( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> tuple[Message, Context]: """Sends a simple chat message and returns the response. Adds both messages to the Context.""" if user_variables is not None: @@ -692,6 +716,7 @@ async def achat( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) parsed_assistant_message = result.parsed_repr assert isinstance(parsed_assistant_message, Message) diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 06401ec4..f403d190 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -1,6 +1,7 @@ """Base Sampling Strategies.""" import abc +from collections.abc import Sequence from copy import deepcopy import tqdm @@ -94,6 +95,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, + labels: Sequence[str] | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -107,6 +109,7 @@ async def sample( model_options: model options to pass to the backend during generation / validation. tool_calls: True if tool calls should be used during this sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + labels: if provided, restrict generation to context nodes with matching set of labels Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -157,6 +160,7 @@ async def sample( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) await result.avalue() diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index be082159..be7a1ee8 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -1,5 +1,6 @@ """Best of N Sampling Strategy.""" +from collections.abc import Sequence from copy import deepcopy import tqdm @@ -29,6 +30,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, + labels: Sequence[str] | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -42,6 +44,7 @@ async def sample( model_options: model options to pass to the backend during generation / validation. tool_calls: True if tool calls should be used during this sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + labels: if provided, restrict generation to context nodes with matching types. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -114,6 +117,7 @@ async def sample( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) sampled_results.append(result) sampled_actions.append(next_action) diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 692fae66..760b5aee 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -1,5 +1,6 @@ """Sampling Strategies for budget forcing generation.""" +from collections.abc import Sequence from copy import deepcopy import tqdm @@ -82,6 +83,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, + labels: Sequence[str] | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -95,6 +97,7 @@ async def sample( model_options: model options to pass to the backend during generation / validation. tool_calls: True if tool calls should be used during this sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + labels: if provided, restrict generation to context nodes with matching types. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. diff --git a/mellea/stdlib/sampling/majority_voting.py b/mellea/stdlib/sampling/majority_voting.py index 8ba99798..f4b12719 100644 --- a/mellea/stdlib/sampling/majority_voting.py +++ b/mellea/stdlib/sampling/majority_voting.py @@ -2,6 +2,7 @@ import abc import asyncio +from collections.abc import Sequence import numpy as np from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify @@ -74,6 +75,7 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, show_progress: bool = True, + labels: Sequence[str] | None = None, ) -> SamplingResult: """Samples using majority voting. @@ -87,6 +89,7 @@ async def sample( model_options: model options to pass to the backend during generation / validation. tool_calls: True if tool calls should be used during this sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + labels: if provided, restrict generation to context nodes with matching types. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -104,6 +107,7 @@ async def sample( model_options=model_options, tool_calls=tool_calls, show_progress=show_progress, + labels=labels, ) ) tasks.append(task) diff --git a/mellea/stdlib/sampling/types.py b/mellea/stdlib/sampling/types.py index 391b2c89..a64713df 100644 --- a/mellea/stdlib/sampling/types.py +++ b/mellea/stdlib/sampling/types.py @@ -1,6 +1,7 @@ """Base types for sampling.""" import abc +from collections.abc import Sequence from mellea.backends import Backend, BaseModelSubclass from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk @@ -95,6 +96,7 @@ async def sample( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given component. @@ -109,6 +111,7 @@ async def sample( format: output format for structured outputs. model_options: model options to pass to the backend during generation / validation. tool_calls: True if tool calls should be used during this sampling strategy. + labels: if provided, restrict generation to context nodes with matching types. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 91d1be24..c356a691 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -4,6 +4,7 @@ import contextvars import inspect +from collections.abc import Sequence from copy import copy from typing import Any, Literal, overload @@ -237,6 +238,12 @@ def cleanup(self) -> None: if hasattr(self.backend, "close"): self.backend.close() # type: ignore + def append_to_ctx( + self, c: Component | CBlock, labels: Sequence[str] | None = None + ) -> None: + """Adds new component to current context.""" + self.ctx = self.ctx.add(c, labels=labels) + @overload def act( self, @@ -524,7 +531,7 @@ def transform( @overload async def aact( self, - action: Component, + action: Component | None, *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -532,12 +539,13 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> ModelOutputThunk: ... @overload async def aact( self, - action: Component, + action: Component | None, *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -545,11 +553,12 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> SamplingResult: ... async def aact( self, - action: Component, + action: Component | None, *, requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), @@ -557,6 +566,7 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + labels: Sequence[str] | None = None, ) -> ModelOutputThunk | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. @@ -568,6 +578,7 @@ async def aact( format: if set, the BaseModel to use for constrained decoding. model_options: additional model options, which will upsert into the model/backend's defaults. tool_calls: if true, tool calling is enabled. + labels: labels to filter on Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -582,6 +593,7 @@ async def aact( format=format, model_options=model_options, tool_calls=tool_calls, + labels=labels, ) # type: ignore if isinstance(r, SamplingResult): diff --git a/test/stdlib_basics/test_base_context.py b/test/stdlib_basics/test_base_context.py index 0c1a5620..46987305 100644 --- a/test/stdlib_basics/test_base_context.py +++ b/test/stdlib_basics/test_base_context.py @@ -67,5 +67,22 @@ def test_actions_for_available_tools(): assert actions[i] == for_generation[i] +def test_render_view_for_chat_context_with_labels(): + ctx = ChatContext(window_size=3) + for i in range(5): + ctx = ctx.add(CBlock(f"a {i}"), labels=[str(i // 2)]) + + # no labels + assert len(ctx.as_list()) == 5, "Context size must be 5" + assert len(ctx.view_for_generation()) == 3, "Render size must be 3" + + # with explicit labels + for labels, al_sz, vg_sz in [(None, 5, 3), ([str(0)], 2, 2), ([str(2)], 1, 1)]: + assert len(ctx.as_list(labels=labels)) == al_sz, f"Context size must be {al_sz}" + assert len(ctx.view_for_generation(labels=labels)) == vg_sz, ( + f"Render size must be {vg_sz}" + ) + + if __name__ == "__main__": pytest.main([__file__])