From 59c712c3ea30d994eab87a386f9b359831da1f36 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Thu, 13 Feb 2025 15:07:50 +0100 Subject: [PATCH 1/4] centralize codec's id/vector encoding; add multi-ids encoding --- .../utils/vector_store_codecs.py | 57 +++++++------------ .../astradb/langchain_astradb/vectorstores.py | 28 ++++----- 2 files changed, 37 insertions(+), 48 deletions(-) diff --git a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py index ae15403..1803314 100644 --- a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py +++ b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py @@ -46,6 +46,12 @@ def _default_encode_id(filter_id: str) -> dict[str, Any]: return {"_id": filter_id} +def _default_encode_ids(filter_ids: list[str]) -> dict[str, Any]: + if len(filter_ids) == 1: + return _default_encode_id(filter_ids[0]) + return {"_id": {"$in": filter_ids}} + + def _default_encode_vector_sort(vector: list[float]) -> dict[str, Any]: return {"$vector": vector} @@ -131,7 +137,6 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: an equivalent filter clause for use in Astra DB's find queries. """ - @abstractmethod def encode_id(self, filter_id: str) -> dict[str, Any]: """Encode an ID as a filter for use in Astra DB queries. @@ -139,10 +144,23 @@ def encode_id(self, filter_id: str) -> dict[str, Any]: filter_id: the ID value to filter on. Returns: - an filter clause for use in Astra DB's find queries. + a filter clause for use in Astra DB's find queries. """ + return _default_encode_id(filter_id) + + def encode_ids(self, filter_ids: list[str]) -> dict[str, Any]: + """Encode a list of IDs as an appropriate search filter. + + The resulting filter expresses condition: "document ID is among filter_ids". + + Args: + filter_ids: the ID values to filter on. + + Returns: + a filter clause for use in Astra DB's find queries. + """ + return _default_encode_ids(filter_ids) - @abstractmethod def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: """Encode a vector as a sort to use for Astra DB queries. @@ -152,6 +170,7 @@ def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: Returns: an order clause for use in Astra DB's find queries. """ + return _default_encode_vector_sort(vector) class _DefaultVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): @@ -225,14 +244,6 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return _default_encode_filter(filter_dict) - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) - - @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) - class _DefaultVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): """Codec for the default vector store usage with server-side embeddings. @@ -308,14 +319,6 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return _default_encode_filter(filter_dict) - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) - - @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) - class _FlatVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): """Codec for collections populated externally, with client-side embeddings. @@ -396,14 +399,6 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return filter_dict - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) - - @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) - class _FlatVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): """Codec for collections populated externally, with server-side embeddings. @@ -476,11 +471,3 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: @override def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return filter_dict - - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) - - @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 5d0300b..514364e 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -654,11 +654,6 @@ def __init__( msg = f"Collection '{self.collection_name}' not found." raise ValueError(msg) # use the collection info to set the store properties - self.indexing_policy = self._normalize_metadata_indexing_policy( - metadata_indexing_include=None, - metadata_indexing_exclude=None, - collection_indexing_policy=c_descriptor.options.indexing, - ) if c_descriptor.options.vector is None: msg = "Non-vector collection detected." raise ValueError(msg) @@ -677,6 +672,11 @@ def __init__( ignore_invalid_documents=ignore_invalid_documents, norm_content_field=norm_content_field, ) + self.indexing_policy = self._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=None, + collection_indexing_policy=c_descriptor.options.indexing, + ) # validate embedding/vectorize compatibility and such. # Embedding and the server-side embeddings are mutually exclusive, @@ -864,7 +864,9 @@ def delete_by_document_id(self, document_id: str) -> bool: """ self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) - deletion_response = self.astra_env.collection.delete_one({"_id": document_id}) + deletion_response = self.astra_env.collection.delete_one( + self.document_codec.encode_id(document_id), + ) return deletion_response.deleted_count == 1 async def adelete_by_document_id(self, document_id: str) -> bool: @@ -878,7 +880,7 @@ async def adelete_by_document_id(self, document_id: str) -> bool: """ await self.astra_env.aensure_db_setup() deletion_response = await self.astra_env.async_collection.delete_one( - {"_id": document_id}, + self.document_codec.encode_id(document_id), ) return deletion_response.deleted_count == 1 @@ -1204,7 +1206,7 @@ def _replace_document( document: dict[str, Any], ) -> tuple[UpdateResult, str]: return self.astra_env.collection.replace_one( - {"_id": document["_id"]}, + self.document_codec.encode_id(document["_id"]), document, ), document["_id"] @@ -1334,7 +1336,7 @@ async def _replace_document( ) -> tuple[UpdateResult, str]: async with sem: return await _async_collection.replace_one( - {"_id": document["_id"]}, + self.document_codec.encode_id(document["_id"]), document, ), document["_id"] @@ -1395,7 +1397,7 @@ def _update_document( document_id, update_metadata = id_md_pair encoded_metadata = self.filter_to_query(update_metadata) return self.astra_env.collection.update_one( - {"_id": document_id}, + self.document_codec.encode_id(document_id), {"$set": encoded_metadata}, ) @@ -1448,7 +1450,7 @@ async def _update_document( encoded_metadata = self.filter_to_query(update_metadata) async with sem: return await _async_collection.update_one( - {"_id": document_id}, + self.document_codec.encode_id(document_id), {"$set": encoded_metadata}, ) @@ -1520,7 +1522,7 @@ def get_by_document_id(self, document_id: str) -> Document | None: self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) hit = self.astra_env.collection.find_one( - {"_id": document_id}, + self.document_codec.encode_id(document_id), projection=self.document_codec.base_projection, ) if hit is None: @@ -1539,7 +1541,7 @@ async def aget_by_document_id(self, document_id: str) -> Document | None: await self.astra_env.aensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) hit = await self.astra_env.async_collection.find_one( - {"_id": document_id}, + self.document_codec.encode_id(document_id), projection=self.document_codec.base_projection, ) if hit is None: From d92c15e10170e07538489121439461deaa16a047 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Thu, 13 Feb 2025 17:48:03 +0100 Subject: [PATCH 2/4] move all indexing, _id and similarity management into codecs so that it cleanly passes through the coded layer all the time --- .../utils/vector_store_codecs.py | 65 +++++++++++++- .../astradb/langchain_astradb/vectorstores.py | 85 +++++++------------ .../tests/integration_tests/conftest.py | 9 +- .../tests/unit_tests/test_vectorstores.py | 55 ++++++++++-- .../tests/unit_tests/test_vs_doc_codecs.py | 17 ++++ 5 files changed, 167 insertions(+), 64 deletions(-) diff --git a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py index 1803314..1052427 100644 --- a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py +++ b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py @@ -16,6 +16,8 @@ ) FLATTEN_CONFLICT_MSG = "Cannot flatten metadata: field name overlap for '{field}'." +STANDARD_INDEXING_OPTIONS_DEFAULT = {"allow": ["metadata"]} + logger = logging.getLogger(__name__) @@ -23,6 +25,14 @@ def _default_decode_vector(astra_doc: dict[str, Any]) -> list[float] | None: return astra_doc.get("$vector") +def _default_metadata_key_to_field_identifier(md_key: str) -> str: + return f"metadata.{md_key}" + + +def _flat_metadata_key_to_field_identifier(md_key: str) -> str: + return md_key + + def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]: metadata_filter = {} for k, v in filter_dict.items(): @@ -37,7 +47,7 @@ def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]: # assume each list item can be fed back to this function metadata_filter[k] = _default_encode_filter(v) # type: ignore[assignment] else: - metadata_filter[f"metadata.{k}"] = v + metadata_filter[_default_metadata_key_to_field_identifier(k)] = v return metadata_filter @@ -137,6 +147,26 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: an equivalent filter clause for use in Astra DB's find queries. """ + @abstractmethod + def metadata_key_to_field_identifier(self, md_key: str) -> str: + """Express an 'abstract' metadata key as a full Data API field identifier.""" + + @property + @abstractmethod + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + """Provide the default indexing policy if the collection must be created.""" + + def get_id(self, astra_document: dict[str, Any]) -> str: + """Return the ID of an encoded document (= a raw JSON read from DB).""" + return astra_document["_id"] + + def get_similarity(self, astra_document: dict[str, Any]) -> float: + """Return the similarity of an encoded document (= a raw JSON read from DB). + + This method assumes its argument comes from a suitable vector search. + """ + return astra_document["$similarity"] + def encode_id(self, filter_id: str) -> dict[str, Any]: """Encode an ID as a filter for use in Astra DB queries. @@ -244,6 +274,14 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return _default_encode_filter(filter_dict) + @override + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _default_metadata_key_to_field_identifier(md_key) + + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + return STANDARD_INDEXING_OPTIONS_DEFAULT + class _DefaultVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): """Codec for the default vector store usage with server-side embeddings. @@ -319,6 +357,14 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return _default_encode_filter(filter_dict) + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + return STANDARD_INDEXING_OPTIONS_DEFAULT + + @override + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _default_metadata_key_to_field_identifier(md_key) + class _FlatVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): """Codec for collections populated externally, with client-side embeddings. @@ -399,6 +445,14 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return filter_dict + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + return {"deny": [self.content_field]} + + @override + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _flat_metadata_key_to_field_identifier(md_key) + class _FlatVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): """Codec for collections populated externally, with server-side embeddings. @@ -471,3 +525,12 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: @override def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return filter_dict + + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + # $vectorize cannot be de-indexed explicitly (the API manages it entirely). + return {} + + @override + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _flat_metadata_key_to_field_identifier(md_key) diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 514364e..f4c58e0 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -8,7 +8,6 @@ import uuid import warnings from concurrent.futures import ThreadPoolExecutor -from operator import itemgetter from typing import ( TYPE_CHECKING, Any, @@ -64,8 +63,6 @@ U = TypeVar("U") DocDict = Dict[str, Any] # dicts expressing entries to insert -# indexing options when creating a collection -DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]} # error code to check for during bulk insertions DOCUMENT_ALREADY_EXISTS_API_ERROR_CODE = "DOCUMENT_ALREADY_EXISTS" # max number of errors shown in full insertion error messages @@ -363,6 +360,7 @@ def _normalize_metadata_indexing_policy( metadata_indexing_include: Iterable[str] | None, metadata_indexing_exclude: Iterable[str] | None, collection_indexing_policy: dict[str, Any] | None, + document_codec: _AstraDBVectorStoreDocumentCodec, ) -> dict[str, Any]: """Normalize the constructor indexing parameters. @@ -385,19 +383,21 @@ def _normalize_metadata_indexing_policy( if metadata_indexing_include is not None: return { "allow": [ - f"metadata.{md_field}" for md_field in metadata_indexing_include + document_codec.metadata_key_to_field_identifier(md_field) + for md_field in metadata_indexing_include ] } if metadata_indexing_exclude is not None: return { "deny": [ - f"metadata.{md_field}" for md_field in metadata_indexing_exclude + document_codec.metadata_key_to_field_identifier(md_field) + for md_field in metadata_indexing_exclude ] } return ( collection_indexing_policy if collection_indexing_policy is not None - else DEFAULT_INDEXING_OPTIONS + else document_codec.default_collection_indexing_policy ) def __init__( @@ -621,6 +621,7 @@ def __init__( metadata_indexing_include=metadata_indexing_include, metadata_indexing_exclude=metadata_indexing_exclude, collection_indexing_policy=collection_indexing_policy, + document_codec=self.document_codec, ) else: logger.info( @@ -676,6 +677,7 @@ def __init__( metadata_indexing_include=None, metadata_indexing_exclude=None, collection_indexing_policy=c_descriptor.options.indexing, + document_codec=self.document_codec, ) # validate embedding/vectorize compatibility and such. @@ -708,7 +710,9 @@ def __init__( embedding_dimension=_embedding_dimension, metric=self.metric, requested_indexing_policy=self.indexing_policy, - default_indexing_policy=DEFAULT_INDEXING_OPTIONS, + default_indexing_policy=( + self.document_codec.default_collection_indexing_policy + ), collection_vector_service_options=self.collection_vector_service_options, collection_embedding_api_key=self.collection_embedding_api_key, ext_callers=ext_callers, @@ -1068,38 +1072,9 @@ def _get_documents_to_insert( # make unique by id, keeping the last return _unique_list( documents_to_insert[::-1], - itemgetter("_id"), + self.document_codec.get_id, )[::-1] - @staticmethod - def _get_missing_from_batch( - document_batch: list[DocDict], insert_result: dict[str, Any] - ) -> tuple[list[str], list[DocDict]]: - if "status" not in insert_result: - msg = f"API Exception while running bulk insertion: {insert_result}" - raise AstraDBVectorStoreError(msg) - batch_inserted = insert_result["status"]["insertedIds"] - # estimation of the preexisting documents that failed - missed_inserted_ids = {document["_id"] for document in document_batch} - set( - batch_inserted - ) - errors = insert_result.get("errors", []) - # careful for other sources of error other than "doc already exists" - num_errors = len(errors) - unexpected_errors = any( - error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors - ) - if num_errors != len(missed_inserted_ids) or unexpected_errors: - msg = f"API Exception while running bulk insertion: {errors}" - raise AstraDBVectorStoreError(msg) - # deal with the missing insertions as upserts - missing_from_batch = [ - document - for document in document_batch - if document["_id"] in missed_inserted_ids - ] - return batch_inserted, missing_from_batch - @override def add_texts( self, @@ -1161,7 +1136,7 @@ def add_texts( ) # perform an AstraPy insert_many, catching exceptions for overwriting docs - ids_to_replace: list[int] + ids_to_replace: list[str] inserted_ids: list[str] = [] try: insert_many_result = self.astra_env.collection.insert_many( @@ -1179,9 +1154,10 @@ def add_texts( inserted_ids = err.partial_result.inserted_ids inserted_ids_set = set(inserted_ids) ids_to_replace = [ - document["_id"] + doc_id for document in documents_to_insert - if document["_id"] not in inserted_ids_set + if (doc_id := self.document_codec.get_id(document)) + not in inserted_ids_set ] else: full_err_message = _insertmany_error_message(err) @@ -1192,7 +1168,7 @@ def add_texts( documents_to_replace = [ document for document in documents_to_insert - if document["_id"] in ids_to_replace + if self.document_codec.get_id(document) in ids_to_replace ] _max_workers = ( @@ -1205,10 +1181,11 @@ def add_texts( def _replace_document( document: dict[str, Any], ) -> tuple[UpdateResult, str]: + doc_id = self.document_codec.get_id(document) return self.astra_env.collection.replace_one( - self.document_codec.encode_id(document["_id"]), + self.document_codec.encode_id(doc_id), document, - ), document["_id"] + ), doc_id replace_results = list( executor.map( @@ -1291,7 +1268,7 @@ async def aadd_texts( ) # perform an AstraPy insert_many, catching exceptions for overwriting docs - ids_to_replace: list[int] + ids_to_replace: list[str] inserted_ids: list[str] = [] try: insert_many_result = await self.astra_env.async_collection.insert_many( @@ -1309,9 +1286,10 @@ async def aadd_texts( inserted_ids = err.partial_result.inserted_ids inserted_ids_set = set(inserted_ids) ids_to_replace = [ - document["_id"] + doc_id for document in documents_to_insert - if document["_id"] not in inserted_ids_set + if (doc_id := self.document_codec.get_id(document)) + not in inserted_ids_set ] else: full_err_message = _insertmany_error_message(err) @@ -1322,7 +1300,7 @@ async def aadd_texts( documents_to_replace = [ document for document in documents_to_insert - if document["_id"] in ids_to_replace + if self.document_codec.get_id(document) in ids_to_replace ] sem = asyncio.Semaphore( @@ -1335,10 +1313,11 @@ async def _replace_document( document: dict[str, Any], ) -> tuple[UpdateResult, str]: async with sem: + doc_id = self.document_codec.get_id(document) return await _async_collection.replace_one( - self.document_codec.encode_id(document["_id"]), + self.document_codec.encode_id(doc_id), document, - ), document["_id"] + ), doc_id tasks = [ asyncio.create_task(_replace_document(document)) @@ -1736,8 +1715,8 @@ def _similarity_search_with_score_id_by_sort( for (doc, sim, did) in ( ( self.document_codec.decode(hit), - hit["$similarity"], - hit["_id"], + self.document_codec.get_similarity(hit), + self.document_codec.get_id(hit), ) for hit in hits_ite ) @@ -2115,8 +2094,8 @@ async def _asimilarity_search_with_score_id_by_sort( async for (doc, sim, did) in ( ( self.document_codec.decode(hit), - hit["$similarity"], - hit["_id"], + self.document_codec.get_similarity(hit), + self.document_codec.get_id(hit), ) async for hit in self.astra_env.async_collection.find( filter=metadata_parameter, diff --git a/libs/astradb/tests/integration_tests/conftest.py b/libs/astradb/tests/integration_tests/conftest.py index c65763e..e871ae8 100644 --- a/libs/astradb/tests/integration_tests/conftest.py +++ b/libs/astradb/tests/integration_tests/conftest.py @@ -32,7 +32,10 @@ from astrapy.info import CollectionVectorServiceOptions from langchain_astradb.utils.astradb import SetupMode -from langchain_astradb.vectorstores import DEFAULT_INDEXING_OPTIONS, AstraDBVectorStore +from langchain_astradb.utils.vector_store_codecs import ( + STANDARD_INDEXING_OPTIONS_DEFAULT, +) +from langchain_astradb.vectorstores import AstraDBVectorStore from tests.conftest import IdentityLLM, ParserEmbeddings if TYPE_CHECKING: @@ -207,7 +210,7 @@ def collection_d2( COLLECTION_NAME_D2, dimension=2, check_exists=False, - indexing=DEFAULT_INDEXING_OPTIONS, + indexing=STANDARD_INDEXING_OPTIONS_DEFAULT, metric="euclidean", ) yield collection @@ -400,7 +403,7 @@ def collection_vz( COLLECTION_NAME_VZ, dimension=1536, check_exists=False, - indexing=DEFAULT_INDEXING_OPTIONS, + indexing=STANDARD_INDEXING_OPTIONS_DEFAULT, metric="euclidean", service=OPENAI_VECTORIZE_OPTIONS_HEADER, embedding_api_key=openai_api_key, diff --git a/libs/astradb/tests/unit_tests/test_vectorstores.py b/libs/astradb/tests/unit_tests/test_vectorstores.py index 51a5ac2..1bb85a5 100644 --- a/libs/astradb/tests/unit_tests/test_vectorstores.py +++ b/libs/astradb/tests/unit_tests/test_vectorstores.py @@ -3,10 +3,11 @@ from astrapy.info import CollectionVectorServiceOptions from langchain_astradb.utils.astradb import SetupMode -from langchain_astradb.vectorstores import ( - DEFAULT_INDEXING_OPTIONS, - AstraDBVectorStore, +from langchain_astradb.utils.vector_store_codecs import ( + _DefaultVSDocumentCodec, + _FlatVSDocumentCodec, ) +from langchain_astradb.vectorstores import AstraDBVectorStore from tests.conftest import ParserEmbeddings FAKE_TOKEN = "t" # noqa: S105 @@ -86,27 +87,62 @@ def test_initialization(self) -> None: ) def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: - """Unit test of the indexing policy normalization""" + """Unit test of the indexing policy normalization. + + We use just a couple of codecs to check the idx policy fallbacks. + """ + + the_f_codec = _FlatVSDocumentCodec( + content_field="content_x", + ignore_invalid_documents=False, + ) + the_f_default_policy = the_f_codec.default_collection_indexing_policy + the_d_codec = _DefaultVSDocumentCodec( + content_field="content_y", + ignore_invalid_documents=False, + ) + + # default (non-flat): hardcoding expected indexing from including + al_d_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=["a1", "a2"], + metadata_indexing_exclude=None, + collection_indexing_policy=None, + document_codec=the_d_codec, + ) + assert al_d_idx == {"allow": ["metadata.a1", "metadata.a2"]} + + # default (non-flat): hardcoding expected indexing from excluding + dl_d_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=["d1", "d2"], + collection_indexing_policy=None, + document_codec=the_d_codec, + ) + assert dl_d_idx == {"deny": ["metadata.d1", "metadata.d2"]} + n3_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( metadata_indexing_include=None, metadata_indexing_exclude=None, collection_indexing_policy=None, + document_codec=the_f_codec, ) - assert n3_idx == DEFAULT_INDEXING_OPTIONS + assert n3_idx == the_f_default_policy al_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( metadata_indexing_include=["a1", "a2"], metadata_indexing_exclude=None, collection_indexing_policy=None, + document_codec=the_f_codec, ) - assert al_idx == {"allow": ["metadata.a1", "metadata.a2"]} + assert al_idx == {"allow": ["a1", "a2"]} dl_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( metadata_indexing_include=None, metadata_indexing_exclude=["d1", "d2"], collection_indexing_policy=None, + document_codec=the_f_codec, ) - assert dl_idx == {"deny": ["metadata.d1", "metadata.d2"]} + assert dl_idx == {"deny": ["d1", "d2"]} custom_policy = { "deny": ["myfield", "other_field.subfield", "metadata.long_text"] @@ -115,6 +151,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=None, metadata_indexing_exclude=None, collection_indexing_policy=custom_policy, + document_codec=the_f_codec, ) assert cip_idx == custom_policy @@ -129,6 +166,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=["a"], metadata_indexing_exclude=["b"], collection_indexing_policy=None, + document_codec=the_f_codec, ) with pytest.raises(ValueError, match=error_msg): @@ -136,6 +174,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=["a"], metadata_indexing_exclude=None, collection_indexing_policy={"a": "z"}, + document_codec=the_f_codec, ) with pytest.raises(ValueError, match=error_msg): @@ -143,6 +182,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=None, metadata_indexing_exclude=["b"], collection_indexing_policy={"a": "z"}, + document_codec=the_f_codec, ) with pytest.raises(ValueError, match=error_msg): @@ -150,4 +190,5 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=["a"], metadata_indexing_exclude=["b"], collection_indexing_policy={"a": "z"}, + document_codec=the_f_codec, ) diff --git a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py index 4e5d8c5..f49988d 100644 --- a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py +++ b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py @@ -365,3 +365,20 @@ def test_flat_vectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for flat, vectorize.""" codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) assert codec.encode_vector_sort(VECTOR) == VECTOR_SORT + + def test_codec_default_collection_indexing_policy(self) -> None: + """Test all codecs give their expected default indexing settings back.""" + codec_d_n = _DefaultVSDocumentCodec( + content_field="content_x", + ignore_invalid_documents=False, + ) + assert codec_d_n.default_collection_indexing_policy == {"allow": ["metadata"]} + codec_d_v = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=True) + assert codec_d_v.default_collection_indexing_policy == {"allow": ["metadata"]} + codec_f_n = _FlatVSDocumentCodec( + content_field="content_x", + ignore_invalid_documents=False, + ) + assert codec_f_n.default_collection_indexing_policy == {"deny": ["content_x"]} + codec_f_v = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=True) + assert codec_f_v.default_collection_indexing_policy == {} From 1797685389df3e1ae5fbff51dbe52ed03aeac27d Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 18 Feb 2025 00:49:55 +0100 Subject: [PATCH 3/4] trading encode_id[s] for encode_query --- .../utils/vector_store_codecs.py | 43 ++++--- .../astradb/langchain_astradb/vectorstores.py | 16 +-- .../tests/unit_tests/test_vs_doc_codecs.py | 105 ++++++++++++++---- 3 files changed, 108 insertions(+), 56 deletions(-) diff --git a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py index 1052427..ad01815 100644 --- a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py +++ b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py @@ -5,7 +5,7 @@ import logging import warnings from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Iterable from langchain_core.documents import Document from typing_extensions import override @@ -167,29 +167,24 @@ def get_similarity(self, astra_document: dict[str, Any]) -> float: """ return astra_document["$similarity"] - def encode_id(self, filter_id: str) -> dict[str, Any]: - """Encode an ID as a filter for use in Astra DB queries. - - Args: - filter_id: the ID value to filter on. - - Returns: - a filter clause for use in Astra DB's find queries. - """ - return _default_encode_id(filter_id) - - def encode_ids(self, filter_ids: list[str]) -> dict[str, Any]: - """Encode a list of IDs as an appropriate search filter. - - The resulting filter expresses condition: "document ID is among filter_ids". - - Args: - filter_ids: the ID values to filter on. - - Returns: - a filter clause for use in Astra DB's find queries. - """ - return _default_encode_ids(filter_ids) + def encode_query( + self, + *, + ids: Iterable[str] | None = None, + filter_dict: dict[str, Any] | None = None, + ) -> dict[str, Any]: + clauses: list[dict[str, Any]] = [] + _ids_list = list(ids or []) + if _ids_list: + clauses.append(_default_encode_ids(_ids_list)) + if filter_dict: + clauses.append(self.encode_filter(filter_dict)) + + if clauses: + if len(clauses) > 1: + return {"$and": clauses} + return clauses[0] + return {} def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: """Encode a vector as a sort to use for Astra DB queries. diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index f4c58e0..a34b237 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -869,7 +869,7 @@ def delete_by_document_id(self, document_id: str) -> bool: self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) deletion_response = self.astra_env.collection.delete_one( - self.document_codec.encode_id(document_id), + self.document_codec.encode_query(ids=[document_id]), ) return deletion_response.deleted_count == 1 @@ -884,7 +884,7 @@ async def adelete_by_document_id(self, document_id: str) -> bool: """ await self.astra_env.aensure_db_setup() deletion_response = await self.astra_env.async_collection.delete_one( - self.document_codec.encode_id(document_id), + self.document_codec.encode_query(ids=[document_id]), ) return deletion_response.deleted_count == 1 @@ -1183,7 +1183,7 @@ def _replace_document( ) -> tuple[UpdateResult, str]: doc_id = self.document_codec.get_id(document) return self.astra_env.collection.replace_one( - self.document_codec.encode_id(doc_id), + self.document_codec.encode_query(ids=[doc_id]), document, ), doc_id @@ -1315,7 +1315,7 @@ async def _replace_document( async with sem: doc_id = self.document_codec.get_id(document) return await _async_collection.replace_one( - self.document_codec.encode_id(doc_id), + self.document_codec.encode_query(ids=[doc_id]), document, ), doc_id @@ -1376,7 +1376,7 @@ def _update_document( document_id, update_metadata = id_md_pair encoded_metadata = self.filter_to_query(update_metadata) return self.astra_env.collection.update_one( - self.document_codec.encode_id(document_id), + self.document_codec.encode_query(ids=[document_id]), {"$set": encoded_metadata}, ) @@ -1429,7 +1429,7 @@ async def _update_document( encoded_metadata = self.filter_to_query(update_metadata) async with sem: return await _async_collection.update_one( - self.document_codec.encode_id(document_id), + self.document_codec.encode_query(ids=[document_id]), {"$set": encoded_metadata}, ) @@ -1501,7 +1501,7 @@ def get_by_document_id(self, document_id: str) -> Document | None: self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) hit = self.astra_env.collection.find_one( - self.document_codec.encode_id(document_id), + self.document_codec.encode_query(ids=[document_id]), projection=self.document_codec.base_projection, ) if hit is None: @@ -1520,7 +1520,7 @@ async def aget_by_document_id(self, document_id: str) -> Document | None: await self.astra_env.aensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) hit = await self.astra_env.async_collection.find_one( - self.document_codec.encode_id(document_id), + self.document_codec.encode_query(ids=[document_id]), projection=self.document_codec.base_projection, ) if hit is None: diff --git a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py index f49988d..2df3022 100644 --- a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py +++ b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py @@ -9,6 +9,7 @@ from langchain_astradb.utils.vector_store_codecs import ( NO_NULL_VECTOR_MSG, VECTOR_REQUIRED_PREAMBLE_MSG, + _AstraDBVectorStoreDocumentCodec, _DefaultVectorizeVSDocumentCodec, _DefaultVSDocumentCodec, _FlatVectorizeVSDocumentCodec, @@ -21,7 +22,6 @@ DOCUMENT_ID = "the_id" LC_DOCUMENT = Document(id=DOCUMENT_ID, page_content=CONTENT, metadata=METADATA) LC_FILTER = {"a0": 0, "$or": [{"b1": 1}, {"b2": 2}]} -ID_FILTER = {"_id": DOCUMENT_ID} VECTOR_SORT = {"$vector": VECTOR} ASTRA_DEFAULT_DOCUMENT_NOVECTORIZE = { @@ -136,12 +136,86 @@ def test_default_novectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_DEFAULT_DOCUMENT_NOVECTORIZE) == VECTOR assert codec.decode_vector({}) is None - def test_default_novectorize_id_encoding(self) -> None: - """Test id-encoding for default, no-vectorize.""" - codec = _DefaultVSDocumentCodec( - content_field="content_x", ignore_invalid_documents=False - ) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER + @pytest.mark.parametrize( + ("default_codec_class", "codec_kwargs"), + [ + ( + _DefaultVSDocumentCodec, + {"content_field": "cf", "ignore_invalid_documents": False}, + ), + (_DefaultVectorizeVSDocumentCodec, {"ignore_invalid_documents": False}), + ], + ids=("default", "default_vectorize"), + ) + def test_default_query_encoding( + self, + default_codec_class: type[_AstraDBVectorStoreDocumentCodec], + codec_kwargs: dict[str, Any], + ) -> None: + """Test query-encoding for default, no-vectorize.""" + codec = default_codec_class(**codec_kwargs) + assert codec.encode_query() == {} + assert codec.encode_query(ids=[DOCUMENT_ID]) == {"_id": DOCUMENT_ID} + assert codec.encode_query(ids=["id1", "id2"]) == { + "_id": {"$in": ["id1", "id2"]} + } + assert codec.encode_query(ids=[DOCUMENT_ID], filter_dict={"mdk": "mdv"}) == { + "$and": [{"_id": DOCUMENT_ID}, {"metadata.mdk": "mdv"}] + } + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"mdk1": "mdv1", "mdk2": "mdv2"} + ) == { + "$and": [ + {"_id": DOCUMENT_ID}, + {"metadata.mdk1": "mdv1", "metadata.mdk2": "mdv2"}, + ] + } + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"$or": [{"mdk1": "mdv1"}, {"mdk2": "mdv2"}]} + ) == { + "$and": [ + {"_id": DOCUMENT_ID}, + {"$or": [{"metadata.mdk1": "mdv1"}, {"metadata.mdk2": "mdv2"}]}, + ] + } + + @pytest.mark.parametrize( + ("default_codec_class", "codec_kwargs"), + [ + ( + _FlatVSDocumentCodec, + {"content_field": "cf", "ignore_invalid_documents": False}, + ), + (_FlatVectorizeVSDocumentCodec, {"ignore_invalid_documents": False}), + ], + ids=("flat", "flat_vectorize"), + ) + def test_flat_query_encoding( + self, + default_codec_class: type[_AstraDBVectorStoreDocumentCodec], + codec_kwargs: dict[str, Any], + ) -> None: + """Test query-encoding for default, no-vectorize.""" + codec = default_codec_class(**codec_kwargs) + assert codec.encode_query() == {} + assert codec.encode_query(ids=[DOCUMENT_ID]) == {"_id": DOCUMENT_ID} + assert codec.encode_query(ids=["id1", "id2"]) == { + "_id": {"$in": ["id1", "id2"]} + } + assert codec.encode_query(ids=[DOCUMENT_ID], filter_dict={"mdk": "mdv"}) == { + "$and": [{"_id": DOCUMENT_ID}, {"mdk": "mdv"}] + } + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"mdk1": "mdv1", "mdk2": "mdv2"} + ) == {"$and": [{"_id": DOCUMENT_ID}, {"mdk1": "mdv1", "mdk2": "mdv2"}]} + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"$or": [{"mdk1": "mdv1"}, {"mdk2": "mdv2"}]} + ) == { + "$and": [ + {"_id": DOCUMENT_ID}, + {"$or": [{"mdk1": "mdv1"}, {"mdk2": "mdv2"}]}, + ] + } def test_default_novectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for default, no-vectorize.""" @@ -206,11 +280,6 @@ def test_default_vectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_DEFAULT_DOCUMENT_VECTORIZE_READ) == VECTOR assert codec.decode_vector({}) is None - def test_default_vectorize_id_encoding(self) -> None: - """Test id-encoding for default, vectorize.""" - codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER - def test_default_vectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for default, vectorize.""" codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) @@ -286,13 +355,6 @@ def test_flat_novectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_FLAT_DOCUMENT_NOVECTORIZE) == VECTOR assert codec.decode_vector({}) is None - def test_flat_novectorize_id_encoding(self) -> None: - """Test id-encoding for flat, no-vectorize.""" - codec = _FlatVSDocumentCodec( - content_field="content_x", ignore_invalid_documents=False - ) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER - def test_flat_novectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for flat, no-vectorize.""" codec = _FlatVSDocumentCodec( @@ -356,11 +418,6 @@ def test_flat_vectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_FLAT_DOCUMENT_VECTORIZE_READ) == VECTOR assert codec.decode_vector({}) is None - def test_flat_vectorize_id_encoding(self) -> None: - """Test id-encoding for flat, vectorize.""" - codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER - def test_flat_vectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for flat, vectorize.""" codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) From 1112af041b9269cccb0baa5960a570da6a1da0b6 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 18 Feb 2025 12:38:20 +0100 Subject: [PATCH 4/4] them docstrings --- .../utils/vector_store_codecs.py | 78 ++++++++++++++++--- 1 file changed, 68 insertions(+), 10 deletions(-) diff --git a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py index ad01815..919aadb 100644 --- a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py +++ b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py @@ -1,4 +1,4 @@ -"""Classes to handle encoding of documents on DB for the Vector Store..""" +"""Classes to handle encoding of Documents on DB for the Vector Store..""" from __future__ import annotations @@ -22,18 +22,38 @@ def _default_decode_vector(astra_doc: dict[str, Any]) -> list[float] | None: + """Extract the embedding vector from an Astra DB document.""" return astra_doc.get("$vector") def _default_metadata_key_to_field_identifier(md_key: str) -> str: + """Rewrite a metadata key name to its full path in the 'default' encoding. + + The input `md_key` is an "abstract" metadata key, while the return value + identifies its actual full-path location on an Astra DB document encoded in the + 'default' way (i.e. with a nested `metadata` dictionary). + """ return f"metadata.{md_key}" def _flat_metadata_key_to_field_identifier(md_key: str) -> str: + """Rewrite a metadata key name to its full path in the 'flat' encoding. + + The input `md_key` is an "abstract" metadata key, while the return value + identifies its actual full-path location on an Astra DB document encoded in the + 'flat' way (i.e. metadata fields appearing at top-level in the Astra DB document). + """ return md_key def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]: + """Encode an "abstract" metadata condition for the 'default' encoding. + + The input can express a query clause on metadata and uses just the metadata field + names, possibly connected/nested through AND and ORs. The output makes key names + into their full path-identifiers (e.g. "metadata.xyz") according to the 'default' + encoding scheme for Astra DB documents. + """ metadata_filter = {} for k, v in filter_dict.items(): # Key in this dict starting with $ are supposedly operators and as such @@ -52,33 +72,43 @@ def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]: return metadata_filter -def _default_encode_id(filter_id: str) -> dict[str, Any]: +def _astra_generic_encode_id(filter_id: str) -> dict[str, Any]: + """Encoding of a single Document ID as a query clause for an Astra DB document.""" return {"_id": filter_id} -def _default_encode_ids(filter_ids: list[str]) -> dict[str, Any]: +def _astra_generic_encode_ids(filter_ids: list[str]) -> dict[str, Any]: + """Encoding of Document IDs as a query clause for an Astra DB document. + + This function picks the right, and most concise, expression based on the + multiplicity of the provided IDs. + """ if len(filter_ids) == 1: - return _default_encode_id(filter_ids[0]) + return _astra_generic_encode_id(filter_ids[0]) return {"_id": {"$in": filter_ids}} -def _default_encode_vector_sort(vector: list[float]) -> dict[str, Any]: +def _astra_generic_encode_vector_sort(vector: list[float]) -> dict[str, Any]: + """Encoding of a vector-based sort as a query clause for an Astra DB document.""" return {"$vector": vector} class _AstraDBVectorStoreDocumentCodec(ABC): - """A document codec for the Astra DB vector store. + """A Document codec for the Astra DB vector store. - The document codec contains the information for consistent interaction + Document codecs hold the logic consistent interaction with documents as stored on the Astra DB collection. + In this context, 'Document' (capital D) refers to the LangChain class, + while 'Astra DB document' refers to the JSON-like object stored on DB. + Implementations of this class must: - define how to encode/decode documents consistently to and from Astra DB collections. The two operations must, so to speak, combine to the identity on both sides (except for the quirks of their signatures). - provide the adequate projection dictionaries for running find operations on Astra DB, with and without the field containing the vector. - - encode IDs to the `_id` field on Astra DB. + - encode Document IDs to the right field on Astra DB ("_id" for Collections). - define the name of the field storing the textual content of the Document. - define whether embeddings are computed server-side (with $vectorize) or not. """ @@ -173,10 +203,38 @@ def encode_query( ids: Iterable[str] | None = None, filter_dict: dict[str, Any] | None = None, ) -> dict[str, Any]: + """Prepare an encoded query according to the Astra DB document encoding. + + The method optionally accepts both IDs and metadata filters. The two, + if passed together, are automatically combined with an AND operation. + + In other words, if passing both IDs and a metadata filtering clause, + the resulting query would return Astra DB documents matching the metadata + clause AND having an ID among those provided to this method. If, instead, + an OR is required, one should run two separate queries and subsequently merge + the result (taking care of avoiding duplcates). + + Args: + ids: an iterable over Document IDs. If provided, the resulting Astra DB + query dictionary expresses the requirement that returning documents + have an ID among those provided here. Passing an empty iterable, + or None, results in a query with no conditions on the IDs at all. + filter_dict: a metadata filtering part. If provided, if must refer to + metadata keys by their bare name (such as `{"key": 123}`). + This filter can combine nested conditions with "$or"/"$and" connectors, + for example: + - `{"tag": "a"}` + - `{"$or": [{"tag": "a"}, "label": "b"]}` + - `{"$and": [{"tag": {"$in": ["a", "z"]}}, "label": "b"]}` + + Returns: + a query dictionary ready to be used in an Astra DB find operation on + a collection. + """ clauses: list[dict[str, Any]] = [] _ids_list = list(ids or []) if _ids_list: - clauses.append(_default_encode_ids(_ids_list)) + clauses.append(_astra_generic_encode_ids(_ids_list)) if filter_dict: clauses.append(self.encode_filter(filter_dict)) @@ -195,7 +253,7 @@ def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: Returns: an order clause for use in Astra DB's find queries. """ - return _default_encode_vector_sort(vector) + return _astra_generic_encode_vector_sort(vector) class _DefaultVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):