From eb4024e485382c4f0469641fb72f6888c8771246 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Tue, 7 Jan 2025 08:17:32 +0000 Subject: [PATCH] feat: Add chat store init methods --- src/llama_index_cloud_sql_pg/engine.py | 85 ++++++++++++++++++++++++++ tests/test_engine.py | 36 +++++++++++ 2 files changed, 121 insertions(+) diff --git a/src/llama_index_cloud_sql_pg/engine.py b/src/llama_index_cloud_sql_pg/engine.py index cc067db..2faa943 100644 --- a/src/llama_index_cloud_sql_pg/engine.py +++ b/src/llama_index_cloud_sql_pg/engine.py @@ -756,6 +756,91 @@ def init_index_store_table( ) ) + async def _ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + if overwrite_existing: + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + key VARCHAR NOT NULL, + message JSON NOT NULL + );""" + create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.execute(text(create_index_query)) + await conn.commit() + + async def ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + + def init_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + self._run_as_sync( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + async def _aload_table_schema( self, table_name: str, schema_name: str = "public" ) -> Table: diff --git a/tests/test_engine.py b/tests/test_engine.py index fe89197..46af5d0 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -34,6 +34,8 @@ DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4()) VECTOR_SIZE = 768 @@ -113,6 +115,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() async def test_password( @@ -296,6 +299,22 @@ async def test_init_index_store(self, engine): for row in results: assert row in expected + async def test_init_chat_store(self, engine): + await engine.ainit_chat_store_table( + table_name=DEFAULT_CS_TABLE, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -343,6 +362,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() async def test_password( @@ -461,3 +481,19 @@ async def test_init_index_store(self, engine): ] for row in results: assert row in expected + + async def test_init_chat_store(self, engine): + engine.init_chat_store_table( + table_name=DEFAULT_CS_TABLE_SYNC, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected