Skip to content

Commit 6bf0948

Browse files
feat: Adding Async Index Store (#4)
* feat: Adding Async Index Store and it's tests. * Added warning to docstring/ * Remove warning from doctring * Linter fix --------- Co-authored-by: Averi Kitsch <[email protected]>
1 parent b6ecae7 commit 6bf0948

File tree

2 files changed

+377
-0
lines changed

2 files changed

+377
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
import warnings
19+
from typing import List, Optional
20+
21+
from llama_index.core.constants import DATA_KEY
22+
from llama_index.core.data_structs.data_structs import IndexStruct
23+
from llama_index.core.storage.index_store.types import BaseIndexStore
24+
from llama_index.core.storage.index_store.utils import (
25+
index_struct_to_json,
26+
json_to_index_struct,
27+
)
28+
from sqlalchemy import text
29+
from sqlalchemy.ext.asyncio import AsyncEngine
30+
31+
from .engine import PostgresEngine
32+
33+
34+
class AsyncPostgresIndexStore(BaseIndexStore):
35+
"""Index Store Table stored in an Cloud SQL for PostgreSQL database."""
36+
37+
__create_key = object()
38+
39+
def __init__(
40+
self,
41+
key: object,
42+
engine: AsyncEngine,
43+
table_name: str,
44+
schema_name: str = "public",
45+
):
46+
"""AsyncPostgresIndexStore constructor.
47+
48+
Args:
49+
key (object): Key to prevent direct constructor usage.
50+
engine (PostgresEngine): Database connection pool.
51+
table_name (str): Table name that stores the index metadata.
52+
schema_name (str): The schema name where the table is located. Defaults to "public"
53+
54+
Raises:
55+
Exception: If constructor is directly called by the user.
56+
"""
57+
if key != AsyncPostgresIndexStore.__create_key:
58+
raise Exception("Only create class through 'create' method!")
59+
self._engine = engine
60+
self._table_name = table_name
61+
self._schema_name = schema_name
62+
63+
@classmethod
64+
async def create(
65+
cls,
66+
engine: PostgresEngine,
67+
table_name: str,
68+
schema_name: str = "public",
69+
) -> AsyncPostgresIndexStore:
70+
"""Create a new AsyncPostgresIndexStore instance.
71+
72+
Args:
73+
engine (PostgresEngine): Postgres engine to use.
74+
table_name (str): Table name that stores the index metadata.
75+
schema_name (str): The schema name where the table is located. Defaults to "public"
76+
77+
Raises:
78+
ValueError: If the table provided does not contain required schema.
79+
80+
Returns:
81+
AsyncPostgresIndexStore: A newly created instance of AsyncPostgresIndexStore.
82+
"""
83+
table_schema = await engine._aload_table_schema(table_name, schema_name)
84+
column_names = table_schema.columns.keys()
85+
86+
required_columns = ["index_id", "type", "index_data"]
87+
88+
if not (all(x in column_names for x in required_columns)):
89+
raise ValueError(
90+
f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n"
91+
f"Expected column names: {required_columns}\n"
92+
f"Provided column names: {column_names}\n"
93+
"Please create the table with the following schema:\n"
94+
f"CREATE TABLE {schema_name}.{table_name} (\n"
95+
" index_id VARCHAR PRIMARY KEY,\n"
96+
" type VARCHAR NOT NULL,\n"
97+
" index_data JSONB NOT NULL\n"
98+
");"
99+
)
100+
101+
return cls(cls.__create_key, engine._pool, table_name, schema_name)
102+
103+
async def __aexecute_query(self, query, params=None):
104+
async with self._engine.connect() as conn:
105+
await conn.execute(text(query), params)
106+
await conn.commit()
107+
108+
async def __afetch_query(self, query):
109+
async with self._engine.connect() as conn:
110+
result = await conn.execute(text(query))
111+
result_map = result.mappings()
112+
results = result_map.fetchall()
113+
await conn.commit()
114+
return results
115+
116+
async def aindex_structs(self) -> List[IndexStruct]:
117+
"""Get all index structs.
118+
119+
Returns:
120+
List[IndexStruct]: index structs
121+
122+
"""
123+
query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";"""
124+
index_list = await self.__afetch_query(query)
125+
126+
if index_list:
127+
return [json_to_index_struct(index["index_data"]) for index in index_list]
128+
return []
129+
130+
async def aadd_index_struct(self, index_struct: IndexStruct) -> None:
131+
"""Add an index struct.
132+
133+
Args:
134+
index_struct (IndexStruct): index struct
135+
136+
"""
137+
key = index_struct.index_id
138+
data = index_struct_to_json(index_struct)
139+
type = index_struct.get_type()
140+
141+
index_row = {
142+
"index_id": key,
143+
"type": type,
144+
"index_data": json.dumps(data),
145+
}
146+
147+
insert_query = f'INSERT INTO "{self._schema_name}"."{self._table_name}"(index_id, type, index_data) '
148+
values_statement = f"VALUES (:index_id, :type, :index_data)"
149+
upsert_statement = " ON CONFLICT (index_id) DO UPDATE SET type = EXCLUDED.type, index_data = EXCLUDED.index_data;"
150+
151+
query = insert_query + values_statement + upsert_statement
152+
await self.__aexecute_query(query, index_row)
153+
154+
async def adelete_index_struct(self, key: str) -> None:
155+
"""Delete an index struct.
156+
157+
Args:
158+
key (str): index struct key
159+
160+
"""
161+
query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE index_id = '{key}'; """
162+
await self.__aexecute_query(query)
163+
return
164+
165+
async def aget_index_struct(
166+
self, struct_id: Optional[str] = None
167+
) -> Optional[IndexStruct]:
168+
"""Get an index struct.
169+
170+
Args:
171+
struct_id (Optional[str]): index struct id
172+
173+
"""
174+
if struct_id is None:
175+
structs = await self.aindex_structs()
176+
if len(structs) == 1:
177+
return structs[0]
178+
warnings.warn("No struct_id specified and more than one struct exists.")
179+
return None
180+
else:
181+
query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE index_id = '{struct_id}';"""
182+
result = await self.__afetch_query(query)
183+
if result:
184+
json = result[0]
185+
if json is None:
186+
return None
187+
index_data = json.get("index_data")
188+
189+
if index_data:
190+
return json_to_index_struct(index_data)
191+
return None
192+
193+
def index_structs(self) -> List[IndexStruct]:
194+
raise NotImplementedError(
195+
"Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead."
196+
)
197+
198+
def add_index_struct(self, index_struct: IndexStruct) -> None:
199+
raise NotImplementedError(
200+
"Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead."
201+
)
202+
203+
def delete_index_struct(self, key: str) -> None:
204+
raise NotImplementedError(
205+
"Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead."
206+
)
207+
208+
def get_index_struct(
209+
self, struct_id: Optional[str] = None
210+
) -> Optional[IndexStruct]:
211+
raise NotImplementedError(
212+
"Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead."
213+
)

tests/test_async_index_store.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import uuid
17+
import warnings
18+
from typing import Sequence
19+
20+
import pytest
21+
import pytest_asyncio
22+
from llama_index.core.data_structs.data_structs import IndexDict, IndexGraph, IndexList
23+
from sqlalchemy import RowMapping, text
24+
25+
from llama_index_cloud_sql_pg import PostgresEngine
26+
from llama_index_cloud_sql_pg.async_index_store import AsyncPostgresIndexStore
27+
28+
default_table_name_async = "index_store_" + str(uuid.uuid4())
29+
30+
31+
async def aexecute(engine: PostgresEngine, query: str) -> None:
32+
async with engine._pool.connect() as conn:
33+
await conn.execute(text(query))
34+
await conn.commit()
35+
36+
37+
async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]:
38+
async with engine._pool.connect() as conn:
39+
result = await conn.execute(text(query))
40+
result_map = result.mappings()
41+
result_fetch = result_map.fetchall()
42+
return result_fetch
43+
44+
45+
def get_env_var(key: str, desc: str) -> str:
46+
v = os.environ.get(key)
47+
if v is None:
48+
raise ValueError(f"Must set env var {key} to: {desc}")
49+
return v
50+
51+
52+
@pytest.mark.asyncio(loop_scope="class")
53+
class TestAsyncPostgresIndexStore:
54+
@pytest.fixture(scope="module")
55+
def db_project(self) -> str:
56+
return get_env_var("PROJECT_ID", "project id for google cloud")
57+
58+
@pytest.fixture(scope="module")
59+
def db_region(self) -> str:
60+
return get_env_var("REGION", "region for Cloud SQL instance")
61+
62+
@pytest.fixture(scope="module")
63+
def db_instance(self) -> str:
64+
return get_env_var("INSTANCE_ID", "instance for Cloud SQL")
65+
66+
@pytest.fixture(scope="module")
67+
def db_name(self) -> str:
68+
return get_env_var("DATABASE_ID", "database name on Cloud SQL instance")
69+
70+
@pytest.fixture(scope="module")
71+
def user(self) -> str:
72+
return get_env_var("DB_USER", "database user for Cloud SQL")
73+
74+
@pytest.fixture(scope="module")
75+
def password(self) -> str:
76+
return get_env_var("DB_PASSWORD", "database password for Cloud SQL")
77+
78+
@pytest_asyncio.fixture(scope="class")
79+
async def async_engine(self, db_project, db_region, db_instance, db_name):
80+
async_engine = await PostgresEngine.afrom_instance(
81+
project_id=db_project,
82+
instance=db_instance,
83+
region=db_region,
84+
database=db_name,
85+
)
86+
87+
yield async_engine
88+
89+
await async_engine.close()
90+
91+
@pytest_asyncio.fixture(scope="class")
92+
async def index_store(self, async_engine):
93+
await async_engine._ainit_index_store_table(table_name=default_table_name_async)
94+
95+
index_store = await AsyncPostgresIndexStore.create(
96+
engine=async_engine, table_name=default_table_name_async
97+
)
98+
99+
yield index_store
100+
101+
query = f'DROP TABLE IF EXISTS "{default_table_name_async}"'
102+
await aexecute(async_engine, query)
103+
104+
async def test_init_with_constructor(self, async_engine):
105+
with pytest.raises(Exception):
106+
AsyncPostgresIndexStore(
107+
engine=async_engine, table_name=default_table_name_async
108+
)
109+
110+
async def test_add_and_delete_index(self, index_store, async_engine):
111+
index_struct = IndexGraph()
112+
index_id = index_struct.index_id
113+
index_type = index_struct.get_type()
114+
await index_store.aadd_index_struct(index_struct)
115+
116+
query = f"""select * from "public"."{default_table_name_async}" where index_id = '{index_id}';"""
117+
results = await afetch(async_engine, query)
118+
result = results[0]
119+
assert result.get("type") == index_type
120+
121+
await index_store.adelete_index_struct(index_id)
122+
query = f"""select * from "public"."{default_table_name_async}" where index_id = '{index_id}';"""
123+
results = await afetch(async_engine, query)
124+
assert results == []
125+
126+
async def test_get_index(self, index_store):
127+
index_struct = IndexGraph()
128+
index_id = index_struct.index_id
129+
index_type = index_struct.get_type()
130+
await index_store.aadd_index_struct(index_struct)
131+
132+
ind_struct = await index_store.aget_index_struct(index_id)
133+
134+
assert index_struct == ind_struct
135+
136+
async def test_aindex_structs(self, index_store):
137+
index_dict_struct = IndexDict()
138+
index_list_struct = IndexList()
139+
index_graph_struct = IndexGraph()
140+
141+
await index_store.aadd_index_struct(index_dict_struct)
142+
await index_store.aadd_index_struct(index_graph_struct)
143+
await index_store.aadd_index_struct(index_list_struct)
144+
145+
indexes = await index_store.aindex_structs()
146+
147+
assert index_dict_struct in indexes
148+
assert index_list_struct in indexes
149+
assert index_graph_struct in indexes
150+
151+
async def test_warning(self, index_store):
152+
index_dict_struct = IndexDict()
153+
index_list_struct = IndexList()
154+
155+
await index_store.aadd_index_struct(index_dict_struct)
156+
await index_store.aadd_index_struct(index_list_struct)
157+
158+
with warnings.catch_warnings(record=True) as w:
159+
index_struct = await index_store.aget_index_struct()
160+
161+
assert len(w) == 1
162+
assert "No struct_id specified and more than one struct exists." in str(
163+
w[-1].message
164+
)

0 commit comments

Comments
 (0)