Skip to content

Commit 0398191

Browse files
committed
Adds support for HYBRID_POLICY on KNN queries with filters
1 parent 20c69bd commit 0398191

File tree

3 files changed

+366
-5
lines changed

3 files changed

+366
-5
lines changed

redisvl/query/query.py

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def __init__(
188188
dialect: int = 2,
189189
sort_by: Optional[str] = None,
190190
in_order: bool = False,
191+
hybrid_policy: Optional[str] = None,
192+
batch_size: Optional[int] = None,
191193
):
192194
"""A query for running a vector search along with an optional filter
193195
expression.
@@ -224,6 +226,8 @@ def __init__(
224226
self._vector_field_name = vector_field_name
225227
self._dtype = dtype
226228
self._num_results = num_results
229+
self._hybrid_policy: Optional[str] = None
230+
self._batch_size: Optional[int] = None
227231
self.set_filter(filter_expression)
228232
query_string = self._build_query_string()
229233

@@ -246,12 +250,89 @@ def __init__(
246250
if in_order:
247251
self.in_order()
248252

253+
if hybrid_policy is not None:
254+
self.set_hybrid_policy(hybrid_policy)
255+
256+
if batch_size is not None:
257+
self.set_batch_size(batch_size)
258+
249259
def _build_query_string(self) -> str:
250260
"""Build the full query string for vector search with optional filtering."""
251261
filter_expression = self._filter_expression
252262
if isinstance(filter_expression, FilterExpression):
253263
filter_expression = str(filter_expression)
254-
return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
264+
265+
# Base KNN query
266+
knn_query = (
267+
f"KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM}"
268+
)
269+
270+
# Add hybrid policy parameters if specified
271+
if self._hybrid_policy:
272+
knn_query += f" HYBRID_POLICY {self._hybrid_policy}"
273+
274+
# Add batch size if specified and using BATCHES policy
275+
if self._hybrid_policy == "BATCHES" and self._batch_size:
276+
knn_query += f" BATCH_SIZE {self._batch_size}"
277+
278+
# Add distance field alias
279+
knn_query += f" AS {self.DISTANCE_ID}"
280+
281+
return f"{filter_expression}=>[{knn_query}]"
282+
283+
def set_hybrid_policy(self, hybrid_policy: str):
284+
"""Set the hybrid policy for the query.
285+
286+
Args:
287+
hybrid_policy (str): The hybrid policy to use. Options are "BATCHES"
288+
or "ADHOC_BF".
289+
290+
Raises:
291+
ValueError: If hybrid_policy is not one of the valid options
292+
"""
293+
if hybrid_policy not in {"BATCHES", "ADHOC_BF"}:
294+
raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}")
295+
self._hybrid_policy = hybrid_policy
296+
297+
# Reset the query string
298+
self._query_string = self._build_query_string()
299+
300+
def set_batch_size(self, batch_size: int):
301+
"""Set the batch size for the query.
302+
303+
Args:
304+
batch_size (int): The batch size to use when hybrid_policy is "BATCHES".
305+
306+
Raises:
307+
TypeError: If batch_size is not an integer
308+
ValueError: If batch_size is not positive
309+
"""
310+
if not isinstance(batch_size, int):
311+
raise TypeError("batch_size must be an integer")
312+
if batch_size <= 0:
313+
raise ValueError("batch_size must be positive")
314+
self._batch_size = batch_size
315+
316+
# Reset the query string
317+
self._query_string = self._build_query_string()
318+
319+
@property
320+
def hybrid_policy(self) -> Optional[str]:
321+
"""Return the hybrid policy for the query.
322+
323+
Returns:
324+
Optional[str]: The hybrid policy for the query.
325+
"""
326+
return self._hybrid_policy
327+
328+
@property
329+
def batch_size(self) -> Optional[int]:
330+
"""Return the batch size for the query.
331+
332+
Returns:
333+
Optional[int]: The batch size for the query.
334+
"""
335+
return self._batch_size
255336

256337
@property
257338
def params(self) -> Dict[str, Any]:
@@ -265,7 +346,9 @@ def params(self) -> Dict[str, Any]:
265346
else:
266347
vector = array_to_buffer(self._vector, dtype=self._dtype)
267348

268-
return {self.VECTOR_PARAM: vector}
349+
params = {self.VECTOR_PARAM: vector}
350+
351+
return params
269352

270353

271354
class VectorRangeQuery(BaseVectorQuery, BaseQuery):
@@ -279,6 +362,7 @@ def __init__(
279362
filter_expression: Optional[Union[str, FilterExpression]] = None,
280363
dtype: str = "float32",
281364
distance_threshold: float = 0.2,
365+
epsilon: Optional[float] = None,
282366
num_results: int = 10,
283367
return_score: bool = True,
284368
dialect: int = 2,
@@ -301,6 +385,11 @@ def __init__(
301385
distance_threshold (str, float): The threshold for vector distance.
302386
A smaller threshold indicates a stricter semantic search.
303387
Defaults to 0.2.
388+
epsilon (Optional[float]): The relative factor for vector range queries,
389+
setting boundaries for candidates within radius * (1 + epsilon).
390+
This controls how extensive the search is beyond the specified radius.
391+
Higher values increase recall at the expense of performance.
392+
Defaults to None, which uses the index-defined epsilon (typically 0.01).
304393
num_results (int): The MAX number of results to return.
305394
Defaults to 10.
306395
return_score (bool, optional): Whether to return the vector
@@ -324,6 +413,11 @@ def __init__(
324413
self._vector_field_name = vector_field_name
325414
self._dtype = dtype
326415
self._num_results = num_results
416+
self._epsilon: Optional[float] = None
417+
418+
if epsilon is not None:
419+
self.set_epsilon(epsilon)
420+
327421
self.set_distance_threshold(distance_threshold)
328422
self.set_filter(filter_expression)
329423
query_string = self._build_query_string()
@@ -349,15 +443,20 @@ def __init__(
349443

350444
def _build_query_string(self) -> str:
351445
"""Build the full query string for vector range queries with optional filtering"""
446+
range_params = ""
447+
if self._epsilon is not None:
448+
range_params = f"; $EPSILON: {self._epsilon}"
449+
352450
base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]"
451+
attr_section = f"=>{{$yield_distance_as: {self.DISTANCE_ID}{range_params}}}"
353452

354453
filter_expression = self._filter_expression
355454
if isinstance(filter_expression, FilterExpression):
356455
filter_expression = str(filter_expression)
357456

358457
if filter_expression == "*":
359-
return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}"
360-
return f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {filter_expression})"
458+
return f"{base_query}{attr_section}"
459+
return f"({base_query}{attr_section} {filter_expression})"
361460

362461
def set_distance_threshold(self, distance_threshold: float):
363462
"""Set the distance threshold for the query.
@@ -369,6 +468,35 @@ def set_distance_threshold(self, distance_threshold: float):
369468
raise TypeError("distance_threshold must be of type int or float")
370469
self._distance_threshold = distance_threshold
371470

471+
def set_epsilon(self, epsilon: float):
472+
"""Set the epsilon parameter for the range query.
473+
474+
Args:
475+
epsilon (float): The relative factor for vector range queries,
476+
setting boundaries for candidates within radius * (1 + epsilon).
477+
478+
Raises:
479+
TypeError: If epsilon is not a float or int
480+
ValueError: If epsilon is negative
481+
"""
482+
if not isinstance(epsilon, (float, int)):
483+
raise TypeError("epsilon must be of type float or int")
484+
if epsilon < 0:
485+
raise ValueError("epsilon must be non-negative")
486+
self._epsilon = epsilon
487+
488+
# Reset the query string
489+
self._query_string = self._build_query_string()
490+
491+
@property
492+
def epsilon(self) -> Optional[float]:
493+
"""Return the epsilon for the query.
494+
495+
Returns:
496+
Optional[float]: The epsilon for the query, or None if not set.
497+
"""
498+
return self._epsilon
499+
372500
@property
373501
def distance_threshold(self) -> float:
374502
"""Return the distance threshold for the query.
@@ -390,11 +518,14 @@ def params(self) -> Dict[str, Any]:
390518
else:
391519
vector_param = array_to_buffer(self._vector, dtype=self._dtype)
392520

393-
return {
521+
# Only include the necessary parameters, not EPSILON
522+
params = {
394523
self.VECTOR_PARAM: vector_param,
395524
self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold,
396525
}
397526

527+
return params
528+
398529

399530
class RangeQuery(VectorRangeQuery):
400531
# keep for backwards compatibility

tests/integration/test_query.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from redisvl.index import SearchIndex
55
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
66
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
7+
from redisvl.query.query import VectorRangeQuery
78
from redisvl.redis.utils import array_to_buffer
89

910
# TODO expand to multiple schema types and sync + async
@@ -466,3 +467,93 @@ def test_query_with_chunk_number_zero():
466467
assert (
467468
str(filter_conditions) == expected_query_str
468469
), "Query with chunk_number zero is incorrect"
470+
471+
472+
def test_hybrid_policy_batches_mode(index, vector_query):
473+
"""Test vector query with BATCHES hybrid policy."""
474+
# Create a filter
475+
t = Tag("credit_score") == "high"
476+
477+
# Set hybrid policy to BATCHES
478+
vector_query.set_hybrid_policy("BATCHES")
479+
vector_query.set_batch_size(2)
480+
481+
# Set the filter
482+
vector_query.set_filter(t)
483+
484+
# Check query string
485+
assert "HYBRID_POLICY BATCHES BATCH_SIZE 2" in str(vector_query)
486+
487+
# Execute query
488+
results = index.query(vector_query)
489+
490+
# Check results - should have filtered to "high" credit scores
491+
assert len(results) > 0
492+
for result in results:
493+
assert result["credit_score"] == "high"
494+
495+
496+
def test_hybrid_policy_adhoc_bf_mode(index, vector_query):
497+
"""Test vector query with ADHOC_BF hybrid policy."""
498+
# Create a filter
499+
t = Tag("credit_score") == "high"
500+
501+
# Set hybrid policy to ADHOC_BF
502+
vector_query.set_hybrid_policy("ADHOC_BF")
503+
504+
# Set the filter
505+
vector_query.set_filter(t)
506+
507+
# Check query string
508+
assert "HYBRID_POLICY ADHOC_BF" in str(vector_query)
509+
510+
# Execute query
511+
results = index.query(vector_query)
512+
513+
# Check results - should have filtered to "high" credit scores
514+
assert len(results) > 0
515+
for result in results:
516+
assert result["credit_score"] == "high"
517+
518+
519+
def test_range_query_with_epsilon(index):
520+
"""Test vector range query with epsilon parameter."""
521+
# Create a range query with default epsilon
522+
default_query = VectorRangeQuery(
523+
vector=[0.1, 0.1, 0.5],
524+
vector_field_name="user_embedding",
525+
return_fields=["user", "credit_score", "age", "job"],
526+
distance_threshold=0.2,
527+
)
528+
529+
# Create a range query with custom epsilon
530+
custom_query = VectorRangeQuery(
531+
vector=[0.1, 0.1, 0.5],
532+
vector_field_name="user_embedding",
533+
return_fields=["user", "credit_score", "age", "job"],
534+
distance_threshold=0.2,
535+
epsilon=0.05, # More strict than the default
536+
)
537+
538+
# Verify query string contains epsilon attribute
539+
query_string = str(custom_query)
540+
assert "$EPSILON: 0.05" in query_string
541+
542+
# Verify epsilon property is set
543+
assert custom_query.epsilon == 0.05
544+
545+
# Test setting epsilon
546+
custom_query.set_epsilon(0.1)
547+
assert custom_query.epsilon == 0.1
548+
assert "$EPSILON: 0.1" in str(custom_query)
549+
550+
# Verify the parameters don't include epsilon
551+
assert "EPSILON" not in custom_query.params
552+
553+
# Execute basic query without epsilon to ensure functionality
554+
results = index.query(default_query)
555+
556+
# Check results
557+
assert len(results) > 0
558+
for result in results:
559+
assert float(result["vector_distance"]) <= 0.2

0 commit comments

Comments
 (0)