Skip to content

Commit 2b14f5a

Browse files
feat: Add Async Chat Store (#38)
* feat: Add Async Chat Store * fix tests --------- Co-authored-by: Averi Kitsch <[email protected]>
1 parent 0ef1fa5 commit 2b14f5a

File tree

2 files changed

+513
-0
lines changed

2 files changed

+513
-0
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
from typing import List, Optional
19+
20+
from llama_index.core.llms import ChatMessage
21+
from llama_index.core.storage.chat_store.base import BaseChatStore
22+
from sqlalchemy import text
23+
from sqlalchemy.ext.asyncio import AsyncEngine
24+
25+
from .engine import PostgresEngine
26+
27+
28+
class AsyncPostgresChatStore(BaseChatStore):
29+
"""Chat Store Table stored in an CloudSQL for PostgreSQL database."""
30+
31+
__create_key = object()
32+
33+
def __init__(
34+
self,
35+
key: object,
36+
engine: AsyncEngine,
37+
table_name: str,
38+
schema_name: str = "public",
39+
):
40+
"""AsyncPostgresChatStore constructor.
41+
42+
Args:
43+
key (object): Key to prevent direct constructor usage.
44+
engine (PostgresEngine): Database connection pool.
45+
table_name (str): Table name that stores the chat store.
46+
schema_name (str): The schema name where the table is located.
47+
Defaults to "public"
48+
49+
Raises:
50+
Exception: If constructor is directly called by the user.
51+
"""
52+
if key != AsyncPostgresChatStore.__create_key:
53+
raise Exception("Only create class through 'create' method!")
54+
55+
# Delegate to Pydantic's __init__
56+
super().__init__()
57+
self._engine = engine
58+
self._table_name = table_name
59+
self._schema_name = schema_name
60+
61+
@classmethod
62+
async def create(
63+
cls,
64+
engine: PostgresEngine,
65+
table_name: str,
66+
schema_name: str = "public",
67+
) -> AsyncPostgresChatStore:
68+
"""Create a new AsyncPostgresChatStore instance.
69+
70+
Args:
71+
engine (PostgresEngine): Postgres engine to use.
72+
table_name (str): Table name that stores the chat store.
73+
schema_name (str): The schema name where the table is located.
74+
Defaults to "public"
75+
76+
Raises:
77+
ValueError: If the table provided does not contain required schema.
78+
79+
Returns:
80+
AsyncPostgresChatStore: A newly created instance of AsyncPostgresChatStore.
81+
"""
82+
table_schema = await engine._aload_table_schema(table_name, schema_name)
83+
column_names = table_schema.columns.keys()
84+
85+
required_columns = ["id", "key", "message"]
86+
87+
if not (all(x in column_names for x in required_columns)):
88+
raise ValueError(
89+
f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n"
90+
f"Expected column names: {required_columns}\n"
91+
f"Provided column names: {column_names}\n"
92+
"Please create the table with the following schema:\n"
93+
f"CREATE TABLE {schema_name}.{table_name} (\n"
94+
" id SERIAL PRIMARY KEY,\n"
95+
" key VARCHAR NOT NULL,\n"
96+
" message JSON NOT NULL\n"
97+
");"
98+
)
99+
100+
return cls(cls.__create_key, engine._pool, table_name, schema_name)
101+
102+
async def __aexecute_query(self, query, params=None):
103+
async with self._engine.connect() as conn:
104+
await conn.execute(text(query), params)
105+
await conn.commit()
106+
107+
async def __afetch_query(self, query):
108+
async with self._engine.connect() as conn:
109+
result = await conn.execute(text(query))
110+
result_map = result.mappings()
111+
results = result_map.fetchall()
112+
await conn.commit()
113+
return results
114+
115+
@classmethod
116+
def class_name(cls) -> str:
117+
"""Get class name."""
118+
return "AsyncPostgresChatStore"
119+
120+
async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None:
121+
"""Asynchronously sets the chat messages for a specific key.
122+
123+
Args:
124+
key (str): A unique identifier for the chat.
125+
messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert.
126+
127+
Returns:
128+
None
129+
130+
"""
131+
query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """
132+
await self.__aexecute_query(query)
133+
insert_query = f"""
134+
INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message)
135+
VALUES (:key, :message);"""
136+
137+
params = [
138+
{
139+
"key": key,
140+
"message": json.dumps(message.dict()),
141+
}
142+
for message in messages
143+
]
144+
145+
await self.__aexecute_query(insert_query, params)
146+
147+
async def aget_messages(self, key: str) -> List[ChatMessage]:
148+
"""Asynchronously retrieves the chat messages associated with a specific key.
149+
150+
Args:
151+
key (str): A unique identifier for which the messages are to be retrieved.
152+
153+
Returns:
154+
List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key.
155+
If no messages are found, an empty list is returned.
156+
"""
157+
query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;"""
158+
results = await self.__afetch_query(query)
159+
if results:
160+
return [
161+
ChatMessage.model_validate(result.get("message")) for result in results
162+
]
163+
return []
164+
165+
async def async_add_message(self, key: str, message: ChatMessage) -> None:
166+
"""Asynchronously adds a new chat message to the specified key.
167+
168+
Args:
169+
key (str): A unique identifierfor the chat to which the message is added.
170+
message (ChatMessage): The `ChatMessage` object that is to be added.
171+
172+
Returns:
173+
None
174+
"""
175+
insert_query = f"""
176+
INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message)
177+
VALUES (:key, :message);"""
178+
params = {"key": key, "message": json.dumps(message.dict())}
179+
180+
await self.__aexecute_query(insert_query, params)
181+
182+
async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]:
183+
"""Asynchronously deletes the chat messages associated with a specific key.
184+
185+
Args:
186+
key (str): A unique identifier for the chat whose messages are to be deleted.
187+
188+
Returns:
189+
Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages
190+
were associated with the key or could be deleted.
191+
"""
192+
query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """
193+
results = await self.__afetch_query(query)
194+
if results:
195+
return [
196+
ChatMessage.model_validate(result.get("message")) for result in results
197+
]
198+
return None
199+
200+
async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
201+
"""Asynchronously deletes a specific chat message by index from the messages associated with a given key.
202+
203+
Args:
204+
key (str): A unique identifier for the chat whose messages are to be deleted.
205+
idx (int): The index of the `ChatMessage` to be deleted from the list of messages.
206+
207+
Returns:
208+
Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message
209+
was associated with the key or could be deleted.
210+
"""
211+
query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;"""
212+
results = await self.__afetch_query(query)
213+
if results:
214+
if idx >= len(results):
215+
return None
216+
id_to_be_deleted = results[idx].get("id")
217+
delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;"""
218+
result = await self.__afetch_query(delete_query)
219+
result = result[0]
220+
if result:
221+
return ChatMessage.model_validate(result.get("message"))
222+
return None
223+
return None
224+
225+
async def adelete_last_message(self, key: str) -> Optional[ChatMessage]:
226+
"""Asynchronously deletes the last chat message associated with a given key.
227+
228+
Args:
229+
key (str): A unique identifier for the chat whose message is to be deleted.
230+
231+
Returns:
232+
Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message
233+
was associated with the key or could be deleted.
234+
"""
235+
query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;"""
236+
results = await self.__afetch_query(query)
237+
if results:
238+
id_to_be_deleted = results[0].get("id")
239+
delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;"""
240+
result = await self.__afetch_query(delete_query)
241+
result = result[0]
242+
if result:
243+
return ChatMessage.model_validate(result.get("message"))
244+
return None
245+
return None
246+
247+
async def aget_keys(self) -> List[str]:
248+
"""Asynchronously retrieves a list of all keys.
249+
250+
Returns:
251+
Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned.
252+
"""
253+
query = (
254+
f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";"""
255+
)
256+
results = await self.__afetch_query(query)
257+
keys = []
258+
if results:
259+
keys = [row.get("key") for row in results]
260+
return keys
261+
262+
def set_messages(self, key: str, messages: List[ChatMessage]) -> None:
263+
raise NotImplementedError(
264+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
265+
)
266+
267+
def get_messages(self, key: str) -> List[ChatMessage]:
268+
raise NotImplementedError(
269+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
270+
)
271+
272+
def add_message(self, key: str, message: ChatMessage) -> None:
273+
raise NotImplementedError(
274+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
275+
)
276+
277+
def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
278+
raise NotImplementedError(
279+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
280+
)
281+
282+
def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]:
283+
raise NotImplementedError(
284+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
285+
)
286+
287+
def delete_last_message(self, key: str) -> Optional[ChatMessage]:
288+
raise NotImplementedError(
289+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
290+
)
291+
292+
def get_keys(self) -> List[str]:
293+
raise NotImplementedError(
294+
"Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead."
295+
)

0 commit comments

Comments
 (0)