Skip to content

Commit 2c05e2b

Browse files
committed
Adds support for HYBRID_POLICY on KNN queries (VectorQuery and VectorRangeQuery) with filters
1 parent 29cb397 commit 2c05e2b

File tree

3 files changed

+657
-26
lines changed

3 files changed

+657
-26
lines changed

redisvl/query/query.py

Lines changed: 257 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def __init__(
188188
dialect: int = 2,
189189
sort_by: Optional[str] = None,
190190
in_order: bool = False,
191+
hybrid_policy: Optional[str] = None,
192+
batch_size: Optional[int] = None,
191193
):
192194
"""A query for running a vector search along with an optional filter
193195
expression.
@@ -213,6 +215,16 @@ def __init__(
213215
in_order (bool): Requires the terms in the field to have
214216
the same order as the terms in the query filter, regardless of
215217
the offsets between them. Defaults to False.
218+
hybrid_policy (Optional[str]): Controls how filters are applied during vector search.
219+
Options are "BATCHES" (paginates through small batches of nearest neighbors) or
220+
"ADHOC_BF" (computes scores for all vectors passing the filter).
221+
"BATCHES" mode is typically faster for queries with selective filters.
222+
"ADHOC_BF" mode is better when filters match a large portion of the dataset.
223+
Defaults to None, which lets Redis auto-select the optimal policy.
224+
batch_size (Optional[int]): When hybrid_policy is "BATCHES", controls the number
225+
of vectors to fetch in each batch. Larger values may improve performance
226+
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
227+
Defaults to None, which lets Redis auto-select an appropriate batch size.
216228
217229
Raises:
218230
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -224,6 +236,8 @@ def __init__(
224236
self._vector_field_name = vector_field_name
225237
self._dtype = dtype
226238
self._num_results = num_results
239+
self._hybrid_policy: Optional[str] = None
240+
self._batch_size: Optional[int] = None
227241
self.set_filter(filter_expression)
228242
query_string = self._build_query_string()
229243

@@ -246,12 +260,89 @@ def __init__(
246260
if in_order:
247261
self.in_order()
248262

263+
if hybrid_policy is not None:
264+
self.set_hybrid_policy(hybrid_policy)
265+
266+
if batch_size is not None:
267+
self.set_batch_size(batch_size)
268+
249269
def _build_query_string(self) -> str:
250270
"""Build the full query string for vector search with optional filtering."""
251271
filter_expression = self._filter_expression
252272
if isinstance(filter_expression, FilterExpression):
253273
filter_expression = str(filter_expression)
254-
return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
274+
275+
# Base KNN query
276+
knn_query = (
277+
f"KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM}"
278+
)
279+
280+
# Add hybrid policy parameters if specified
281+
if self._hybrid_policy:
282+
knn_query += f" HYBRID_POLICY {self._hybrid_policy}"
283+
284+
# Add batch size if specified and using BATCHES policy
285+
if self._hybrid_policy == "BATCHES" and self._batch_size:
286+
knn_query += f" BATCH_SIZE {self._batch_size}"
287+
288+
# Add distance field alias
289+
knn_query += f" AS {self.DISTANCE_ID}"
290+
291+
return f"{filter_expression}=>[{knn_query}]"
292+
293+
def set_hybrid_policy(self, hybrid_policy: str):
294+
"""Set the hybrid policy for the query.
295+
296+
Args:
297+
hybrid_policy (str): The hybrid policy to use. Options are "BATCHES"
298+
or "ADHOC_BF".
299+
300+
Raises:
301+
ValueError: If hybrid_policy is not one of the valid options
302+
"""
303+
if hybrid_policy not in {"BATCHES", "ADHOC_BF"}:
304+
raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}")
305+
self._hybrid_policy = hybrid_policy
306+
307+
# Reset the query string
308+
self._query_string = self._build_query_string()
309+
310+
def set_batch_size(self, batch_size: int):
311+
"""Set the batch size for the query.
312+
313+
Args:
314+
batch_size (int): The batch size to use when hybrid_policy is "BATCHES".
315+
316+
Raises:
317+
TypeError: If batch_size is not an integer
318+
ValueError: If batch_size is not positive
319+
"""
320+
if not isinstance(batch_size, int):
321+
raise TypeError("batch_size must be an integer")
322+
if batch_size <= 0:
323+
raise ValueError("batch_size must be positive")
324+
self._batch_size = batch_size
325+
326+
# Reset the query string
327+
self._query_string = self._build_query_string()
328+
329+
@property
330+
def hybrid_policy(self) -> Optional[str]:
331+
"""Return the hybrid policy for the query.
332+
333+
Returns:
334+
Optional[str]: The hybrid policy for the query.
335+
"""
336+
return self._hybrid_policy
337+
338+
@property
339+
def batch_size(self) -> Optional[int]:
340+
"""Return the batch size for the query.
341+
342+
Returns:
343+
Optional[int]: The batch size for the query.
344+
"""
345+
return self._batch_size
255346

256347
@property
257348
def params(self) -> Dict[str, Any]:
@@ -265,11 +356,16 @@ def params(self) -> Dict[str, Any]:
265356
else:
266357
vector = array_to_buffer(self._vector, dtype=self._dtype)
267358

268-
return {self.VECTOR_PARAM: vector}
359+
params = {self.VECTOR_PARAM: vector}
360+
361+
return params
269362

270363

271364
class VectorRangeQuery(BaseVectorQuery, BaseQuery):
272365
DISTANCE_THRESHOLD_PARAM: str = "distance_threshold"
366+
EPSILON_PARAM: str = "EPSILON" # Parameter name for epsilon
367+
HYBRID_POLICY_PARAM: str = "HYBRID_POLICY" # Parameter name for hybrid policy
368+
BATCH_SIZE_PARAM: str = "BATCH_SIZE" # Parameter name for batch size
273369

274370
def __init__(
275371
self,
@@ -279,11 +375,14 @@ def __init__(
279375
filter_expression: Optional[Union[str, FilterExpression]] = None,
280376
dtype: str = "float32",
281377
distance_threshold: float = 0.2,
378+
epsilon: Optional[float] = None,
282379
num_results: int = 10,
283380
return_score: bool = True,
284381
dialect: int = 2,
285382
sort_by: Optional[str] = None,
286383
in_order: bool = False,
384+
hybrid_policy: Optional[str] = None,
385+
batch_size: Optional[int] = None,
287386
):
288387
"""A query for running a filtered vector search based on semantic
289388
distance threshold.
@@ -298,9 +397,14 @@ def __init__(
298397
along with the range query. Defaults to None.
299398
dtype (str, optional): The dtype of the vector. Defaults to
300399
"float32".
301-
distance_threshold (str, float): The threshold for vector distance.
400+
distance_threshold (float): The threshold for vector distance.
302401
A smaller threshold indicates a stricter semantic search.
303402
Defaults to 0.2.
403+
epsilon (Optional[float]): The relative factor for vector range queries,
404+
setting boundaries for candidates within radius * (1 + epsilon).
405+
This controls how extensive the search is beyond the specified radius.
406+
Higher values increase recall at the expense of performance.
407+
Defaults to None, which uses the index-defined epsilon (typically 0.01).
304408
num_results (int): The MAX number of results to return.
305409
Defaults to 10.
306410
return_score (bool, optional): Whether to return the vector
@@ -312,18 +416,35 @@ def __init__(
312416
in_order (bool): Requires the terms in the field to have
313417
the same order as the terms in the query filter, regardless of
314418
the offsets between them. Defaults to False.
315-
316-
Raises:
317-
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
318-
319-
Note:
320-
Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query
321-
419+
hybrid_policy (Optional[str]): Controls how filters are applied during vector search.
420+
Options are "BATCHES" (paginates through small batches of nearest neighbors) or
421+
"ADHOC_BF" (computes scores for all vectors passing the filter).
422+
"BATCHES" mode is typically faster for queries with selective filters.
423+
"ADHOC_BF" mode is better when filters match a large portion of the dataset.
424+
Defaults to None, which lets Redis auto-select the optimal policy.
425+
batch_size (Optional[int]): When hybrid_policy is "BATCHES", controls the number
426+
of vectors to fetch in each batch. Larger values may improve performance
427+
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
428+
Defaults to None, which lets Redis auto-select an appropriate batch size.
322429
"""
323430
self._vector = vector
324431
self._vector_field_name = vector_field_name
325432
self._dtype = dtype
326433
self._num_results = num_results
434+
self._distance_threshold: float = 0.2 # Initialize with default
435+
self._epsilon: Optional[float] = None
436+
self._hybrid_policy: Optional[str] = None
437+
self._batch_size: Optional[int] = None
438+
439+
if epsilon is not None:
440+
self.set_epsilon(epsilon)
441+
442+
if hybrid_policy is not None:
443+
self.set_hybrid_policy(hybrid_policy)
444+
445+
if batch_size is not None:
446+
self.set_batch_size(batch_size)
447+
327448
self.set_distance_threshold(distance_threshold)
328449
self.set_filter(filter_expression)
329450
query_string = self._build_query_string()
@@ -347,27 +468,104 @@ def __init__(
347468
if in_order:
348469
self.in_order()
349470

471+
def set_distance_threshold(self, distance_threshold: float):
472+
"""Set the distance threshold for the query.
473+
474+
Args:
475+
distance_threshold (float): Vector distance threshold.
476+
477+
Raises:
478+
TypeError: If distance_threshold is not a float or int
479+
ValueError: If distance_threshold is negative
480+
"""
481+
if not isinstance(distance_threshold, (float, int)):
482+
raise TypeError("distance_threshold must be of type float or int")
483+
if distance_threshold < 0:
484+
raise ValueError("distance_threshold must be non-negative")
485+
self._distance_threshold = distance_threshold
486+
487+
# Reset the query string
488+
self._query_string = self._build_query_string()
489+
490+
def set_epsilon(self, epsilon: float):
491+
"""Set the epsilon parameter for the range query.
492+
493+
Args:
494+
epsilon (float): The relative factor for vector range queries,
495+
setting boundaries for candidates within radius * (1 + epsilon).
496+
497+
Raises:
498+
TypeError: If epsilon is not a float or int
499+
ValueError: If epsilon is negative
500+
"""
501+
if not isinstance(epsilon, (float, int)):
502+
raise TypeError("epsilon must be of type float or int")
503+
if epsilon < 0:
504+
raise ValueError("epsilon must be non-negative")
505+
self._epsilon = epsilon
506+
507+
# Reset the query string
508+
self._query_string = self._build_query_string()
509+
510+
def set_hybrid_policy(self, hybrid_policy: str):
511+
"""Set the hybrid policy for the query.
512+
513+
Args:
514+
hybrid_policy (str): The hybrid policy to use. Options are "BATCHES"
515+
or "ADHOC_BF".
516+
517+
Raises:
518+
ValueError: If hybrid_policy is not one of the valid options
519+
"""
520+
if hybrid_policy not in {"BATCHES", "ADHOC_BF"}:
521+
raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}")
522+
self._hybrid_policy = hybrid_policy
523+
524+
# Reset the query string
525+
self._query_string = self._build_query_string()
526+
527+
def set_batch_size(self, batch_size: int):
528+
"""Set the batch size for the query.
529+
530+
Args:
531+
batch_size (int): The batch size to use when hybrid_policy is "BATCHES".
532+
533+
Raises:
534+
TypeError: If batch_size is not an integer
535+
ValueError: If batch_size is not positive
536+
"""
537+
if not isinstance(batch_size, int):
538+
raise TypeError("batch_size must be an integer")
539+
if batch_size <= 0:
540+
raise ValueError("batch_size must be positive")
541+
self._batch_size = batch_size
542+
543+
# Reset the query string
544+
self._query_string = self._build_query_string()
545+
350546
def _build_query_string(self) -> str:
351547
"""Build the full query string for vector range queries with optional filtering"""
548+
# Build base query with vector range only
352549
base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]"
353550

551+
# Build query attributes section
552+
attr_parts = []
553+
attr_parts.append(f"$YIELD_DISTANCE_AS: {self.DISTANCE_ID}")
554+
555+
if self._epsilon is not None:
556+
attr_parts.append(f"$EPSILON: {self._epsilon}")
557+
558+
# Add query attributes section
559+
attr_section = f"=>{{{'; '.join(attr_parts)}}}"
560+
561+
# Add filter expression if present
354562
filter_expression = self._filter_expression
355563
if isinstance(filter_expression, FilterExpression):
356564
filter_expression = str(filter_expression)
357565

358566
if filter_expression == "*":
359-
return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}"
360-
return f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {filter_expression})"
361-
362-
def set_distance_threshold(self, distance_threshold: float):
363-
"""Set the distance threshold for the query.
364-
365-
Args:
366-
distance_threshold (float): vector distance
367-
"""
368-
if not isinstance(distance_threshold, (float, int)):
369-
raise TypeError("distance_threshold must be of type int or float")
370-
self._distance_threshold = distance_threshold
567+
return f"{base_query}{attr_section}"
568+
return f"({base_query}{attr_section} {filter_expression})"
371569

372570
@property
373571
def distance_threshold(self) -> float:
@@ -378,6 +576,33 @@ def distance_threshold(self) -> float:
378576
"""
379577
return self._distance_threshold
380578

579+
@property
580+
def epsilon(self) -> Optional[float]:
581+
"""Return the epsilon for the query.
582+
583+
Returns:
584+
Optional[float]: The epsilon for the query, or None if not set.
585+
"""
586+
return self._epsilon
587+
588+
@property
589+
def hybrid_policy(self) -> Optional[str]:
590+
"""Return the hybrid policy for the query.
591+
592+
Returns:
593+
Optional[str]: The hybrid policy for the query.
594+
"""
595+
return self._hybrid_policy
596+
597+
@property
598+
def batch_size(self) -> Optional[int]:
599+
"""Return the batch size for the query.
600+
601+
Returns:
602+
Optional[int]: The batch size for the query.
603+
"""
604+
return self._batch_size
605+
381606
@property
382607
def params(self) -> Dict[str, Any]:
383608
"""Return the parameters for the query.
@@ -390,11 +615,20 @@ def params(self) -> Dict[str, Any]:
390615
else:
391616
vector_param = array_to_buffer(self._vector, dtype=self._dtype)
392617

393-
return {
618+
params = {
394619
self.VECTOR_PARAM: vector_param,
395620
self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold,
396621
}
397622

623+
# Add hybrid policy and batch size as query parameters (not in query string)
624+
if self._hybrid_policy:
625+
params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy
626+
627+
if self._hybrid_policy == "BATCHES" and self._batch_size:
628+
params[self.BATCH_SIZE_PARAM] = self._batch_size
629+
630+
return params
631+
398632

399633
class RangeQuery(VectorRangeQuery):
400634
# keep for backwards compatibility

0 commit comments

Comments
 (0)