Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import logging
import asyncio
from uuid import UUID, uuid4
from typing import Any, Optional, Union, Dict, TYPE_CHECKING
from typing import Optional, Union, Dict, TYPE_CHECKING

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue

from chatsky.core.message import Message, MessageInitTypes
from chatsky.slots.slots import SlotManager
Expand Down Expand Up @@ -87,7 +87,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True):
Instance of the pipeline that manages this context.
Can be used to obtain run configuration such as script or fallback label.
"""
stats: Dict[str, Any] = Field(default_factory=dict)
stats: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict)
"Enables complex stats collection across multiple turns."
slot_manager: SlotManager = Field(default_factory=SlotManager)
"Stores extracted slots."
Expand Down Expand Up @@ -133,7 +133,7 @@ class Context(BaseModel):
First response is stored at key ``1``.
IDs go up by ``1`` after that.
"""
misc: Dict[str, Any] = Field(default_factory=dict)
misc: Dict[str, Union[BaseModel, JsonValue]] = Field(default_factory=dict)
"""
``misc`` stores any custom data. The framework doesn't use this dictionary,
so storage of any data won't reflect on the work of the internal Chatsky functions.
Expand Down
68 changes: 9 additions & 59 deletions chatsky/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,21 @@
"""

from __future__ import annotations
from typing import Literal, Optional, List, Union, Dict, Any, TYPE_CHECKING
from typing import Literal, Optional, List, Union, Dict, TYPE_CHECKING
from typing_extensions import TypeAlias, Annotated
from pathlib import Path
from urllib.request import urlopen
import uuid
import abc

from pydantic import Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer
from pydantic import BaseModel, Field, FilePath, HttpUrl, JsonValue, model_validator
from pydantic_core import Url

from chatsky.utils.devel import (
json_pickle_validator,
json_pickle_serializer,
pickle_serializer,
pickle_validator,
JSONSerializableExtras,
)

if TYPE_CHECKING:
from chatsky.messengers.common.interface import MessengerInterfaceWithAttachments


class DataModel(JSONSerializableExtras):
class DataModel(BaseModel, extra="allow"):
"""
This class is a Pydantic BaseModel that can have any type and number of extras.
"""
Expand Down Expand Up @@ -290,9 +282,9 @@ class level variables to store message information.
]
]
] = None
annotations: Optional[Dict[str, Any]] = None
misc: Optional[Dict[str, Any]] = None
original_message: Optional[Any] = None
annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None
misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None
Copy link
Copy Markdown
Member

@RLKRo RLKRo Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change type annotation (for Union[BaseModel, JsonValue]) to allow deeper BaseModel usage (e.g. a dictionary or a list with BaseModel values).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PydanticValue: TypeAlias = Union[
    List["PydanticValue"],
    Dict[str, "PydanticValue"],
    BaseModel,
    str,
    bool,
    int,
    float,
    None,
]

original_message: Optional[Union[BaseModel, JsonValue]] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge dev to have #398 here.


def __init__( # this allows initializing Message with string as positional argument
self,
Expand All @@ -318,9 +310,9 @@ def __init__( # this allows initializing Message with string as positional argu
]
]
] = None,
annotations: Optional[Dict[str, Any]] = None,
misc: Optional[Dict[str, Any]] = None,
original_message: Optional[Any] = None,
annotations: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None,
misc: Optional[Dict[str, Union[BaseModel, JsonValue]]] = None,
original_message: Optional[Union[BaseModel, JsonValue]] = None,
**kwargs,
):
super().__init__(
Expand All @@ -332,48 +324,6 @@ def __init__( # this allows initializing Message with string as positional argu
**kwargs,
)

@field_serializer("annotations", "misc", when_used="json")
def pickle_serialize_dicts(self, value):
"""
Serialize values that are not json-serializable via pickle.
Allows storing arbitrary data in misc/annotations when using context storages.
"""
if isinstance(value, dict):
return json_pickle_serializer(value)
return value

@field_validator("annotations", "misc", mode="before")
@classmethod
def pickle_validate_dicts(cls, value):
"""Restore values serialized with :py:meth:`pickle_serialize_dicts`."""
if isinstance(value, dict):
return json_pickle_validator(value)
return value

@field_serializer("original_message", when_used="json")
def pickle_serialize_original_message(self, value):
"""
Cast :py:attr:`original_message` to string via pickle.
Allows storing arbitrary data in this field when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("original_message", mode="before")
@classmethod
def pickle_validate_original_message(cls, value):
"""
Restore :py:attr:`original_message` after being processed with
:py:meth:`pickle_serialize_original_message`.
"""
if value is not None:
return pickle_validator(value)
return value

def __str__(self) -> str:
return " ".join([f"{key}='{value}'" for key, value in self.model_dump(exclude_none=True).items()])

@model_validator(mode="before")
@classmethod
def validate_from_str(cls, data):
Expand Down
2 changes: 1 addition & 1 deletion chatsky/messengers/telegram/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ async def _on_event(self, update: Update, _: Any, create_message: Callable[[Upda
data_available = update.message is not None or update.callback_query is not None
if update.effective_chat is not None and data_available:
message = create_message(update)
message.original_message = update
message.original_message = update.to_dict(recursive=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to apply to extra fields in Attachments.
Add a validator for Attachment and Message extras that modifies the extra field via to_dict if the field is of the TelegramObject value.

AFAIK if the extra field value is a dictionary from to_dict it should still work for the tg bot methods.

resp = await self._pipeline_runner(message, update.effective_chat.id)
if resp.last_response is not None:
await self.cast_message_to_telegram_and_send(
Expand Down
38 changes: 8 additions & 30 deletions chatsky/slots/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
import asyncio
import re
from abc import ABC, abstractmethod
from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing import Callable, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing_extensions import TypeAlias, Annotated
import logging
from functools import reduce
from string import Formatter

from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator
from pydantic import BaseModel, JsonValue, model_validator, Field

from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async
from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator

if TYPE_CHECKING:
from chatsky.core import Context, Message
Expand Down Expand Up @@ -117,29 +116,8 @@ class ExtractedValueSlot(ExtractedSlot):
"""Value extracted from :py:class:`~.ValueSlot`."""

is_slot_extracted: bool
extracted_value: Any
default_value: Any = None

@field_serializer("extracted_value", "default_value", when_used="json")
def pickle_serialize_values(self, value):
"""
Cast values to string via pickle.
Allows storing arbitrary data in these fields when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("extracted_value", "default_value", mode="before")
@classmethod
def pickle_validate_values(cls, value):
"""
Restore values after being processed with
:py:meth:`pickle_serialize_values`.
"""
if value is not None:
return pickle_validator(value)
return value
extracted_value: Union[BaseModel, JsonValue]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slots store exceptions on failure (which are not serializable).
I think we should store exception representation instead.

default_value: Optional[Union[BaseModel, JsonValue]] = None

@property
def __slot_extracted__(self) -> bool:
Expand Down Expand Up @@ -219,10 +197,10 @@ class ValueSlot(BaseSlot, frozen=True):
Subclass it, if you want to declare your own slot type.
"""

default_value: Any = None
default_value: Union[BaseModel, JsonValue] = None

@abstractmethod
async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]:
async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]:
"""
Return value extracted from context.

Expand Down Expand Up @@ -328,9 +306,9 @@ class FunctionSlot(ValueSlot, frozen=True):
Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message.
"""

func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]]
func: Callable[[Message], Union[Awaitable[Union[Union[BaseModel, JsonValue], SlotNotExtracted]], Union[BaseModel, JsonValue], SlotNotExtracted]]

async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]:
async def extract_value(self, ctx: Context) -> Union[Union[BaseModel, JsonValue], SlotNotExtracted]:
return await wrap_sync_function_in_async(self.func, ctx.last_request)


Expand Down
154 changes: 0 additions & 154 deletions chatsky/utils/devel/json_serialization.py

This file was deleted.

Loading