Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 77 additions & 8 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@
ERROR_LEXICAL_QUERY_ON_NONHYBRID_SEARCH = (
"Parameter 'lexical_query' cannot be passed for a non-hybrid search"
)
# Warning message for retrieving scores on a hybrid search
WARNING_HYBRID_SEARCH_WITH_SCORES = (
"Scores returned as part of a hybrid search, which come from the "
"reranking step, may not be deterministically computed solely based "
"on the query and result. Using the scores e.g. for "
"threshold-filtering may lead to unpredictable results and is discouraged."
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -2459,7 +2466,7 @@ def similarity_search(
"""
return [
doc
for (doc, _, _) in self.similarity_search_with_score_id(
for (doc, _, _) in self._similarity_search_with_score_id_impl(
query=query,
k=k,
filter=filter,
Expand Down Expand Up @@ -2488,24 +2495,26 @@ def similarity_search_with_score(
Returns:
The list of (Document, score), the most similar to the query vector.
"""
if self.hybrid_search:
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
return [
(doc, score)
for (doc, score, _) in self.similarity_search_with_score_id(
for (doc, score, _) in self._similarity_search_with_score_id_impl(
query=query,
k=k,
filter=filter,
lexical_query=lexical_query,
)
]

def similarity_search_with_score_id(
def _similarity_search_with_score_id_impl(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
lexical_query: str | None = None,
) -> list[tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.
"""Implementation for similarity_search_with_score_id.

Args:
query: Query to look up documents similar to.
Expand Down Expand Up @@ -2561,6 +2570,35 @@ def similarity_search_with_score_id(
filter_dict=filter,
)

def similarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
lexical_query: str | None = None,
) -> list[tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.

Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
lexical_query: for hybrid search, a specific query for the lexical
portion of the retrieval. If omitted or empty, defaults to the same
as 'query'. If passed on a non-hybrid search, an error is raised.

Returns:
The list of (Document, score, id), the most similar to the query.
"""
if self.hybrid_search:
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
return self._similarity_search_with_score_id_impl(
query=query,
k=k,
filter=filter,
lexical_query=lexical_query,
)

@override
def similarity_search_by_vector(
self,
Expand Down Expand Up @@ -2727,7 +2765,7 @@ async def asimilarity_search(
"""
return [
doc
for (doc, _, _) in await self.asimilarity_search_with_score_id(
for (doc, _, _) in await self._asimilarity_search_with_score_id_impl(
query=query,
k=k,
filter=filter,
Expand Down Expand Up @@ -2756,24 +2794,26 @@ async def asimilarity_search_with_score(
Returns:
The list of (Document, score), the most similar to the query vector.
"""
if self.hybrid_search:
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
return [
(doc, score)
for (doc, score, _) in await self.asimilarity_search_with_score_id(
for (doc, score, _) in await self._asimilarity_search_with_score_id_impl(
query=query,
k=k,
filter=filter,
lexical_query=lexical_query,
)
]

async def asimilarity_search_with_score_id(
async def _asimilarity_search_with_score_id_impl(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
lexical_query: str | None = None,
) -> list[tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.
"""Implementation for asimilarity_search_with_score_id.

Args:
query: Query to look up documents similar to.
Expand Down Expand Up @@ -2829,6 +2869,35 @@ async def asimilarity_search_with_score_id(
filter_dict=filter,
)

async def asimilarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | None = None, # noqa: A002
lexical_query: str | None = None,
) -> list[tuple[Document, float, str]]:
"""Return docs most similar to the query with score and id.

Args:
query: Query to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filter on the metadata to apply.
lexical_query: for hybrid search, a specific query for the lexical
portion of the retrieval. If omitted or empty, defaults to the same
as 'query'. If passed on a non-hybrid search, an error is raised.

Returns:
The list of (Document, score, id), the most similar to the query.
"""
if self.hybrid_search:
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
return await self._asimilarity_search_with_score_id_impl(
query=query,
k=k,
filter=filter,
lexical_query=lexical_query,
)

@override
async def asimilarity_search_by_vector(
self,
Expand Down