Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9500853
wip: adding Text and Hybrid queries
justin-cechmanek Mar 24, 2025
3fa93ff
tokenizer helper function
rbs333 Mar 24, 2025
3b8e2b6
adds TextQuery class
justin-cechmanek Mar 26, 2025
7e0f24d
adds nltk requirement
justin-cechmanek Mar 27, 2025
49b2aba
makes stopwords user defined in TextQuery
justin-cechmanek Mar 27, 2025
f7a4b9e
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek Apr 1, 2025
6e007f7
Validate passed-in Redis clients (#296)
abrookins Mar 21, 2025
298d055
Add batch_search to sync Index (#305)
abrookins Mar 29, 2025
94eea52
Support client-side schema validation using Pydantic (#304)
tylerhutcherson Mar 31, 2025
123ee22
Run API tests once (#306)
abrookins Mar 31, 2025
9025bfe
Add option to normalize vector distances on query (#298)
rbs333 Mar 31, 2025
ae69ae9
adds TextQuery class
justin-cechmanek Mar 26, 2025
e403934
makes stopwords user defined in TextQuery
justin-cechmanek Mar 27, 2025
9348583
adds hybrid aggregation query and tests. modifies search index to acc…
justin-cechmanek Apr 1, 2025
10f4474
cleans text and hybrid tests
justin-cechmanek Apr 2, 2025
018fe9f
merge conflicts
justin-cechmanek Apr 2, 2025
3518121
updates lock file
justin-cechmanek Apr 2, 2025
091148c
mypy cannot find defined methods
justin-cechmanek Apr 2, 2025
9069dd5
updates nltk requirement
justin-cechmanek Apr 2, 2025
c5ad696
I swear I have changed this 4 times now
justin-cechmanek Apr 2, 2025
ea5d087
wip: debugging aggregations and filters
justin-cechmanek Apr 2, 2025
1672ea3
fixes query string parsing. adds more tests
justin-cechmanek Apr 3, 2025
f32067a
test now checks default dialect is 2
justin-cechmanek Apr 3, 2025
9b1dc18
makes methods private
justin-cechmanek Apr 3, 2025
ff44041
abstracts AggregationQuery to follow BaseQuery calls in search index
justin-cechmanek Apr 3, 2025
c0be24f
updates docstrings
justin-cechmanek Apr 3, 2025
aae3949
fixes docstring
justin-cechmanek Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
BaseVectorQuery,
CountQuery,
FilterQuery,
HybridAggregationQuery,
HybridQuery,
)
from redisvl.query.filter import FilterExpression
from redisvl.redis.connection import (
Expand Down Expand Up @@ -686,35 +686,8 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

def aggregate_query(
self, aggregation_query: AggregationQuery
) -> List[Dict[str, Any]]:
"""Execute an aggretation query and processes the results.

This method takes an AggregationHyridQuery object directly, runs the search, and
handles post-processing of the search.

Args:
aggregation_query (AggregationQuery): The aggregation query to run.

Returns:
List[Result]: A list of search results.

.. code-block:: python

from redisvl.query import HybridAggregationQuery

aggregation = HybridAggregationQuery(
text="the text to search for",
text_field="description",
vector=[0.16, -0.34, 0.98, 0.23],
vector_field="embedding",
num_results=3
)

results = index.aggregate_query(aggregation_query)

"""
def _aggregate(self, aggregation_query: AggregationQuery) -> List[Dict[str, Any]]:
"""Execute an aggretation query and processes the results."""
results = self.aggregate(
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
)
Expand Down Expand Up @@ -846,7 +819,7 @@ def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
results = self.search(query.query, query_params=query.params)
return process_results(results, query=query, schema=self.schema)

def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
def query(self, query: Union[BaseQuery, AggregationQuery]) -> List[Dict[str, Any]]:
"""Execute a query on the index.

This method takes a BaseQuery object directly, runs the search, and
Expand All @@ -871,7 +844,10 @@ def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
results = index.query(query)

"""
return self._query(query)
if isinstance(query, AggregationQuery):
return self._aggregate(query)
else:
return self._query(query)

def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator:
"""Execute a given query against the index and return results in
Expand Down Expand Up @@ -1377,6 +1353,19 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

async def _aggregate(
self, aggregation_query: AggregationQuery
) -> List[Dict[str, Any]]:
"""Execute an aggretation query and processes the results."""
results = await self.aggregate(
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
)
return process_aggregate_results(
results,
query=aggregation_query,
storage_type=self.schema.index.storage_type,
)

async def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.

Expand Down Expand Up @@ -1500,14 +1489,16 @@ async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
results = await self.search(query.query, query_params=query.params)
return process_results(results, query=query, schema=self.schema)

async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
async def query(
self, query: Union[BaseQuery, AggregationQuery]
) -> List[Dict[str, Any]]:
"""Asynchronously execute a query on the index.

This method takes a BaseQuery object directly, runs the search, and
handles post-processing of the search.
This method takes a BaseQuery or AggregationQuery object directly, runs
the search, and handles post-processing of the search.

Args:
query (BaseQuery): The query to run.
query Union(BaseQuery, AggregationQuery): The query to run.

Returns:
List[Result]: A list of search results.
Expand All @@ -1524,7 +1515,10 @@ async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:

results = await index.query(query)
"""
return await self._query(query)
if isinstance(query, AggregationQuery):
return await self._aggregate(query)
else:
return await self._query(query)

async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerator:
"""Execute a given query against the index and return results in
Expand Down
4 changes: 2 additions & 2 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from redisvl.query.aggregate import AggregationQuery, HybridAggregationQuery
from redisvl.query.aggregate import AggregationQuery, HybridQuery
from redisvl.query.query import (
BaseQuery,
BaseVectorQuery,
Expand All @@ -20,5 +20,5 @@
"CountQuery",
"TextQuery",
"AggregationQuery",
"HybridAggregationQuery",
"HybridQuery",
]
16 changes: 8 additions & 8 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __init__(self, query_string):
super().__init__(query_string)


class HybridAggregationQuery(AggregationQuery):
class HybridQuery(AggregationQuery):
"""
HybridAggregationQuery combines text and vector search in Redis.
HybridQuery combines text and vector search in Redis.
It allows you to perform a hybrid search using both text and vector similarity.
It scores documents based on a weighted combination of text and vector similarity.
"""
Expand All @@ -45,7 +45,7 @@ def __init__(
dialect: int = 2,
):
"""
Instantiages a HybridAggregationQuery object.
Instantiages a HybridQuery object.

Args:
text (str): The text to search for.
Expand Down Expand Up @@ -75,12 +75,12 @@ def __init__(
TypeError: If the stopwords are not a set, list, or tuple of strings.

.. code-block:: python
from redisvl.query.aggregate import HybridAggregationQuery
from redisvl.query import HybridQuery
from redisvl.index import SearchIndex

index = SearchIndex("my_index")
index = SearchIndex.from_yaml(index.yaml)

query = HybridAggregationQuery(
query = HybridQuery(
text="example text",
text_field_name="text_field",
vector=[0.1, 0.2, 0.3],
Expand All @@ -92,10 +92,10 @@ def __init__(
num_results=10,
return_fields=["field1", "field2"],
stopwords="english",
dialect=4,
dialect=2,
)

results = index.aggregate_query(query)
results = index.query(query)
"""

if not text.strip():
Expand Down
21 changes: 20 additions & 1 deletion redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def __init__(
text (str): The text string to perform the text search with.
text_field_name (str): The name of the document field to perform text search on.
text_scorer (str, optional): The text scoring algorithm to use.
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, DOCNORM, DISMAX, DOCSCORE}.
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, TFIDF.DOCNORM, DISMAX, DOCSCORE}.
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
filter_expression (Union[str, FilterExpression], optional): A filter to apply
along with the text search. Defaults to None.
Expand Down Expand Up @@ -740,6 +740,25 @@ def __init__(
Raises:
ValueError: if stopwords language string cannot be loaded.
TypeError: If stopwords is not a valid iterable set of strings.

.. code-block:: python
from redisvl.query import TextQuery
from redisvl.index import SearchIndex

index = SearchIndex.from_yaml(index.yaml)

query = TextQuery(
text="example text",
text_field_name="text_field",
text_scorer="BM25STD",
filter_expression=None,
num_results=10,
return_fields=["field1", "field2"],
stopwords="english",
dialect=2,
)

results = index.query(query)
"""
self._text = text
self._text_field = text_field_name
Expand Down
38 changes: 19 additions & 19 deletions tests/integration/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from redis.commands.search.result import Result

from redisvl.index import SearchIndex
from redisvl.query import HybridAggregationQuery
from redisvl.query import HybridQuery
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
from redisvl.redis.connection import compare_versions
from redisvl.redis.utils import array_to_buffer
Expand Down Expand Up @@ -70,15 +70,15 @@ def test_aggregation_query(index):
vector_field = "user_embedding"
return_fields = ["user", "credit_score", "age", "job", "location", "description"]

hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
return_fields=return_fields,
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert isinstance(results, list)
assert len(results) == 7
for doc in results:
Expand All @@ -96,15 +96,15 @@ def test_aggregation_query(index):
assert doc["job"] in ["engineer", "doctor", "dermatologist", "CEO", "dentist"]
assert doc["credit_score"] in ["high", "low", "medium"]

hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
num_results=3,
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 3
assert (
results[0]["hybrid_score"]
Expand All @@ -122,7 +122,7 @@ def test_empty_query_string():

# test if text is empty
with pytest.raises(ValueError):
hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -132,7 +132,7 @@ def test_empty_query_string():
# test if text becomes empty after stopwords are removed
text = "with a for but and" # will all be removed as default stopwords
with pytest.raises(ValueError):
hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -152,7 +152,7 @@ def test_aggregation_query_with_filter(index):
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
filter_expression = (Tag("credit_score") == ("high")) & (Num("age") > 30)

hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -161,7 +161,7 @@ def test_aggregation_query_with_filter(index):
return_fields=return_fields,
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 2
for result in results:
assert result["credit_score"] == "high"
Expand All @@ -180,7 +180,7 @@ def test_aggregation_query_with_geo_filter(index):
return_fields = ["user", "credit_score", "age", "job", "location", "description"]
filter_expression = Geo("location") == GeoRadius(-122.4194, 37.7749, 1000, "m")

hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -189,7 +189,7 @@ def test_aggregation_query_with_geo_filter(index):
return_fields=return_fields,
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 3
for result in results:
assert result["location"] is not None
Expand All @@ -206,15 +206,15 @@ def test_aggregate_query_alpha(index, alpha):
vector = [0.1, 0.1, 0.5]
vector_field = "user_embedding"

hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
alpha=alpha,
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 7
for result in results:
score = alpha * float(result["vector_similarity"]) + (1 - alpha) * float(
Expand All @@ -236,7 +236,7 @@ def test_aggregate_query_stopwords(index):
vector_field = "user_embedding"
alpha = 0.5

hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -250,7 +250,7 @@ def test_aggregate_query_stopwords(index):
assert "medical" not in query_string
assert "expertize" not in query_string

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 7
for result in results:
score = alpha * float(result["vector_similarity"]) + (1 - alpha) * float(
Expand All @@ -273,7 +273,7 @@ def test_aggregate_query_with_text_filter(index):
filter_expression = Text(text_field) == ("medical")

# make sure we can still apply filters to the same text field we are querying
hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -283,15 +283,15 @@ def test_aggregate_query_with_text_filter(index):
return_fields=["job", "description"],
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 2
for result in results:
assert "medical" in result[text_field].lower()

filter_expression = (Text(text_field) == ("medical")) & (
(Text(text_field) != ("research"))
)
hybrid_query = HybridAggregationQuery(
hybrid_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
Expand All @@ -301,7 +301,7 @@ def test_aggregate_query_with_text_filter(index):
return_fields=["description"],
)

results = index.aggregate_query(hybrid_query)
results = index.query(hybrid_query)
assert len(results) == 2
for result in results:
assert "medical" in result[text_field].lower()
Expand Down
Loading