Skip to content

Commit 944422c

Browse files
authored
refactor similarity searches with score and add warning if used from outside (#129)
1 parent e8c9482 commit 944422c

File tree

1 file changed

+77
-8
lines changed

1 file changed

+77
-8
lines changed

libs/astradb/langchain_astradb/vectorstores.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@
8787
ERROR_LEXICAL_QUERY_ON_NONHYBRID_SEARCH = (
8888
"Parameter 'lexical_query' cannot be passed for a non-hybrid search"
8989
)
90+
# Warning message for retrieving scores on a hybrid search
91+
WARNING_HYBRID_SEARCH_WITH_SCORES = (
92+
"Scores returned as part of a hybrid search, which come from the "
93+
"reranking step, may not be deterministically computed solely based "
94+
"on the query and result. Using the scores e.g. for "
95+
"threshold-filtering may lead to unpredictable results and is discouraged."
96+
)
9097

9198
logger = logging.getLogger(__name__)
9299

@@ -2459,7 +2466,7 @@ def similarity_search(
24592466
"""
24602467
return [
24612468
doc
2462-
for (doc, _, _) in self.similarity_search_with_score_id(
2469+
for (doc, _, _) in self._similarity_search_with_score_id_impl(
24632470
query=query,
24642471
k=k,
24652472
filter=filter,
@@ -2488,24 +2495,26 @@ def similarity_search_with_score(
24882495
Returns:
24892496
The list of (Document, score), the most similar to the query vector.
24902497
"""
2498+
if self.hybrid_search:
2499+
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
24912500
return [
24922501
(doc, score)
2493-
for (doc, score, _) in self.similarity_search_with_score_id(
2502+
for (doc, score, _) in self._similarity_search_with_score_id_impl(
24942503
query=query,
24952504
k=k,
24962505
filter=filter,
24972506
lexical_query=lexical_query,
24982507
)
24992508
]
25002509

2501-
def similarity_search_with_score_id(
2510+
def _similarity_search_with_score_id_impl(
25022511
self,
25032512
query: str,
25042513
k: int = 4,
25052514
filter: dict[str, Any] | None = None, # noqa: A002
25062515
lexical_query: str | None = None,
25072516
) -> list[tuple[Document, float, str]]:
2508-
"""Return docs most similar to the query with score and id.
2517+
"""Implementation for similarity_search_with_score_id.
25092518
25102519
Args:
25112520
query: Query to look up documents similar to.
@@ -2561,6 +2570,35 @@ def similarity_search_with_score_id(
25612570
filter_dict=filter,
25622571
)
25632572

2573+
def similarity_search_with_score_id(
2574+
self,
2575+
query: str,
2576+
k: int = 4,
2577+
filter: dict[str, Any] | None = None, # noqa: A002
2578+
lexical_query: str | None = None,
2579+
) -> list[tuple[Document, float, str]]:
2580+
"""Return docs most similar to the query with score and id.
2581+
2582+
Args:
2583+
query: Query to look up documents similar to.
2584+
k: Number of Documents to return. Defaults to 4.
2585+
filter: Filter on the metadata to apply.
2586+
lexical_query: for hybrid search, a specific query for the lexical
2587+
portion of the retrieval. If omitted or empty, defaults to the same
2588+
as 'query'. If passed on a non-hybrid search, an error is raised.
2589+
2590+
Returns:
2591+
The list of (Document, score, id), the most similar to the query.
2592+
"""
2593+
if self.hybrid_search:
2594+
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
2595+
return self._similarity_search_with_score_id_impl(
2596+
query=query,
2597+
k=k,
2598+
filter=filter,
2599+
lexical_query=lexical_query,
2600+
)
2601+
25642602
@override
25652603
def similarity_search_by_vector(
25662604
self,
@@ -2727,7 +2765,7 @@ async def asimilarity_search(
27272765
"""
27282766
return [
27292767
doc
2730-
for (doc, _, _) in await self.asimilarity_search_with_score_id(
2768+
for (doc, _, _) in await self._asimilarity_search_with_score_id_impl(
27312769
query=query,
27322770
k=k,
27332771
filter=filter,
@@ -2756,24 +2794,26 @@ async def asimilarity_search_with_score(
27562794
Returns:
27572795
The list of (Document, score), the most similar to the query vector.
27582796
"""
2797+
if self.hybrid_search:
2798+
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
27592799
return [
27602800
(doc, score)
2761-
for (doc, score, _) in await self.asimilarity_search_with_score_id(
2801+
for (doc, score, _) in await self._asimilarity_search_with_score_id_impl(
27622802
query=query,
27632803
k=k,
27642804
filter=filter,
27652805
lexical_query=lexical_query,
27662806
)
27672807
]
27682808

2769-
async def asimilarity_search_with_score_id(
2809+
async def _asimilarity_search_with_score_id_impl(
27702810
self,
27712811
query: str,
27722812
k: int = 4,
27732813
filter: dict[str, Any] | None = None, # noqa: A002
27742814
lexical_query: str | None = None,
27752815
) -> list[tuple[Document, float, str]]:
2776-
"""Return docs most similar to the query with score and id.
2816+
"""Implementation for asimilarity_search_with_score_id.
27772817
27782818
Args:
27792819
query: Query to look up documents similar to.
@@ -2829,6 +2869,35 @@ async def asimilarity_search_with_score_id(
28292869
filter_dict=filter,
28302870
)
28312871

2872+
async def asimilarity_search_with_score_id(
2873+
self,
2874+
query: str,
2875+
k: int = 4,
2876+
filter: dict[str, Any] | None = None, # noqa: A002
2877+
lexical_query: str | None = None,
2878+
) -> list[tuple[Document, float, str]]:
2879+
"""Return docs most similar to the query with score and id.
2880+
2881+
Args:
2882+
query: Query to look up documents similar to.
2883+
k: Number of Documents to return. Defaults to 4.
2884+
filter: Filter on the metadata to apply.
2885+
lexical_query: for hybrid search, a specific query for the lexical
2886+
portion of the retrieval. If omitted or empty, defaults to the same
2887+
as 'query'. If passed on a non-hybrid search, an error is raised.
2888+
2889+
Returns:
2890+
The list of (Document, score, id), the most similar to the query.
2891+
"""
2892+
if self.hybrid_search:
2893+
warnings.warn(WARNING_HYBRID_SEARCH_WITH_SCORES, stacklevel=2)
2894+
return await self._asimilarity_search_with_score_id_impl(
2895+
query=query,
2896+
k=k,
2897+
filter=filter,
2898+
lexical_query=lexical_query,
2899+
)
2900+
28322901
@override
28332902
async def asimilarity_search_by_vector(
28342903
self,

0 commit comments

Comments
 (0)