1+ from enum import Enum
12from typing import Any , Dict , List , Optional , Union
23
34from redis .commands .search .query import Query as RedisQuery
@@ -175,6 +176,13 @@ class BaseVectorQuery:
175176 VECTOR_PARAM : str = "vector"
176177
177178
179+ class HybridPolicy (str , Enum ):
180+ """Enum for valid hybrid policy options in vector queries."""
181+
182+ BATCHES = "BATCHES"
183+ ADHOC_BF = "ADHOC_BF"
184+
185+
178186class VectorQuery (BaseVectorQuery , BaseQuery ):
179187 def __init__ (
180188 self ,
@@ -236,7 +244,7 @@ def __init__(
236244 self ._vector_field_name = vector_field_name
237245 self ._dtype = dtype
238246 self ._num_results = num_results
239- self ._hybrid_policy : Optional [str ] = None
247+ self ._hybrid_policy : Optional [HybridPolicy ] = None
240248 self ._batch_size : Optional [int ] = None
241249 self .set_filter (filter_expression )
242250 query_string = self ._build_query_string ()
@@ -279,10 +287,10 @@ def _build_query_string(self) -> str:
279287
280288 # Add hybrid policy parameters if specified
281289 if self ._hybrid_policy :
282- knn_query += f" HYBRID_POLICY { self ._hybrid_policy } "
290+ knn_query += f" HYBRID_POLICY { self ._hybrid_policy . value } "
283291
284292 # Add batch size if specified and using BATCHES policy
285- if self ._hybrid_policy == " BATCHES" and self ._batch_size :
293+ if self ._hybrid_policy == HybridPolicy . BATCHES and self ._batch_size :
286294 knn_query += f" BATCH_SIZE { self ._batch_size } "
287295
288296 # Add distance field alias
@@ -300,9 +308,12 @@ def set_hybrid_policy(self, hybrid_policy: str):
300308 Raises:
301309 ValueError: If hybrid_policy is not one of the valid options
302310 """
303- if hybrid_policy not in {"BATCHES" , "ADHOC_BF" }:
304- raise ValueError ("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}" )
305- self ._hybrid_policy = hybrid_policy
311+ try :
312+ self ._hybrid_policy = HybridPolicy (hybrid_policy )
313+ except ValueError :
314+ raise ValueError (
315+ f"hybrid_policy must be one of { ', ' .join ([p .value for p in HybridPolicy ])} "
316+ )
306317
307318 # Reset the query string
308319 self ._query_string = self ._build_query_string ()
@@ -333,7 +344,7 @@ def hybrid_policy(self) -> Optional[str]:
333344 Returns:
334345 Optional[str]: The hybrid policy for the query.
335346 """
336- return self ._hybrid_policy
347+ return self ._hybrid_policy . value if self . _hybrid_policy else None
337348
338349 @property
339350 def batch_size (self ) -> Optional [int ]:
@@ -433,7 +444,7 @@ def __init__(
433444 self ._num_results = num_results
434445 self ._distance_threshold : float = 0.2 # Initialize with default
435446 self ._epsilon : Optional [float ] = None
436- self ._hybrid_policy : Optional [str ] = None
447+ self ._hybrid_policy : Optional [HybridPolicy ] = None
437448 self ._batch_size : Optional [int ] = None
438449
439450 if epsilon is not None :
@@ -517,9 +528,12 @@ def set_hybrid_policy(self, hybrid_policy: str):
517528 Raises:
518529 ValueError: If hybrid_policy is not one of the valid options
519530 """
520- if hybrid_policy not in {"BATCHES" , "ADHOC_BF" }:
521- raise ValueError ("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}" )
522- self ._hybrid_policy = hybrid_policy
531+ try :
532+ self ._hybrid_policy = HybridPolicy (hybrid_policy )
533+ except ValueError :
534+ raise ValueError (
535+ f"hybrid_policy must be one of { ', ' .join ([p .value for p in HybridPolicy ])} "
536+ )
523537
524538 # Reset the query string
525539 self ._query_string = self ._build_query_string ()
@@ -592,7 +606,7 @@ def hybrid_policy(self) -> Optional[str]:
592606 Returns:
593607 Optional[str]: The hybrid policy for the query.
594608 """
595- return self ._hybrid_policy
609+ return self ._hybrid_policy . value if self . _hybrid_policy else None
596610
597611 @property
598612 def batch_size (self ) -> Optional [int ]:
@@ -622,9 +636,9 @@ def params(self) -> Dict[str, Any]:
622636
623637 # Add hybrid policy and batch size as query parameters (not in query string)
624638 if self ._hybrid_policy :
625- params [self .HYBRID_POLICY_PARAM ] = self ._hybrid_policy
639+ params [self .HYBRID_POLICY_PARAM ] = self ._hybrid_policy . value
626640
627- if self ._hybrid_policy == " BATCHES" and self ._batch_size :
641+ if self ._hybrid_policy == HybridPolicy . BATCHES and self ._batch_size :
628642 params [self .BATCH_SIZE_PARAM ] = self ._batch_size
629643
630644 return params
0 commit comments