Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/llama_index_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,91 @@ def init_index_store_table(
)
)

async def _ainit_chat_store_table(
self,
table_name: str,
schema_name: str = "public",
overwrite_existing: bool = False,
) -> None:
"""
Create an table to save chat store.
Args:
table_name (str): The table name to store chat history.
schema_name (str): The schema name to store the chat store table.
Default: "public".
overwrite_existing (bool): Whether to drop existing table.
Default: False.
Returns:
None
"""
if overwrite_existing:
async with self._pool.connect() as conn:
await conn.execute(
text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
)
await conn.commit()

create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
id SERIAL PRIMARY KEY,
key VARCHAR NOT NULL,
message JSON NOT NULL
);"""
create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);"""
async with self._pool.connect() as conn:
await conn.execute(text(create_table_query))
await conn.execute(text(create_index_query))
await conn.commit()

async def ainit_chat_store_table(
self,
table_name: str,
schema_name: str = "public",
overwrite_existing: bool = False,
) -> None:
"""
Create an table to save chat store.
Args:
table_name (str): The table name to store chat store.
schema_name (str): The schema name to store the chat store table.
Default: "public".
overwrite_existing (bool): Whether to drop existing table.
Default: False.
Returns:
None
"""
await self._run_as_async(
self._ainit_chat_store_table(
table_name,
schema_name,
overwrite_existing,
)
)

def init_chat_store_table(
self,
table_name: str,
schema_name: str = "public",
overwrite_existing: bool = False,
) -> None:
"""
Create an table to save chat store.
Args:
table_name (str): The table name to store chat store.
schema_name (str): The schema name to store the chat store table.
Default: "public".
overwrite_existing (bool): Whether to drop existing table.
Default: False.
Returns:
None
"""
self._run_as_sync(
self._ainit_chat_store_table(
table_name,
schema_name,
overwrite_existing,
)
)

async def _aload_table_schema(
self, table_name: str, schema_name: str = "public"
) -> Table:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4())
DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4())
DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4())
DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4())
DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4())
VECTOR_SIZE = 768


Expand Down Expand Up @@ -113,6 +115,7 @@ async def engine(self, db_project, db_region, db_instance, db_name):
await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"')
await engine.close()

async def test_password(
Expand Down Expand Up @@ -296,6 +299,22 @@ async def test_init_index_store(self, engine):
for row in results:
assert row in expected

async def test_init_chat_store(self, engine):
await engine.ainit_chat_store_table(
table_name=DEFAULT_CS_TABLE,
schema_name="public",
overwrite_existing=True,
)
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';"
results = await afetch(engine, stmt)
expected = [
{"column_name": "id", "data_type": "integer"},
{"column_name": "key", "data_type": "character varying"},
{"column_name": "message", "data_type": "json"},
]
for row in results:
assert row in expected


@pytest.mark.asyncio
class TestEngineSync:
Expand Down Expand Up @@ -343,6 +362,7 @@ async def engine(self, db_project, db_region, db_instance, db_name):
await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"')
await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"')
await engine.close()

async def test_password(
Expand Down Expand Up @@ -461,3 +481,19 @@ async def test_init_index_store(self, engine):
]
for row in results:
assert row in expected

async def test_init_chat_store(self, engine):
engine.init_chat_store_table(
table_name=DEFAULT_CS_TABLE_SYNC,
schema_name="public",
overwrite_existing=True,
)
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';"
results = await afetch(engine, stmt)
expected = [
{"column_name": "id", "data_type": "integer"},
{"column_name": "key", "data_type": "character varying"},
{"column_name": "message", "data_type": "json"},
]
for row in results:
assert row in expected
Loading