Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from importlib import import_module
from logging import getLogger
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union, Set
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union, Set

from chatsky.core.ctx_utils import ContextMainInfo
from chatsky.utils.decorations import classproperty
from chatsky.utils.logging import collapse_num_list
from .protocol import PROTOCOLS

if TYPE_CHECKING:
from chatsky.core.context import Context

_SUBSCRIPT_TYPE = Union[Literal["__all__"], int, Set[int]]
_SUBSCRIPT_DICT = Dict[Literal["labels", "requests", "responses"], _SUBSCRIPT_TYPE]

Expand Down Expand Up @@ -169,7 +172,7 @@ async def connect(self) -> None:
self.connected = True

@abstractmethod
async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
raise NotImplementedError

@_lock
Expand All @@ -184,13 +187,13 @@ async def load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
logger.debug(f"Loading main info for {ctx_id}...")
result = await self._load_main_info(ctx_id)
logger.debug(f"Main info loaded for {ctx_id}")
return result
return ContextMainInfo.model_validate(result) if result is not None else None

@abstractmethod
async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
raise NotImplementedError
Expand All @@ -199,7 +202,7 @@ async def _update_context(
async def update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo] = None,
ctx_info: Optional[Union[ContextMainInfo, Context]] = None,
field_info: Optional[List[Tuple[str, List[Tuple[int, bytes]], List[int]]]] = None,
) -> None:
"""
Expand All @@ -225,7 +228,8 @@ async def update_context(
else:
field_info += [(k, None) for k in deleted]
logger.debug(f"\tDeleting fields for {field}: {collapse_num_list(deleted)}...")
await self._update_context(ctx_id, ctx_info, list(joined_field_info.items()))
ctx_info_dump = ContextMainInfo.model_dump(ctx_info, mode="python") if ctx_info is not None else None
await self._update_context(ctx_id, ctx_info_dump, list(joined_field_info.items()))
logger.debug(f"Context updated for {ctx_id}")

@abstractmethod
Expand Down
9 changes: 4 additions & 5 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from abc import ABC, abstractmethod
from pickle import loads, dumps
from shelve import DbfilenameShelf
from typing import List, Set, Tuple, Dict, Optional
from typing import Any, List, Set, Tuple, Dict, Optional

from pydantic import BaseModel, Field

from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT

try:
Expand All @@ -33,7 +32,7 @@ class SerializableStorage(BaseModel):
One element of this class will be used to store all the contexts, read and written to file on every turn.
"""

main: Dict[str, ContextMainInfo] = Field(default_factory=dict)
main: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
turns: List[Tuple[str, str, int, Optional[bytes]]] = Field(default_factory=list)


Expand Down Expand Up @@ -67,13 +66,13 @@ async def _load(self) -> SerializableStorage:
async def _connect(self):
await self._load()

async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
return (await self._load()).main.get(ctx_id, None)

async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
storage = await self._load()
Expand Down
7 changes: 3 additions & 4 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
The Memory module provides an in-RAM version of the :py:class:`.DBContextStorage` class.
"""

from typing import List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple

from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT, NameConfig


Expand Down Expand Up @@ -44,13 +43,13 @@ def __init__(
async def _connect(self):
pass

async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
return self._main_storage.get(ctx_id, None)

async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
if ctx_info is not None:
Expand Down
18 changes: 6 additions & 12 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

mongo_available = False

from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT, NameConfig
from .protocol import get_protocol_install_suggestion

Expand Down Expand Up @@ -76,21 +75,17 @@ async def _connect(self):
),
)

async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
result = await self.main_table.find_one(
{NameConfig._id_column: ctx_id},
NameConfig.get_context_main_fields,
)
return (
ContextMainInfo.model_validate({f: result[f] for f in NameConfig.get_context_main_fields})
if result is not None
else None
)
return {f: result[f] for f in NameConfig.get_context_main_fields} if result is not None else None

async def _inner_update_context(
self,
ctx_id: str,
ctx_info_dump: Optional[Dict],
ctx_info_dump: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
session: Optional[AsyncIOMotorClientSession],
) -> None:
Expand Down Expand Up @@ -123,16 +118,15 @@ async def _inner_update_context(
async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
ctx_info_dump = ctx_info.model_dump(mode="python") if ctx_info is not None else None
if self._transactions_enabled:
async with await self._mongo.start_session() as session:
async with session.start_transaction():
await self._inner_update_context(ctx_id, ctx_info_dump, field_info, session)
await self._inner_update_context(ctx_id, ctx_info, field_info, session)
else:
await self._inner_update_context(ctx_id, ctx_info_dump, field_info, None)
await self._inner_update_context(ctx_id, ctx_info, field_info, None)

async def _delete_context(self, ctx_id: str) -> None:
await gather(
Expand Down
14 changes: 5 additions & 9 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from asyncio import gather
from typing import List, Set, Tuple, Optional
from typing import Any, Dict, List, Set, Tuple, Optional

try:
from redis.asyncio import Redis
Expand All @@ -23,7 +23,6 @@
except ImportError:
redis_available = False

from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT, NameConfig
from .protocol import get_protocol_install_suggestion

Expand Down Expand Up @@ -79,28 +78,25 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]:
def _bytes_to_keys(keys: List[bytes]) -> List[int]:
return [int(f.decode("utf-8")) for f in keys]

async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
if await self.database.exists(f"{self._main_key}:{ctx_id}"):
retrieved_fields = await gather(
*[self.database.hget(f"{self._main_key}:{ctx_id}", f) for f in NameConfig.get_context_main_fields]
)
return ContextMainInfo.model_validate(
{f: v for f, v in zip(NameConfig.get_context_main_fields, retrieved_fields)}
)
return {f: v for f, v in zip(NameConfig.get_context_main_fields, retrieved_fields)}
else:
return None

async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
update_main, update_values, delete_keys = list(), list(), list()
if ctx_info is not None:
ctx_info_dump = ctx_info.model_dump(mode="python")
update_main = [
(f, ctx_info_dump[f] if isinstance(ctx_info_dump[f], bytes) else str(ctx_info_dump[f]))
(f, ctx_info[f] if isinstance(ctx_info[f], bytes) else str(ctx_info[f]))
for f in NameConfig.get_context_main_fields
]
for field_name, items in field_info:
Expand Down
16 changes: 5 additions & 11 deletions chatsky/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations
from asyncio import gather
from importlib import import_module
from typing import Callable, Collection, List, Optional, Set, Tuple
from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple
from logging import getLogger

try:
Expand Down Expand Up @@ -69,7 +69,6 @@
except (ImportError, ModuleNotFoundError):
sqlite_available = False

from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT, NameConfig
from .protocol import get_protocol_install_suggestion

Expand Down Expand Up @@ -204,32 +203,27 @@ def _check_availability(self):
install_suggestion = get_protocol_install_suggestion("sqlite")
raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion)

async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
stmt = select(self.main_table).where(self.main_table.c[NameConfig._id_column] == ctx_id)
async with self.engine.begin() as conn:
result = (await conn.execute(stmt)).fetchone()
return (
None
if result is None
else ContextMainInfo.model_validate(
{f: result[i + 1] for i, f in enumerate(NameConfig.get_context_main_fields)}
)
None if result is None else {f: result[i + 1] for i, f in enumerate(NameConfig.get_context_main_fields)}
)

async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
main_update_stmt, turns_update_stmts = None, list()
if ctx_info is not None:
ctx_info_dump = ctx_info.model_dump(mode="python")
main_insert_stmt = self._INSERT_CALLABLE(self.main_table).values(
{
NameConfig._id_column: ctx_id,
}
| {f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields}
| {f: ctx_info[f] for f in NameConfig.get_context_main_fields}
)
main_update_stmt = _get_upsert_stmt(
self.dialect,
Expand Down
16 changes: 6 additions & 10 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from asyncio import gather
from os.path import join
from typing import Awaitable, Callable, Set, Tuple, List, Optional
from typing import Any, Awaitable, Callable, Dict, Set, Tuple, List, Optional
from urllib.parse import urlsplit

try:
Expand All @@ -31,7 +31,6 @@
except ImportError:
ydb_available = False

from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT, NameConfig
from .protocol import get_protocol_install_suggestion

Expand Down Expand Up @@ -136,8 +135,8 @@ async def callee(session: Session) -> None:

await self.pool.retry_operation(callee)

async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
async def callee(session: Session) -> Optional[ContextMainInfo]:
async def _load_main_info(self, ctx_id: str) -> Optional[Dict[str, Any]]:
async def callee(session: Session) -> Optional[Dict[str, Any]]:
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${NameConfig._id_column} AS Utf8;
Expand All @@ -153,9 +152,7 @@ async def callee(session: Session) -> Optional[ContextMainInfo]:
commit_tx=True,
)
return (
ContextMainInfo.model_validate(
{f: result_sets[0].rows[0][f] for f in NameConfig.get_context_main_fields}
)
{f: result_sets[0].rows[0][f] for f in NameConfig.get_context_main_fields}
if len(result_sets[0].rows) > 0
else None
)
Expand All @@ -165,13 +162,12 @@ async def callee(session: Session) -> Optional[ContextMainInfo]:
async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
ctx_info: Optional[Dict[str, Any]],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
async def callee(session: Session) -> None:
transaction = await session.transaction(SerializableReadWrite()).begin()
if ctx_info is not None:
ctx_info_dump = ctx_info.model_dump(mode="python")
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${NameConfig._id_column} AS Utf8;
Expand All @@ -188,7 +184,7 @@ async def callee(session: Session) -> None:
{
f"${NameConfig._id_column}": ctx_id,
}
| {f"${f}": ctx_info_dump[f] for f in NameConfig.get_context_main_fields},
| {f"${f}": ctx_info[f] for f in NameConfig.get_context_main_fields},
)
for field_name, items in field_info:
declare, prepare, values = list(), dict(), list()
Expand Down
Loading
Loading