From a6b31b0a4d291578a74da61d3b72b997ee1e0645 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Tue, 7 Jan 2025 08:00:59 +0000 Subject: [PATCH 1/2] feat: Add Async Chat Store --- .../async_chat_store.py | 295 ++++++++++++++++++ tests/test_async_chat_store.py | 224 +++++++++++++ 2 files changed, 519 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/async_chat_store.py create mode 100644 tests/test_async_chat_store.py diff --git a/src/llama_index_cloud_sql_pg/async_chat_store.py b/src/llama_index_cloud_sql_pg/async_chat_store.py new file mode 100644 index 0000000..8d80543 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_chat_store.py @@ -0,0 +1,295 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + + +class AsyncPostgresChatStore(BaseChatStore): + """Chat Store Table stored in an CloudSQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + ): + """AsyncPostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPostgresChatStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> AsyncPostgresChatStore: + """Create a new AsyncPostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AsyncPostgresChatStore: A newly created instance of AsyncPostgresChatStore. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "key", "message"] + + if not (all(x in column_names for x in required_columns)): + raise ValueError( + f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n" + f"Expected column names: {required_columns}\n" + f"Provided column names: {column_names}\n" + "Please create the table with the following schema:\n" + f"CREATE TABLE {schema_name}.{table_name} (\n" + " id SERIAL PRIMARY KEY,\n" + " key VARCHAR NOT NULL,\n" + " message JSON NOT NULL\n" + ");" + ) + + return cls(cls.__create_key, engine._pool, table_name, schema_name) + + async def __aexecute_query(self, query, params=None): + async with self._engine.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def __afetch_query(self, query): + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + await conn.commit() + return results + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AsyncPostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """ + await self.__aexecute_query(query) + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + + params = [ + { + "key": key, + "message": json.dumps(message.dict()), + } + for message in messages + ] + + await self.__aexecute_query(insert_query, params) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return [] + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + params = {"key": key, "message": json.dumps(message.dict())} + + await self.__aexecute_query(insert_query, params) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """ + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return None + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + if idx >= len(results): + return None + id_to_be_deleted = results[idx].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;""" + results = await self.__afetch_query(query) + if results: + id_to_be_deleted = results[0].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + query = ( + f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";""" + ) + results = await self.__afetch_query(query) + keys = [] + if results: + keys = [row.get("key") for row in results] + return keys + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def add_message(self, key: str, message: ChatMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_keys(self) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py new file mode 100644 index 0000000..b2aad02 --- /dev/null +++ b/tests/test_async_chat_store.py @@ -0,0 +1,224 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_chat_store import AsyncPostgresChatStore + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresChatStores: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for Cloud SQL") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def chat_store(self, async_engine): + await async_engine._ainit_chat_store_table(table_name=default_table_name_async) + + chat_store = await AsyncPostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncPostgresChatStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_async_add_message(self, async_engine, chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await chat_store.aset_messages(key, messages) + + results = await chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await chat_store.aset_messages(key_1, message_1) + await chat_store.aset_messages(key_2, message_2) + + keys = await chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From 18fdd960c4c7eb909d05cfe475a81a29afc6f40d Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Tue, 7 Jan 2025 08:04:36 +0000 Subject: [PATCH 2/2] fix tests --- tests/test_async_chat_store.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py index b2aad02..dd22ac1 100644 --- a/tests/test_async_chat_store.py +++ b/tests/test_async_chat_store.py @@ -58,10 +58,6 @@ def db_project(self) -> str: def db_region(self) -> str: return get_env_var("REGION", "region for Cloud SQL instance") - @pytest.fixture(scope="module") - def db_cluster(self) -> str: - return get_env_var("CLUSTER_ID", "cluster for Cloud SQL") - @pytest.fixture(scope="module") def db_instance(self) -> str: return get_env_var("INSTANCE_ID", "instance for Cloud SQL") @@ -83,14 +79,12 @@ async def async_engine( self, db_project, db_region, - db_cluster, db_instance, db_name, ): async_engine = await PostgresEngine.afrom_instance( project_id=db_project, instance=db_instance, - cluster=db_cluster, region=db_region, database=db_name, )