Skip to content

Commit cba2998

Browse files
use enum for hybrid policy
1 parent 8d63bc4 commit cba2998

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

redisvl/query/query.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import Any, Dict, List, Optional, Union
23

34
from redis.commands.search.query import Query as RedisQuery
@@ -175,6 +176,13 @@ class BaseVectorQuery:
175176
VECTOR_PARAM: str = "vector"
176177

177178

179+
class HybridPolicy(str, Enum):
180+
"""Enum for valid hybrid policy options in vector queries."""
181+
182+
BATCHES = "BATCHES"
183+
ADHOC_BF = "ADHOC_BF"
184+
185+
178186
class VectorQuery(BaseVectorQuery, BaseQuery):
179187
def __init__(
180188
self,
@@ -236,7 +244,7 @@ def __init__(
236244
self._vector_field_name = vector_field_name
237245
self._dtype = dtype
238246
self._num_results = num_results
239-
self._hybrid_policy: Optional[str] = None
247+
self._hybrid_policy: Optional[HybridPolicy] = None
240248
self._batch_size: Optional[int] = None
241249
self.set_filter(filter_expression)
242250
query_string = self._build_query_string()
@@ -279,10 +287,10 @@ def _build_query_string(self) -> str:
279287

280288
# Add hybrid policy parameters if specified
281289
if self._hybrid_policy:
282-
knn_query += f" HYBRID_POLICY {self._hybrid_policy}"
290+
knn_query += f" HYBRID_POLICY {self._hybrid_policy.value}"
283291

284292
# Add batch size if specified and using BATCHES policy
285-
if self._hybrid_policy == "BATCHES" and self._batch_size:
293+
if self._hybrid_policy == HybridPolicy.BATCHES and self._batch_size:
286294
knn_query += f" BATCH_SIZE {self._batch_size}"
287295

288296
# Add distance field alias
@@ -300,9 +308,12 @@ def set_hybrid_policy(self, hybrid_policy: str):
300308
Raises:
301309
ValueError: If hybrid_policy is not one of the valid options
302310
"""
303-
if hybrid_policy not in {"BATCHES", "ADHOC_BF"}:
304-
raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}")
305-
self._hybrid_policy = hybrid_policy
311+
try:
312+
self._hybrid_policy = HybridPolicy(hybrid_policy)
313+
except ValueError:
314+
raise ValueError(
315+
f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}"
316+
)
306317

307318
# Reset the query string
308319
self._query_string = self._build_query_string()
@@ -333,7 +344,7 @@ def hybrid_policy(self) -> Optional[str]:
333344
Returns:
334345
Optional[str]: The hybrid policy for the query.
335346
"""
336-
return self._hybrid_policy
347+
return self._hybrid_policy.value if self._hybrid_policy else None
337348

338349
@property
339350
def batch_size(self) -> Optional[int]:
@@ -433,7 +444,7 @@ def __init__(
433444
self._num_results = num_results
434445
self._distance_threshold: float = 0.2 # Initialize with default
435446
self._epsilon: Optional[float] = None
436-
self._hybrid_policy: Optional[str] = None
447+
self._hybrid_policy: Optional[HybridPolicy] = None
437448
self._batch_size: Optional[int] = None
438449

439450
if epsilon is not None:
@@ -517,9 +528,12 @@ def set_hybrid_policy(self, hybrid_policy: str):
517528
Raises:
518529
ValueError: If hybrid_policy is not one of the valid options
519530
"""
520-
if hybrid_policy not in {"BATCHES", "ADHOC_BF"}:
521-
raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}")
522-
self._hybrid_policy = hybrid_policy
531+
try:
532+
self._hybrid_policy = HybridPolicy(hybrid_policy)
533+
except ValueError:
534+
raise ValueError(
535+
f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}"
536+
)
523537

524538
# Reset the query string
525539
self._query_string = self._build_query_string()
@@ -592,7 +606,7 @@ def hybrid_policy(self) -> Optional[str]:
592606
Returns:
593607
Optional[str]: The hybrid policy for the query.
594608
"""
595-
return self._hybrid_policy
609+
return self._hybrid_policy.value if self._hybrid_policy else None
596610

597611
@property
598612
def batch_size(self) -> Optional[int]:
@@ -622,9 +636,9 @@ def params(self) -> Dict[str, Any]:
622636

623637
# Add hybrid policy and batch size as query parameters (not in query string)
624638
if self._hybrid_policy:
625-
params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy
639+
params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy.value
626640

627-
if self._hybrid_policy == "BATCHES" and self._batch_size:
641+
if self._hybrid_policy == HybridPolicy.BATCHES and self._batch_size:
628642
params[self.BATCH_SIZE_PARAM] = self._batch_size
629643

630644
return params

0 commit comments

Comments
 (0)