Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 136 additions & 33 deletions libs/astradb/langchain_astradb/utils/vector_store_codecs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Classes to handle encoding of documents on DB for the Vector Store.."""
"""Classes to handle encoding of Documents on DB for the Vector Store.."""

from __future__ import annotations

import logging
import warnings
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Iterable

from langchain_core.documents import Document
from typing_extensions import override
Expand All @@ -16,14 +16,44 @@
)
FLATTEN_CONFLICT_MSG = "Cannot flatten metadata: field name overlap for '{field}'."

STANDARD_INDEXING_OPTIONS_DEFAULT = {"allow": ["metadata"]}

logger = logging.getLogger(__name__)


def _default_decode_vector(astra_doc: dict[str, Any]) -> list[float] | None:
"""Extract the embedding vector from an Astra DB document."""
return astra_doc.get("$vector")


def _default_metadata_key_to_field_identifier(md_key: str) -> str:
"""Rewrite a metadata key name to its full path in the 'default' encoding.

The input `md_key` is an "abstract" metadata key, while the return value
identifies its actual full-path location on an Astra DB document encoded in the
'default' way (i.e. with a nested `metadata` dictionary).
"""
return f"metadata.{md_key}"


def _flat_metadata_key_to_field_identifier(md_key: str) -> str:
"""Rewrite a metadata key name to its full path in the 'flat' encoding.

The input `md_key` is an "abstract" metadata key, while the return value
identifies its actual full-path location on an Astra DB document encoded in the
'flat' way (i.e. metadata fields appearing at top-level in the Astra DB document).
"""
return md_key


def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]:
"""Encode an "abstract" metadata condition for the 'default' encoding.

The input can express a query clause on metadata and uses just the metadata field
names, possibly connected/nested through AND and ORs. The output makes key names
into their full path-identifiers (e.g. "metadata.xyz") according to the 'default'
encoding scheme for Astra DB documents.
"""
metadata_filter = {}
for k, v in filter_dict.items():
# Key in this dict starting with $ are supposedly operators and as such
Expand All @@ -37,32 +67,48 @@ def _default_encode_filter(filter_dict: dict[str, Any]) -> dict[str, Any]:
# assume each list item can be fed back to this function
metadata_filter[k] = _default_encode_filter(v) # type: ignore[assignment]
else:
metadata_filter[f"metadata.{k}"] = v
metadata_filter[_default_metadata_key_to_field_identifier(k)] = v

return metadata_filter


def _default_encode_id(filter_id: str) -> dict[str, Any]:
def _astra_generic_encode_id(filter_id: str) -> dict[str, Any]:
"""Encoding of a single Document ID as a query clause for an Astra DB document."""
return {"_id": filter_id}


def _default_encode_vector_sort(vector: list[float]) -> dict[str, Any]:
def _astra_generic_encode_ids(filter_ids: list[str]) -> dict[str, Any]:
"""Encoding of Document IDs as a query clause for an Astra DB document.

This function picks the right, and most concise, expression based on the
multiplicity of the provided IDs.
"""
if len(filter_ids) == 1:
return _astra_generic_encode_id(filter_ids[0])
return {"_id": {"$in": filter_ids}}


def _astra_generic_encode_vector_sort(vector: list[float]) -> dict[str, Any]:
"""Encoding of a vector-based sort as a query clause for an Astra DB document."""
return {"$vector": vector}


class _AstraDBVectorStoreDocumentCodec(ABC):
"""A document codec for the Astra DB vector store.
"""A Document codec for the Astra DB vector store.

The document codec contains the information for consistent interaction
Document codecs hold the logic consistent interaction
with documents as stored on the Astra DB collection.

In this context, 'Document' (capital D) refers to the LangChain class,
while 'Astra DB document' refers to the JSON-like object stored on DB.

Implementations of this class must:
- define how to encode/decode documents consistently to and from
Astra DB collections. The two operations must, so to speak, combine
to the identity on both sides (except for the quirks of their signatures).
- provide the adequate projection dictionaries for running find
operations on Astra DB, with and without the field containing the vector.
- encode IDs to the `_id` field on Astra DB.
- encode Document IDs to the right field on Astra DB ("_id" for Collections).
- define the name of the field storing the textual content of the Document.
- define whether embeddings are computed server-side (with $vectorize) or not.
"""
Expand Down Expand Up @@ -132,17 +178,72 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
"""

@abstractmethod
def encode_id(self, filter_id: str) -> dict[str, Any]:
"""Encode an ID as a filter for use in Astra DB queries.
def metadata_key_to_field_identifier(self, md_key: str) -> str:
"""Express an 'abstract' metadata key as a full Data API field identifier."""

@property
@abstractmethod
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
"""Provide the default indexing policy if the collection must be created."""

def get_id(self, astra_document: dict[str, Any]) -> str:
"""Return the ID of an encoded document (= a raw JSON read from DB)."""
return astra_document["_id"]

def get_similarity(self, astra_document: dict[str, Any]) -> float:
"""Return the similarity of an encoded document (= a raw JSON read from DB).

This method assumes its argument comes from a suitable vector search.
"""
return astra_document["$similarity"]

def encode_query(
self,
*,
ids: Iterable[str] | None = None,
filter_dict: dict[str, Any] | None = None,
) -> dict[str, Any]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May merit pydoc indicating the implicit $and.

"""Prepare an encoded query according to the Astra DB document encoding.

The method optionally accepts both IDs and metadata filters. The two,
if passed together, are automatically combined with an AND operation.

In other words, if passing both IDs and a metadata filtering clause,
the resulting query would return Astra DB documents matching the metadata
clause AND having an ID among those provided to this method. If, instead,
an OR is required, one should run two separate queries and subsequently merge
the result (taking care of avoiding duplcates).

Args:
filter_id: the ID value to filter on.
ids: an iterable over Document IDs. If provided, the resulting Astra DB
query dictionary expresses the requirement that returning documents
have an ID among those provided here. Passing an empty iterable,
or None, results in a query with no conditions on the IDs at all.
filter_dict: a metadata filtering part. If provided, if must refer to
metadata keys by their bare name (such as `{"key": 123}`).
This filter can combine nested conditions with "$or"/"$and" connectors,
for example:
- `{"tag": "a"}`
- `{"$or": [{"tag": "a"}, "label": "b"]}`
- `{"$and": [{"tag": {"$in": ["a", "z"]}}, "label": "b"]}`

Returns:
an filter clause for use in Astra DB's find queries.
a query dictionary ready to be used in an Astra DB find operation on
a collection.
"""
clauses: list[dict[str, Any]] = []
_ids_list = list(ids or [])
if _ids_list:
clauses.append(_astra_generic_encode_ids(_ids_list))
if filter_dict:
clauses.append(self.encode_filter(filter_dict))

if clauses:
if len(clauses) > 1:
return {"$and": clauses}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed this makes sense. In general, my rationale is:

  1. If you want ids OR filter, then run separate queries -- the IDs only query is likely to be fast and the filter-only query is likely to do a scan.
  2. If you want ids AND filter then there is no option beyond running them together (unless you somehow emulate the full filtering semantics on the client side).

So, it seems like the AND is the only reasonable choice.

return clauses[0]
return {}

@abstractmethod
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
"""Encode a vector as a sort to use for Astra DB queries.

Expand All @@ -152,6 +253,7 @@ def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
Returns:
an order clause for use in Astra DB's find queries.
"""
return _astra_generic_encode_vector_sort(vector)


class _DefaultVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -226,12 +328,12 @@ def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return _default_encode_filter(filter_dict)

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _default_metadata_key_to_field_identifier(md_key)

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
return STANDARD_INDEXING_OPTIONS_DEFAULT


class _DefaultVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -308,13 +410,13 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None:
def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return _default_encode_filter(filter_dict)

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
return STANDARD_INDEXING_OPTIONS_DEFAULT

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _default_metadata_key_to_field_identifier(md_key)


class _FlatVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -396,13 +498,13 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None:
def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return filter_dict

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
return {"deny": [self.content_field]}

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _flat_metadata_key_to_field_identifier(md_key)


class _FlatVectorizeVSDocumentCodec(_AstraDBVectorStoreDocumentCodec):
Expand Down Expand Up @@ -477,10 +579,11 @@ def decode_vector(self, astra_document: dict[str, Any]) -> list[float] | None:
def encode_filter(self, filter_dict: dict[str, Any]) -> dict[str, Any]:
return filter_dict

@override
def encode_id(self, filter_id: str) -> dict[str, Any]:
return _default_encode_id(filter_id)
@property
def default_collection_indexing_policy(self) -> dict[str, list[str]]:
# $vectorize cannot be de-indexed explicitly (the API manages it entirely).
return {}

@override
def encode_vector_sort(self, vector: list[float]) -> dict[str, Any]:
return _default_encode_vector_sort(vector)
def metadata_key_to_field_identifier(self, md_key: str) -> str:
return _flat_metadata_key_to_field_identifier(md_key)
Loading