Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,18 @@ We differentiate between observers and stores. Observers wrap generative AI APIs
To get started you can run the code below. It sends requests to a HF serverless endpoint and log the interactions into a Hub dataset, using the default store `DatasetsStore`. The dataset will be pushed to your personal workspace (http://hf.co/{your_username}). To learn how to configure stores, go to the next section.

```python
from observers.observers import wrap_openai
from observers.stores import DuckDBStore
from openai import OpenAI

store = DuckDBStore()
from observers import wrap_openai

openai_client = OpenAI()
client = wrap_openai(openai_client, store=store)

client = wrap_openai(openai_client)

response = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "Tell me a joke."}],
)
print(response)
```

## Observers
Expand Down
15 changes: 7 additions & 8 deletions examples/models/openai_example.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import os

from openai import OpenAI

from observers import wrap_openai

openai_client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/", api_key=os.getenv("HF_TOKEN")
)

openai_client = OpenAI()

client = wrap_openai(openai_client)

response = client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=[{"role": "user", "content": "Tell me a joke."}],
model="gpt-4o",
messages=[{"role": "user", "content": "Tell me a joke in the voice of a pirate."}],
temperature=0.5,
)
print(response)

print(response.choices[0].message.content)
290 changes: 287 additions & 3 deletions pdm.lock

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions src/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,42 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from typing_extensions import Literal

if TYPE_CHECKING:
from argilla import Argilla


@dataclass
class Function:
"""Function tool call information"""

name: str
arguments: str


@dataclass
class ToolCall:
"""Tool call information"""

id: str
type: Literal["function"]
function: Function


@dataclass
class Message:
role: Literal["system", "user", "assistant", "function"]
content: str
tool_calls: Optional[List[ToolCall]] = None
"""The tool calls generated by the model, such as function calls."""

function_call: Optional[Function] = None
"""Deprecated and replaced by `tool_calls`.

The name and arguments of a function that should be called, as generated by the
model.
"""


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/observers/models/aisuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def wrap_aisuite(
return ChatCompletionObserver(
client=client,
create=client.chat.completions.create,
format_input=lambda inputs, **kwargs: {"messages": inputs, **kwargs},
format_input=lambda messages, **kwargs: kwargs | {"messages": messages},
parse_response=AisuiteRecord.from_response,
store=store,
tags=tags,
Expand Down
88 changes: 69 additions & 19 deletions src/observers/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ChatCompletionRecord(Record):

model: str = None
timestamp: str = field(default_factory=lambda: datetime.datetime.now().isoformat())
arguments: Optional[Dict[str, Any]] = None

messages: List[Message] = None
assistant_message: Optional[str] = None
Expand Down Expand Up @@ -55,7 +56,7 @@ def table_columns(self):
"properties",
"error",
"raw_response",
"synced_at",
"arguments",
]

@property
Expand All @@ -65,7 +66,7 @@ def duckdb_schema(self):
id VARCHAR PRIMARY KEY,
model VARCHAR,
timestamp TIMESTAMP,
messages STRUCT(role VARCHAR, content VARCHAR)[],
messages JSON,
assistant_message TEXT,
completion_tokens INTEGER,
prompt_tokens INTEGER,
Expand All @@ -77,7 +78,7 @@ def duckdb_schema(self):
properties JSON,
error VARCHAR,
raw_response JSON,
synced_at TIMESTAMP
arguments JSON,
)
"""

Expand Down Expand Up @@ -160,7 +161,14 @@ def table_name(self):

@property
def json_fields(self):
return ["tool_calls", "function_call", "tags", "properties", "raw_response"]
return [
"tool_calls",
"function_call",
"tags",
"properties",
"raw_response",
"arguments",
]

@property
def image_fields(self):
Expand Down Expand Up @@ -223,13 +231,17 @@ def chat(self) -> Self:
def completions(self) -> Self:
return self

def _log_record(self, response, error=None, model=None):
def _log_record(
self, response, error=None, model=None, messages=None, arguments=None
):
record = self.parse_response(
response,
error=error,
model=model,
messages=messages,
tags=self.tags,
properties=self.properties,
arguments=arguments,
)
if random.random() < self.logging_rate:
self.store.add(record)
Expand All @@ -255,30 +267,47 @@ def create(
"""
response = None
kwargs = self.handle_kwargs(kwargs)
excluded_args = {"model", "messages", "tags", "properties"}
arguments = {k: v for k, v in kwargs.items() if k not in excluded_args}
model = kwargs.get("model")
input_data = self.format_input(messages, **kwargs)

if kwargs.get("stream", False):

def stream_responses():
response = []
response_buffer = []
try:
for chunk in self.create_fn(**input_data):
yield chunk
response.append(chunk)
self._log_record(response, model=model)
response_buffer.append(chunk)
self._log_record(
response_buffer,
model=model,
messages=messages,
arguments=arguments,
)
except Exception as e:
self._log_record(response, error=e, model=model)
self._log_record(
response_buffer,
error=e,
model=model,
messages=messages,
arguments=arguments,
)
raise

return stream_responses()

try:
response = self.create_fn(**input_data)
self._log_record(response, model=model)
self._log_record(
response, model=model, messages=messages, arguments=arguments
)
return response
except Exception as e:
self._log_record(response, error=e, model=model)
self._log_record(
response, error=e, model=model, messages=messages, arguments=arguments
)
raise

def handle_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -318,13 +347,17 @@ class AsyncChatCompletionObserver(ChatCompletionObserver):
The logging rate to use for logging, defaults to 1
"""

async def _log_record_async(self, response, error=None, model=None):
async def _log_record_async(
self, response, error=None, model=None, messages=None, arguments=None
):
record = self.parse_response(
response,
error=error,
model=model,
messages=messages,
tags=self.tags,
properties=self.properties,
arguments=arguments,
)
if random.random() < self.logging_rate:
await self.store.add_async(record)
Expand All @@ -346,30 +379,47 @@ async def create(
"""
response = None
kwargs = self.handle_kwargs(kwargs)
input_data = self.format_input(messages, **kwargs)
excluded_args = {"model", "messages", "tags", "properties"}
arguments = {k: v for k, v in kwargs.items() if k not in excluded_args}
model = kwargs.get("model")
input_data = self.format_input(messages, **kwargs)

if kwargs.get("stream", False):

async def stream_responses():
response = []
response_buffer = []
try:
async for chunk in await self.create_fn(**input_data):
yield chunk
response.append(chunk)
await self._log_record_async(response, model=model)
response_buffer.append(chunk)
await self._log_record_async(
response_buffer,
model=model,
messages=messages,
arguments=arguments,
)
except Exception as e:
await self._log_record_async(response, error=e, model=model)
await self._log_record_async(
response_buffer,
error=e,
model=model,
messages=messages,
arguments=arguments,
)
raise

return stream_responses()

try:
response = await self.create_fn(**input_data)
await self._log_record_async(response, model=model)
await self._log_record_async(
response, model=model, messages=messages, arguments=arguments
)
return response
except Exception as e:
await self._log_record_async(response, error=e, model=model)
await self._log_record_async(
response, error=e, model=model, messages=messages, arguments=arguments
)
raise

async def __aenter__(self) -> "AsyncChatCompletionObserver":
Expand Down
14 changes: 9 additions & 5 deletions src/observers/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from observers.stores.duckdb import DuckDBStore
from openai import AsyncOpenAI, OpenAI
from typing_extensions import Self

Expand All @@ -14,7 +14,6 @@
from openai.types.chat import ChatCompletion, ChatCompletionChunk

from observers.stores.datasets import DatasetsStore
from observers.stores.duckdb import DuckDBStore


class OpenAIRecord(ChatCompletionRecord):
Expand All @@ -25,11 +24,14 @@ def from_response(
cls,
response: Union[List["ChatCompletionChunk"], "ChatCompletion"] = None,
error=None,
messages=None,
**kwargs,
) -> Self:
"""Create a response record from an API response or error"""
if not response:
return cls(finish_reason="error", error=str(error), **kwargs)
return cls(
finish_reason="error", error=str(error), messages=messages, **kwargs
)

# Handle streaming responses
if isinstance(response, list):
Expand All @@ -56,6 +58,7 @@ def from_response(

return cls(
id=first_dump.get("id") or str(uuid.uuid4()),
messages=messages,
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=total_tokens,
Expand All @@ -73,6 +76,7 @@ def from_response(
usage = response_dump.get("usage", {}) or {}
return cls(
id=response.id or str(uuid.uuid4()),
messages=messages,
completion_tokens=usage.get("completion_tokens"),
prompt_tokens=usage.get("prompt_tokens"),
total_tokens=usage.get("total_tokens"),
Expand All @@ -87,7 +91,7 @@ def from_response(

def wrap_openai(
client: Union["OpenAI", "AsyncOpenAI"],
store: Optional[Union["DuckDBStore", "DatasetsStore"]] = None,
store: Optional[Union["DuckDBStore", "DatasetsStore"]] = DuckDBStore(),
tags: Optional[List[str]] = None,
properties: Optional[Dict[str, Any]] = None,
logging_rate: Optional[float] = 1,
Expand All @@ -114,7 +118,7 @@ def wrap_openai(
observer_args = {
"client": client,
"create": client.chat.completions.create,
"format_input": lambda inputs, **kwargs: {"messages": inputs, **kwargs},
"format_input": lambda messages, **kwargs: kwargs | {"messages": messages},
"parse_response": OpenAIRecord.from_response,
"store": store,
"tags": tags,
Expand Down
1 change: 0 additions & 1 deletion src/observers/stores/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def add(self, record: "Record"):
with self._scheduler.lock:
with (self._scheduler.folder_path / self._filename).open("a") as f:
record_dict = asdict(record)
record_dict["synced_at"] = None

# Handle JSON fields
for json_field in record.json_fields:
Expand Down
Loading