diff --git a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py index ae15403..919aadb 100644 --- a/libs/astradb/langchain_astradb/utils/vector_store_codecs.py +++ b/libs/astradb/langchain_astradb/utils/vector_store_codecs.py @@ -1,11 +1,11 @@ -"""Classes to handle encoding of documents on DB for the Vector Store..""" +"""Classes to handle encoding of Documents on DB for the Vector Store..""" from __future__ import annotations import logging import warnings from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Iterable from langchain_core.documents import Document from typing_extensions import override @@ -16,14 +16,44 @@ ) FLATTEN_CONFLICT_MSG = "Cannot flatten metadata: field name overlap for '{field}'." +STANDARD_INDEXING_OPTIONS_DEFAULT = {"allow": ["metadata"]} + logger = logging.getLogger(__name__) def _default_decode_vector(astra_doc: dict[str, Any]) -> list[float] | None: + """Extract the embedding vector from an Astra DB document.""" return astra_doc.get("$vector") +def _default_metadata_key_to_field_identifier(md_key: str) -> str: + """Rewrite a metadata key name to its full path in the 'default' encoding. + + The input `md_key` is an "abstract" metadata key, while the return value + identifies its actual full-path location on an Astra DB document encoded in the + 'default' way (i.e. with a nested `metadata` dictionary). + """ + return f"metadata.{md_key}" + + +def _flat_metadata_key_to_field_identifier(md_key: str) -> str: + """Rewrite a metadata key name to its full path in the 'flat' encoding. + + The input `md_key` is an "abstract" metadata key, while the return value + identifies its actual full-path location on an Astra DB document encoded in the + 'flat' way (i.e. metadata fields appearing at top-level in the Astra DB document). + """ + return md_key + + def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]: + """Encode an "abstract" metadata condition for the 'default' encoding. + + The input can express a query clause on metadata and uses just the metadata field + names, possibly connected/nested through AND and ORs. The output makes key names + into their full path-identifiers (e.g. "metadata.xyz") according to the 'default' + encoding scheme for Astra DB documents. + """ metadata_filter = {} for k, v in filter_dict.items(): # Key in this dict starting with $ are supposedly operators and as such @@ -37,32 +67,48 @@ def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]: # assume each list item can be fed back to this function metadata_filter[k] = _default_encode_filter(v) # type: ignore[assignment] else: - metadata_filter[f"metadata.{k}"] = v + metadata_filter[_default_metadata_key_to_field_identifier(k)] = v return metadata_filter -def _default_encode_id(filter_id: str) -> dict[str, Any]: +def _astra_generic_encode_id(filter_id: str) -> dict[str, Any]: + """Encoding of a single Document ID as a query clause for an Astra DB document.""" return {"_id": filter_id} -def _default_encode_vector_sort(vector: list[float]) -> dict[str, Any]: +def _astra_generic_encode_ids(filter_ids: list[str]) -> dict[str, Any]: + """Encoding of Document IDs as a query clause for an Astra DB document. + + This function picks the right, and most concise, expression based on the + multiplicity of the provided IDs. + """ + if len(filter_ids) == 1: + return _astra_generic_encode_id(filter_ids[0]) + return {"_id": {"$in": filter_ids}} + + +def _astra_generic_encode_vector_sort(vector: list[float]) -> dict[str, Any]: + """Encoding of a vector-based sort as a query clause for an Astra DB document.""" return {"$vector": vector} class _AstraDBVectorStoreDocumentCodec(ABC): - """A document codec for the Astra DB vector store. + """A Document codec for the Astra DB vector store. - The document codec contains the information for consistent interaction + Document codecs hold the logic consistent interaction with documents as stored on the Astra DB collection. + In this context, 'Document' (capital D) refers to the LangChain class, + while 'Astra DB document' refers to the JSON-like object stored on DB. + Implementations of this class must: - define how to encode/decode documents consistently to and from Astra DB collections. The two operations must, so to speak, combine to the identity on both sides (except for the quirks of their signatures). - provide the adequate projection dictionaries for running find operations on Astra DB, with and without the field containing the vector. - - encode IDs to the `_id` field on Astra DB. + - encode Document IDs to the right field on Astra DB ("_id" for Collections). - define the name of the field storing the textual content of the Document. - define whether embeddings are computed server-side (with $vectorize) or not. """ @@ -132,17 +178,72 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: """ @abstractmethod - def encode_id(self, filter_id: str) -> dict[str, Any]: - """Encode an ID as a filter for use in Astra DB queries. + def metadata_key_to_field_identifier(self, md_key: str) -> str: + """Express an 'abstract' metadata key as a full Data API field identifier.""" + + @property + @abstractmethod + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + """Provide the default indexing policy if the collection must be created.""" + + def get_id(self, astra_document: dict[str, Any]) -> str: + """Return the ID of an encoded document (= a raw JSON read from DB).""" + return astra_document["_id"] + + def get_similarity(self, astra_document: dict[str, Any]) -> float: + """Return the similarity of an encoded document (= a raw JSON read from DB). + + This method assumes its argument comes from a suitable vector search. + """ + return astra_document["$similarity"] + + def encode_query( + self, + *, + ids: Iterable[str] | None = None, + filter_dict: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Prepare an encoded query according to the Astra DB document encoding. + + The method optionally accepts both IDs and metadata filters. The two, + if passed together, are automatically combined with an AND operation. + + In other words, if passing both IDs and a metadata filtering clause, + the resulting query would return Astra DB documents matching the metadata + clause AND having an ID among those provided to this method. If, instead, + an OR is required, one should run two separate queries and subsequently merge + the result (taking care of avoiding duplcates). Args: - filter_id: the ID value to filter on. + ids: an iterable over Document IDs. If provided, the resulting Astra DB + query dictionary expresses the requirement that returning documents + have an ID among those provided here. Passing an empty iterable, + or None, results in a query with no conditions on the IDs at all. + filter_dict: a metadata filtering part. If provided, if must refer to + metadata keys by their bare name (such as `{"key": 123}`). + This filter can combine nested conditions with "$or"/"$and" connectors, + for example: + - `{"tag": "a"}` + - `{"$or": [{"tag": "a"}, "label": "b"]}` + - `{"$and": [{"tag": {"$in": ["a", "z"]}}, "label": "b"]}` Returns: - an filter clause for use in Astra DB's find queries. + a query dictionary ready to be used in an Astra DB find operation on + a collection. """ + clauses: list[dict[str, Any]] = [] + _ids_list = list(ids or []) + if _ids_list: + clauses.append(_astra_generic_encode_ids(_ids_list)) + if filter_dict: + clauses.append(self.encode_filter(filter_dict)) + + if clauses: + if len(clauses) > 1: + return {"$and": clauses} + return clauses[0] + return {} - @abstractmethod def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: """Encode a vector as a sort to use for Astra DB queries. @@ -152,6 +253,7 @@ def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: Returns: an order clause for use in Astra DB's find queries. """ + return _astra_generic_encode_vector_sort(vector) class _DefaultVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): @@ -226,12 +328,12 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return _default_encode_filter(filter_dict) @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _default_metadata_key_to_field_identifier(md_key) - @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + return STANDARD_INDEXING_OPTIONS_DEFAULT class _DefaultVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): @@ -308,13 +410,13 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return _default_encode_filter(filter_dict) - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + return STANDARD_INDEXING_OPTIONS_DEFAULT @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _default_metadata_key_to_field_identifier(md_key) class _FlatVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): @@ -396,13 +498,13 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return filter_dict - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + return {"deny": [self.content_field]} @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _flat_metadata_key_to_field_identifier(md_key) class _FlatVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec): @@ -477,10 +579,11 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None: def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]: return filter_dict - @override - def encode_id(self, filter_id: str) -> dict[str, Any]: - return _default_encode_id(filter_id) + @property + def default_collection_indexing_policy(self) -> dict[str, list[str]]: + # $vectorize cannot be de-indexed explicitly (the API manages it entirely). + return {} @override - def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]: - return _default_encode_vector_sort(vector) + def metadata_key_to_field_identifier(self, md_key: str) -> str: + return _flat_metadata_key_to_field_identifier(md_key) diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 5d0300b..a34b237 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -8,7 +8,6 @@ import uuid import warnings from concurrent.futures import ThreadPoolExecutor -from operator import itemgetter from typing import ( TYPE_CHECKING, Any, @@ -64,8 +63,6 @@ U = TypeVar("U") DocDict = Dict[str, Any] # dicts expressing entries to insert -# indexing options when creating a collection -DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]} # error code to check for during bulk insertions DOCUMENT_ALREADY_EXISTS_API_ERROR_CODE = "DOCUMENT_ALREADY_EXISTS" # max number of errors shown in full insertion error messages @@ -363,6 +360,7 @@ def _normalize_metadata_indexing_policy( metadata_indexing_include: Iterable[str] | None, metadata_indexing_exclude: Iterable[str] | None, collection_indexing_policy: dict[str, Any] | None, + document_codec: _AstraDBVectorStoreDocumentCodec, ) -> dict[str, Any]: """Normalize the constructor indexing parameters. @@ -385,19 +383,21 @@ def _normalize_metadata_indexing_policy( if metadata_indexing_include is not None: return { "allow": [ - f"metadata.{md_field}" for md_field in metadata_indexing_include + document_codec.metadata_key_to_field_identifier(md_field) + for md_field in metadata_indexing_include ] } if metadata_indexing_exclude is not None: return { "deny": [ - f"metadata.{md_field}" for md_field in metadata_indexing_exclude + document_codec.metadata_key_to_field_identifier(md_field) + for md_field in metadata_indexing_exclude ] } return ( collection_indexing_policy if collection_indexing_policy is not None - else DEFAULT_INDEXING_OPTIONS + else document_codec.default_collection_indexing_policy ) def __init__( @@ -621,6 +621,7 @@ def __init__( metadata_indexing_include=metadata_indexing_include, metadata_indexing_exclude=metadata_indexing_exclude, collection_indexing_policy=collection_indexing_policy, + document_codec=self.document_codec, ) else: logger.info( @@ -654,11 +655,6 @@ def __init__( msg = f"Collection '{self.collection_name}' not found." raise ValueError(msg) # use the collection info to set the store properties - self.indexing_policy = self._normalize_metadata_indexing_policy( - metadata_indexing_include=None, - metadata_indexing_exclude=None, - collection_indexing_policy=c_descriptor.options.indexing, - ) if c_descriptor.options.vector is None: msg = "Non-vector collection detected." raise ValueError(msg) @@ -677,6 +673,12 @@ def __init__( ignore_invalid_documents=ignore_invalid_documents, norm_content_field=norm_content_field, ) + self.indexing_policy = self._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=None, + collection_indexing_policy=c_descriptor.options.indexing, + document_codec=self.document_codec, + ) # validate embedding/vectorize compatibility and such. # Embedding and the server-side embeddings are mutually exclusive, @@ -708,7 +710,9 @@ def __init__( embedding_dimension=_embedding_dimension, metric=self.metric, requested_indexing_policy=self.indexing_policy, - default_indexing_policy=DEFAULT_INDEXING_OPTIONS, + default_indexing_policy=( + self.document_codec.default_collection_indexing_policy + ), collection_vector_service_options=self.collection_vector_service_options, collection_embedding_api_key=self.collection_embedding_api_key, ext_callers=ext_callers, @@ -864,7 +868,9 @@ def delete_by_document_id(self, document_id: str) -> bool: """ self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) - deletion_response = self.astra_env.collection.delete_one({"_id": document_id}) + deletion_response = self.astra_env.collection.delete_one( + self.document_codec.encode_query(ids=[document_id]), + ) return deletion_response.deleted_count == 1 async def adelete_by_document_id(self, document_id: str) -> bool: @@ -878,7 +884,7 @@ async def adelete_by_document_id(self, document_id: str) -> bool: """ await self.astra_env.aensure_db_setup() deletion_response = await self.astra_env.async_collection.delete_one( - {"_id": document_id}, + self.document_codec.encode_query(ids=[document_id]), ) return deletion_response.deleted_count == 1 @@ -1066,38 +1072,9 @@ def _get_documents_to_insert( # make unique by id, keeping the last return _unique_list( documents_to_insert[::-1], - itemgetter("_id"), + self.document_codec.get_id, )[::-1] - @staticmethod - def _get_missing_from_batch( - document_batch: list[DocDict], insert_result: dict[str, Any] - ) -> tuple[list[str], list[DocDict]]: - if "status" not in insert_result: - msg = f"API Exception while running bulk insertion: {insert_result}" - raise AstraDBVectorStoreError(msg) - batch_inserted = insert_result["status"]["insertedIds"] - # estimation of the preexisting documents that failed - missed_inserted_ids = {document["_id"] for document in document_batch} - set( - batch_inserted - ) - errors = insert_result.get("errors", []) - # careful for other sources of error other than "doc already exists" - num_errors = len(errors) - unexpected_errors = any( - error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors - ) - if num_errors != len(missed_inserted_ids) or unexpected_errors: - msg = f"API Exception while running bulk insertion: {errors}" - raise AstraDBVectorStoreError(msg) - # deal with the missing insertions as upserts - missing_from_batch = [ - document - for document in document_batch - if document["_id"] in missed_inserted_ids - ] - return batch_inserted, missing_from_batch - @override def add_texts( self, @@ -1159,7 +1136,7 @@ def add_texts( ) # perform an AstraPy insert_many, catching exceptions for overwriting docs - ids_to_replace: list[int] + ids_to_replace: list[str] inserted_ids: list[str] = [] try: insert_many_result = self.astra_env.collection.insert_many( @@ -1177,9 +1154,10 @@ def add_texts( inserted_ids = err.partial_result.inserted_ids inserted_ids_set = set(inserted_ids) ids_to_replace = [ - document["_id"] + doc_id for document in documents_to_insert - if document["_id"] not in inserted_ids_set + if (doc_id := self.document_codec.get_id(document)) + not in inserted_ids_set ] else: full_err_message = _insertmany_error_message(err) @@ -1190,7 +1168,7 @@ def add_texts( documents_to_replace = [ document for document in documents_to_insert - if document["_id"] in ids_to_replace + if self.document_codec.get_id(document) in ids_to_replace ] _max_workers = ( @@ -1203,10 +1181,11 @@ def add_texts( def _replace_document( document: dict[str, Any], ) -> tuple[UpdateResult, str]: + doc_id = self.document_codec.get_id(document) return self.astra_env.collection.replace_one( - {"_id": document["_id"]}, + self.document_codec.encode_query(ids=[doc_id]), document, - ), document["_id"] + ), doc_id replace_results = list( executor.map( @@ -1289,7 +1268,7 @@ async def aadd_texts( ) # perform an AstraPy insert_many, catching exceptions for overwriting docs - ids_to_replace: list[int] + ids_to_replace: list[str] inserted_ids: list[str] = [] try: insert_many_result = await self.astra_env.async_collection.insert_many( @@ -1307,9 +1286,10 @@ async def aadd_texts( inserted_ids = err.partial_result.inserted_ids inserted_ids_set = set(inserted_ids) ids_to_replace = [ - document["_id"] + doc_id for document in documents_to_insert - if document["_id"] not in inserted_ids_set + if (doc_id := self.document_codec.get_id(document)) + not in inserted_ids_set ] else: full_err_message = _insertmany_error_message(err) @@ -1320,7 +1300,7 @@ async def aadd_texts( documents_to_replace = [ document for document in documents_to_insert - if document["_id"] in ids_to_replace + if self.document_codec.get_id(document) in ids_to_replace ] sem = asyncio.Semaphore( @@ -1333,10 +1313,11 @@ async def _replace_document( document: dict[str, Any], ) -> tuple[UpdateResult, str]: async with sem: + doc_id = self.document_codec.get_id(document) return await _async_collection.replace_one( - {"_id": document["_id"]}, + self.document_codec.encode_query(ids=[doc_id]), document, - ), document["_id"] + ), doc_id tasks = [ asyncio.create_task(_replace_document(document)) @@ -1395,7 +1376,7 @@ def _update_document( document_id, update_metadata = id_md_pair encoded_metadata = self.filter_to_query(update_metadata) return self.astra_env.collection.update_one( - {"_id": document_id}, + self.document_codec.encode_query(ids=[document_id]), {"$set": encoded_metadata}, ) @@ -1448,7 +1429,7 @@ async def _update_document( encoded_metadata = self.filter_to_query(update_metadata) async with sem: return await _async_collection.update_one( - {"_id": document_id}, + self.document_codec.encode_query(ids=[document_id]), {"$set": encoded_metadata}, ) @@ -1520,7 +1501,7 @@ def get_by_document_id(self, document_id: str) -> Document | None: self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) hit = self.astra_env.collection.find_one( - {"_id": document_id}, + self.document_codec.encode_query(ids=[document_id]), projection=self.document_codec.base_projection, ) if hit is None: @@ -1539,7 +1520,7 @@ async def aget_by_document_id(self, document_id: str) -> Document | None: await self.astra_env.aensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) hit = await self.astra_env.async_collection.find_one( - {"_id": document_id}, + self.document_codec.encode_query(ids=[document_id]), projection=self.document_codec.base_projection, ) if hit is None: @@ -1734,8 +1715,8 @@ def _similarity_search_with_score_id_by_sort( for (doc, sim, did) in ( ( self.document_codec.decode(hit), - hit["$similarity"], - hit["_id"], + self.document_codec.get_similarity(hit), + self.document_codec.get_id(hit), ) for hit in hits_ite ) @@ -2113,8 +2094,8 @@ async def _asimilarity_search_with_score_id_by_sort( async for (doc, sim, did) in ( ( self.document_codec.decode(hit), - hit["$similarity"], - hit["_id"], + self.document_codec.get_similarity(hit), + self.document_codec.get_id(hit), ) async for hit in self.astra_env.async_collection.find( filter=metadata_parameter, diff --git a/libs/astradb/tests/integration_tests/conftest.py b/libs/astradb/tests/integration_tests/conftest.py index c65763e..e871ae8 100644 --- a/libs/astradb/tests/integration_tests/conftest.py +++ b/libs/astradb/tests/integration_tests/conftest.py @@ -32,7 +32,10 @@ from astrapy.info import CollectionVectorServiceOptions from langchain_astradb.utils.astradb import SetupMode -from langchain_astradb.vectorstores import DEFAULT_INDEXING_OPTIONS, AstraDBVectorStore +from langchain_astradb.utils.vector_store_codecs import ( + STANDARD_INDEXING_OPTIONS_DEFAULT, +) +from langchain_astradb.vectorstores import AstraDBVectorStore from tests.conftest import IdentityLLM, ParserEmbeddings if TYPE_CHECKING: @@ -207,7 +210,7 @@ def collection_d2( COLLECTION_NAME_D2, dimension=2, check_exists=False, - indexing=DEFAULT_INDEXING_OPTIONS, + indexing=STANDARD_INDEXING_OPTIONS_DEFAULT, metric="euclidean", ) yield collection @@ -400,7 +403,7 @@ def collection_vz( COLLECTION_NAME_VZ, dimension=1536, check_exists=False, - indexing=DEFAULT_INDEXING_OPTIONS, + indexing=STANDARD_INDEXING_OPTIONS_DEFAULT, metric="euclidean", service=OPENAI_VECTORIZE_OPTIONS_HEADER, embedding_api_key=openai_api_key, diff --git a/libs/astradb/tests/unit_tests/test_vectorstores.py b/libs/astradb/tests/unit_tests/test_vectorstores.py index 51a5ac2..1bb85a5 100644 --- a/libs/astradb/tests/unit_tests/test_vectorstores.py +++ b/libs/astradb/tests/unit_tests/test_vectorstores.py @@ -3,10 +3,11 @@ from astrapy.info import CollectionVectorServiceOptions from langchain_astradb.utils.astradb import SetupMode -from langchain_astradb.vectorstores import ( - DEFAULT_INDEXING_OPTIONS, - AstraDBVectorStore, +from langchain_astradb.utils.vector_store_codecs import ( + _DefaultVSDocumentCodec, + _FlatVSDocumentCodec, ) +from langchain_astradb.vectorstores import AstraDBVectorStore from tests.conftest import ParserEmbeddings FAKE_TOKEN = "t" # noqa: S105 @@ -86,27 +87,62 @@ def test_initialization(self) -> None: ) def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: - """Unit test of the indexing policy normalization""" + """Unit test of the indexing policy normalization. + + We use just a couple of codecs to check the idx policy fallbacks. + """ + + the_f_codec = _FlatVSDocumentCodec( + content_field="content_x", + ignore_invalid_documents=False, + ) + the_f_default_policy = the_f_codec.default_collection_indexing_policy + the_d_codec = _DefaultVSDocumentCodec( + content_field="content_y", + ignore_invalid_documents=False, + ) + + # default (non-flat): hardcoding expected indexing from including + al_d_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=["a1", "a2"], + metadata_indexing_exclude=None, + collection_indexing_policy=None, + document_codec=the_d_codec, + ) + assert al_d_idx == {"allow": ["metadata.a1", "metadata.a2"]} + + # default (non-flat): hardcoding expected indexing from excluding + dl_d_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( + metadata_indexing_include=None, + metadata_indexing_exclude=["d1", "d2"], + collection_indexing_policy=None, + document_codec=the_d_codec, + ) + assert dl_d_idx == {"deny": ["metadata.d1", "metadata.d2"]} + n3_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( metadata_indexing_include=None, metadata_indexing_exclude=None, collection_indexing_policy=None, + document_codec=the_f_codec, ) - assert n3_idx == DEFAULT_INDEXING_OPTIONS + assert n3_idx == the_f_default_policy al_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( metadata_indexing_include=["a1", "a2"], metadata_indexing_exclude=None, collection_indexing_policy=None, + document_codec=the_f_codec, ) - assert al_idx == {"allow": ["metadata.a1", "metadata.a2"]} + assert al_idx == {"allow": ["a1", "a2"]} dl_idx = AstraDBVectorStore._normalize_metadata_indexing_policy( metadata_indexing_include=None, metadata_indexing_exclude=["d1", "d2"], collection_indexing_policy=None, + document_codec=the_f_codec, ) - assert dl_idx == {"deny": ["metadata.d1", "metadata.d2"]} + assert dl_idx == {"deny": ["d1", "d2"]} custom_policy = { "deny": ["myfield", "other_field.subfield", "metadata.long_text"] @@ -115,6 +151,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=None, metadata_indexing_exclude=None, collection_indexing_policy=custom_policy, + document_codec=the_f_codec, ) assert cip_idx == custom_policy @@ -129,6 +166,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=["a"], metadata_indexing_exclude=["b"], collection_indexing_policy=None, + document_codec=the_f_codec, ) with pytest.raises(ValueError, match=error_msg): @@ -136,6 +174,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=["a"], metadata_indexing_exclude=None, collection_indexing_policy={"a": "z"}, + document_codec=the_f_codec, ) with pytest.raises(ValueError, match=error_msg): @@ -143,6 +182,7 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=None, metadata_indexing_exclude=["b"], collection_indexing_policy={"a": "z"}, + document_codec=the_f_codec, ) with pytest.raises(ValueError, match=error_msg): @@ -150,4 +190,5 @@ def test_astradb_vectorstore_unit_indexing_normalization(self) -> None: metadata_indexing_include=["a"], metadata_indexing_exclude=["b"], collection_indexing_policy={"a": "z"}, + document_codec=the_f_codec, ) diff --git a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py index 4e5d8c5..2df3022 100644 --- a/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py +++ b/libs/astradb/tests/unit_tests/test_vs_doc_codecs.py @@ -9,6 +9,7 @@ from langchain_astradb.utils.vector_store_codecs import ( NO_NULL_VECTOR_MSG, VECTOR_REQUIRED_PREAMBLE_MSG, + _AstraDBVectorStoreDocumentCodec, _DefaultVectorizeVSDocumentCodec, _DefaultVSDocumentCodec, _FlatVectorizeVSDocumentCodec, @@ -21,7 +22,6 @@ DOCUMENT_ID = "the_id" LC_DOCUMENT = Document(id=DOCUMENT_ID, page_content=CONTENT, metadata=METADATA) LC_FILTER = {"a0": 0, "$or": [{"b1": 1}, {"b2": 2}]} -ID_FILTER = {"_id": DOCUMENT_ID} VECTOR_SORT = {"$vector": VECTOR} ASTRA_DEFAULT_DOCUMENT_NOVECTORIZE = { @@ -136,12 +136,86 @@ def test_default_novectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_DEFAULT_DOCUMENT_NOVECTORIZE) == VECTOR assert codec.decode_vector({}) is None - def test_default_novectorize_id_encoding(self) -> None: - """Test id-encoding for default, no-vectorize.""" - codec = _DefaultVSDocumentCodec( - content_field="content_x", ignore_invalid_documents=False - ) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER + @pytest.mark.parametrize( + ("default_codec_class", "codec_kwargs"), + [ + ( + _DefaultVSDocumentCodec, + {"content_field": "cf", "ignore_invalid_documents": False}, + ), + (_DefaultVectorizeVSDocumentCodec, {"ignore_invalid_documents": False}), + ], + ids=("default", "default_vectorize"), + ) + def test_default_query_encoding( + self, + default_codec_class: type[_AstraDBVectorStoreDocumentCodec], + codec_kwargs: dict[str, Any], + ) -> None: + """Test query-encoding for default, no-vectorize.""" + codec = default_codec_class(**codec_kwargs) + assert codec.encode_query() == {} + assert codec.encode_query(ids=[DOCUMENT_ID]) == {"_id": DOCUMENT_ID} + assert codec.encode_query(ids=["id1", "id2"]) == { + "_id": {"$in": ["id1", "id2"]} + } + assert codec.encode_query(ids=[DOCUMENT_ID], filter_dict={"mdk": "mdv"}) == { + "$and": [{"_id": DOCUMENT_ID}, {"metadata.mdk": "mdv"}] + } + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"mdk1": "mdv1", "mdk2": "mdv2"} + ) == { + "$and": [ + {"_id": DOCUMENT_ID}, + {"metadata.mdk1": "mdv1", "metadata.mdk2": "mdv2"}, + ] + } + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"$or": [{"mdk1": "mdv1"}, {"mdk2": "mdv2"}]} + ) == { + "$and": [ + {"_id": DOCUMENT_ID}, + {"$or": [{"metadata.mdk1": "mdv1"}, {"metadata.mdk2": "mdv2"}]}, + ] + } + + @pytest.mark.parametrize( + ("default_codec_class", "codec_kwargs"), + [ + ( + _FlatVSDocumentCodec, + {"content_field": "cf", "ignore_invalid_documents": False}, + ), + (_FlatVectorizeVSDocumentCodec, {"ignore_invalid_documents": False}), + ], + ids=("flat", "flat_vectorize"), + ) + def test_flat_query_encoding( + self, + default_codec_class: type[_AstraDBVectorStoreDocumentCodec], + codec_kwargs: dict[str, Any], + ) -> None: + """Test query-encoding for default, no-vectorize.""" + codec = default_codec_class(**codec_kwargs) + assert codec.encode_query() == {} + assert codec.encode_query(ids=[DOCUMENT_ID]) == {"_id": DOCUMENT_ID} + assert codec.encode_query(ids=["id1", "id2"]) == { + "_id": {"$in": ["id1", "id2"]} + } + assert codec.encode_query(ids=[DOCUMENT_ID], filter_dict={"mdk": "mdv"}) == { + "$and": [{"_id": DOCUMENT_ID}, {"mdk": "mdv"}] + } + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"mdk1": "mdv1", "mdk2": "mdv2"} + ) == {"$and": [{"_id": DOCUMENT_ID}, {"mdk1": "mdv1", "mdk2": "mdv2"}]} + assert codec.encode_query( + ids=[DOCUMENT_ID], filter_dict={"$or": [{"mdk1": "mdv1"}, {"mdk2": "mdv2"}]} + ) == { + "$and": [ + {"_id": DOCUMENT_ID}, + {"$or": [{"mdk1": "mdv1"}, {"mdk2": "mdv2"}]}, + ] + } def test_default_novectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for default, no-vectorize.""" @@ -206,11 +280,6 @@ def test_default_vectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_DEFAULT_DOCUMENT_VECTORIZE_READ) == VECTOR assert codec.decode_vector({}) is None - def test_default_vectorize_id_encoding(self) -> None: - """Test id-encoding for default, vectorize.""" - codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER - def test_default_vectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for default, vectorize.""" codec = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=False) @@ -286,13 +355,6 @@ def test_flat_novectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_FLAT_DOCUMENT_NOVECTORIZE) == VECTOR assert codec.decode_vector({}) is None - def test_flat_novectorize_id_encoding(self) -> None: - """Test id-encoding for flat, no-vectorize.""" - codec = _FlatVSDocumentCodec( - content_field="content_x", ignore_invalid_documents=False - ) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER - def test_flat_novectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for flat, no-vectorize.""" codec = _FlatVSDocumentCodec( @@ -356,12 +418,24 @@ def test_flat_vectorize_vector_decoding(self) -> None: assert codec.decode_vector(ASTRA_FLAT_DOCUMENT_VECTORIZE_READ) == VECTOR assert codec.decode_vector({}) is None - def test_flat_vectorize_id_encoding(self) -> None: - """Test id-encoding for flat, vectorize.""" - codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) - assert codec.encode_id(DOCUMENT_ID) == ID_FILTER - def test_flat_vectorize_vectorsort_encoding(self) -> None: """Test vector-sort-encoding for flat, vectorize.""" codec = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=False) assert codec.encode_vector_sort(VECTOR) == VECTOR_SORT + + def test_codec_default_collection_indexing_policy(self) -> None: + """Test all codecs give their expected default indexing settings back.""" + codec_d_n = _DefaultVSDocumentCodec( + content_field="content_x", + ignore_invalid_documents=False, + ) + assert codec_d_n.default_collection_indexing_policy == {"allow": ["metadata"]} + codec_d_v = _DefaultVectorizeVSDocumentCodec(ignore_invalid_documents=True) + assert codec_d_v.default_collection_indexing_policy == {"allow": ["metadata"]} + codec_f_n = _FlatVSDocumentCodec( + content_field="content_x", + ignore_invalid_documents=False, + ) + assert codec_f_n.default_collection_indexing_policy == {"deny": ["content_x"]} + codec_f_v = _FlatVectorizeVSDocumentCodec(ignore_invalid_documents=True) + assert codec_f_v.default_collection_indexing_policy == {}