Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 59 additions & 43 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ def __init__(
)
raise ValueError(msg)
_astra_db = astra_db_client.copy() if astra_db_client is not None else None
if async_astra_db_client is not None:
_async_astra_db = async_astra_db_client.copy()
else:
_async_astra_db = None
_async_astra_db = (
async_astra_db_client.copy()
if async_astra_db_client is not None
else None
)

# deprecation of the 'core classes' in constructor and conversion
# to token/endpoint(-environment) based init, with checks
Expand Down Expand Up @@ -234,43 +235,16 @@ def __init__(

self.async_setup_db_task: Task | None = None
if setup_mode == SetupMode.ASYNC:
async_database = self.async_database

async def _setup_db() -> None:
if pre_delete_collection:
await async_database.drop_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension

try:
await async_database.create_collection(
name=collection_name,
dimension=dimension,
metric=metric,
indexing=requested_indexing_policy,
# Used for enabling $vectorize on the collection
service=collection_vector_service_options,
check_exists=False,
)
except DataAPIException:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc
async for coll_desc in async_database.list_collections()
]
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise

self.async_setup_db_task = asyncio.create_task(_setup_db())
self.async_setup_db_task = asyncio.create_task(
self._asetup_db(
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
default_indexing_policy=default_indexing_policy,
requested_indexing_policy=requested_indexing_policy,
collection_vector_service_options=collection_vector_service_options,
)
)
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.database.drop_collection(collection_name)
Expand All @@ -283,7 +257,7 @@ async def _setup_db() -> None:
try:
self.database.create_collection(
name=collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
dimension=embedding_dimension,
metric=metric,
indexing=requested_indexing_policy,
# Used for enabling $vectorize on the collection
Expand All @@ -303,6 +277,48 @@ async def _setup_db() -> None:
# other reasons for the exception
raise

async def _asetup_db(
self,
*,
pre_delete_collection: bool,
embedding_dimension: int | Awaitable[int] | None,
metric: str | None,
requested_indexing_policy: dict[str, Any] | None,
default_indexing_policy: dict[str, Any] | None,
collection_vector_service_options: CollectionVectorServiceOptions | None,
) -> None:
if pre_delete_collection:
await self.async_database.drop_collection(self.collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension

try:
await self.async_database.create_collection(
name=self.collection_name,
dimension=dimension,
metric=metric,
indexing=requested_indexing_policy,
# Used for enabling $vectorize on the collection
service=collection_vector_service_options,
check_exists=False,
)
except DataAPIException:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc async for coll_desc in self.async_database.list_collections()
]
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise

@staticmethod
def _validate_indexing_policy(
collection_descriptors: list[CollectionDescriptor],
Expand All @@ -317,7 +333,7 @@ def _validate_indexing_policy(

Args:
collection_descriptors: collection descriptors for the database.
collection_name (str): the name of the collection whose attempted
collection_name: the name of the collection whose attempted
creation failed
requested_indexing_policy: the 'indexing' part of the collection
options, e.g. `{"deny": ["field1", "field2"]}`.
Expand Down
1 change: 0 additions & 1 deletion libs/astradb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ ignore = [
"ISC001", # Messes with the formatter
"PLR09", # TODO: do we enforce these ones (complexity) ?


"D101", # TODO
"D417", # TODO

Expand Down
4 changes: 2 additions & 2 deletions libs/astradb/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def astra_db_credentials() -> AstraDBCredentials:
def database(astra_db_credentials: AstraDBCredentials) -> Database:
return Database(
token=astra_db_credentials["token"],
api_endpoint=astra_db_credentials["api_endpoint"], # type: ignore[arg-type]
api_endpoint=astra_db_credentials["api_endpoint"],
namespace=astra_db_credentials["namespace"],
environment=astra_db_credentials["environment"],
)
Expand All @@ -64,7 +64,7 @@ def database(astra_db_credentials: AstraDBCredentials) -> Database:
def core_astra_db(astra_db_credentials: AstraDBCredentials) -> AstraDB:
return AstraDB(
token=astra_db_credentials["token"],
api_endpoint=astra_db_credentials["api_endpoint"], # type: ignore[arg-type]
api_endpoint=astra_db_credentials["api_endpoint"],
namespace=astra_db_credentials["namespace"],
)

Expand Down