Skip to content

Commit e2ac342

Browse files
committed
Build query strings dynamically
1 parent 2210d7f commit e2ac342

File tree

2 files changed

+82
-50
lines changed

2 files changed

+82
-50
lines changed

redisvl/query/query.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ class BaseQuery(RedisQuery):
1414
Base query class used to subclass many query types.
1515
1616
NOTE: In the base class, the `_query_string` field is set once on
17-
initialization, and afterward, the redis-py codebase expects to be able to
18-
read it. By contrast, our query subclasses allow users to call methods that
19-
alter the query string at runtime. To avoid having to rebuild
20-
`_query_string` every time one of these methods is called, we lazily build
21-
the query string when a user calls `query()` or accesses the property
22-
`_query_string`, when the underlying `__query_string` field is None. Any
23-
method that alters the query string should set `__query_string` to None so
24-
that the next time the query string is accessed, it is rebuilt.
17+
initialization, and afterward, redis-py expects to be able to read it. By
18+
contrast, our query subclasses allow users to call methods that alter the
19+
query string at runtime. To avoid having to rebuild `_query_string` every
20+
time one of these methods is called, we lazily build the query string when a
21+
user calls `query()` or accesses the property `_query_string`, when the
22+
underlying `_built_query_string` field is None. Any method that alters the query
23+
string should set `_built_query_string` to None so that the next time the query
24+
string is accessed, it is rebuilt.
2525
"""
2626

2727
_params: Dict[str, Any] = {}
2828
_filter_expression: Union[str, FilterExpression] = FilterExpression("*")
29-
__query_string: Optional[str] = None
29+
_built_query_string: Optional[str] = None
3030

3131
def __init__(self, query_string: str = "*"):
3232
"""
@@ -35,8 +35,15 @@ def __init__(self, query_string: str = "*"):
3535
Args:
3636
query_string (str, optional): The query string to use. Defaults to '*'.
3737
"""
38+
# The parent class expects a query string, so we pass it in, but we'll
39+
# actually manage building it dynamically.
3840
super().__init__(query_string)
3941

42+
# This is a private field that we use to track whether the query string
43+
# has been built, and we set it to None here to indicate that the field
44+
# has not been built yet.
45+
self._built_query_string = None
46+
4047
def __str__(self) -> str:
4148
"""Return the string representation of the query."""
4249
return " ".join([str(x) for x in self.get_args()])
@@ -68,7 +75,7 @@ def set_filter(
6875
)
6976

7077
# Invalidate the query string
71-
self.__query_string = None
78+
self._built_query_string = None
7279

7380
@property
7481
def filter(self) -> Union[str, FilterExpression]:
@@ -88,14 +95,14 @@ def params(self) -> Dict[str, Any]:
8895
@property
8996
def _query_string(self) -> str:
9097
"""Maintains compatibility with parent class while providing lazy loading."""
91-
if self.__query_string is None:
92-
self.__query_string = self._build_query_string()
93-
return self.__query_string
98+
if self._built_query_string is None:
99+
self._built_query_string = self._build_query_string()
100+
return self._built_query_string
94101

95102
@_query_string.setter
96103
def _query_string(self, value: Optional[str]):
97104
"""Setter for _query_string to maintain compatibility with parent class."""
98-
self.__query_string = value
105+
self._built_query_string = value
99106

100107

101108
class FilterQuery(BaseQuery):
@@ -134,7 +141,7 @@ def __init__(
134141

135142
# Initialize the base query with the query string from the property
136143
super().__init__("*")
137-
self.__query_string = None # Ensure it's invalidated after initialization
144+
self._built_query_string = None # Ensure it's invalidated after initialization
138145

139146
# Handle query settings
140147
if return_fields:
@@ -188,7 +195,7 @@ def __init__(
188195

189196
# Initialize the base query with the query string from the property
190197
super().__init__("*")
191-
self.__query_string = None
198+
self._built_query_string = None
192199

193200
# Query specific modifications
194201
self.no_content().paging(0, 0).dialect(dialect)
@@ -203,7 +210,8 @@ def _build_query_string(self) -> str:
203210
class BaseVectorQuery:
204211
DISTANCE_ID: str = "vector_distance"
205212
VECTOR_PARAM: str = "vector"
206-
EF_RUNTIME_PARAM: str = "EF_RUNTIME"
213+
EF_RUNTIME: str = "EF_RUNTIME"
214+
EF_RUNTIME_PARAM: str = "EF"
207215

208216
_normalize_vector_distance: bool = False
209217

@@ -296,7 +304,7 @@ def __init__(
296304

297305
# Initialize the base query
298306
super().__init__("*")
299-
self.__query_string = None
307+
self._built_query_string = None
300308

301309
# Handle query modifiers
302310
if return_fields:
@@ -345,7 +353,7 @@ def _build_query_string(self) -> str:
345353

346354
# Add EF_RUNTIME parameter if specified
347355
if self._ef_runtime:
348-
knn_query += f" {self.EF_RUNTIME_PARAM} {self._ef_runtime}"
356+
knn_query += f" {self.EF_RUNTIME} ${self.EF_RUNTIME_PARAM}"
349357

350358
# Add distance field alias
351359
knn_query += f" AS {self.DISTANCE_ID}"
@@ -370,7 +378,7 @@ def set_hybrid_policy(self, hybrid_policy: str):
370378
)
371379

372380
# Invalidate the query string
373-
self.__query_string = None
381+
self._built_query_string = None
374382

375383
def set_batch_size(self, batch_size: int):
376384
"""Set the batch size for the query.
@@ -389,7 +397,7 @@ def set_batch_size(self, batch_size: int):
389397
self._batch_size = batch_size
390398

391399
# Invalidate the query string
392-
self.__query_string = None
400+
self._built_query_string = None
393401

394402
def set_ef_runtime(self, ef_runtime: int):
395403
"""Set the EF_RUNTIME parameter for the query.
@@ -409,7 +417,7 @@ def set_ef_runtime(self, ef_runtime: int):
409417
self._ef_runtime = ef_runtime
410418

411419
# Invalidate the query string
412-
self.__query_string = None
420+
self._built_query_string = None
413421

414422
@property
415423
def hybrid_policy(self) -> Optional[str]:
@@ -556,7 +564,7 @@ def __init__(
556564

557565
# Initialize the base query
558566
super().__init__("*")
559-
self.__query_string = None
567+
self._built_query_string = None
560568

561569
if epsilon is not None:
562570
self.set_epsilon(epsilon)
@@ -615,7 +623,7 @@ def set_distance_threshold(self, distance_threshold: float):
615623
self._distance_threshold = distance_threshold
616624

617625
# Invalidate the query string
618-
self.__query_string = None
626+
self._built_query_string = None
619627

620628
def set_epsilon(self, epsilon: float):
621629
"""Set the epsilon parameter for the range query.
@@ -635,7 +643,7 @@ def set_epsilon(self, epsilon: float):
635643
self._epsilon = epsilon
636644

637645
# Invalidate the query string
638-
self.__query_string = None
646+
self._built_query_string = None
639647

640648
def set_hybrid_policy(self, hybrid_policy: str):
641649
"""Set the hybrid policy for the query.
@@ -655,7 +663,7 @@ def set_hybrid_policy(self, hybrid_policy: str):
655663
)
656664

657665
# Invalidate the query string
658-
self.__query_string = None
666+
self._built_query_string = None
659667

660668
def set_batch_size(self, batch_size: int):
661669
"""Set the batch size for the query.
@@ -674,7 +682,7 @@ def set_batch_size(self, batch_size: int):
674682
self._batch_size = batch_size
675683

676684
# Invalidate the query string
677-
self.__query_string = None
685+
self._built_query_string = None
678686

679687
def set_ef_runtime(self, ef_runtime: int):
680688
"""Set the EF_RUNTIME parameter for the query.
@@ -694,7 +702,7 @@ def set_ef_runtime(self, ef_runtime: int):
694702
self._ef_runtime = ef_runtime
695703

696704
# Invalidate the query string
697-
self.__query_string = None
705+
self._built_query_string = None
698706

699707
def _build_query_string(self) -> str:
700708
"""Build the full query string for vector range queries with optional filtering"""
@@ -708,6 +716,9 @@ def _build_query_string(self) -> str:
708716
if self._epsilon is not None:
709717
attr_parts.append(f"$EPSILON: {self._epsilon}")
710718

719+
if self._ef_runtime is not None:
720+
attr_parts.append(f"$EF_RUNTIME: {self._ef_runtime}")
721+
711722
# Add query attributes section
712723
attr_section = f"=>{{{'; '.join(attr_parts)}}}"
713724

@@ -782,10 +793,6 @@ def params(self) -> Dict[str, Any]:
782793
self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold,
783794
}
784795

785-
# Add EPSILON parameter if specified
786-
if self._epsilon is not None:
787-
params[self.EPSILON_PARAM] = self._epsilon
788-
789796
# Add hybrid policy and batch size as query parameters (not in query string)
790797
if self._hybrid_policy is not None:
791798
params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy.value
@@ -795,10 +802,6 @@ def params(self) -> Dict[str, Any]:
795802
):
796803
params[self.BATCH_SIZE_PARAM] = self._batch_size
797804

798-
# Add EF_RUNTIME parameter if specified
799-
if self._ef_runtime is not None:
800-
params[self.EF_RUNTIME_PARAM] = self._ef_runtime
801-
802805
return params
803806

804807

@@ -894,7 +897,7 @@ def __init__(
894897

895898
# Initialize the base query
896899
super().__init__("*")
897-
self.__query_string = None
900+
self._built_query_string = None
898901

899902
# handle query settings
900903
self.scorer(text_scorer)

tests/unit/test_query_types.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def test_vector_range_query_construction():
535535
assert "$YIELD_DISTANCE_AS: vector_distance" in query_string
536536
assert "$EPSILON: 0.05" in query_string
537537
assert epsilon_query.epsilon == 0.05
538-
assert epsilon_query.params.get("EPSILON") == 0.05
538+
assert "EPSILON" not in epsilon_query.params
539539

540540
# Range query with hybrid policy
541541
hybrid_query = VectorRangeQuery(
@@ -571,6 +571,20 @@ def test_vector_range_query_construction():
571571
assert batch_query.params["HYBRID_POLICY"] == "BATCHES"
572572
assert batch_query.params["BATCH_SIZE"] == 50
573573

574+
# Range query with ef_runtime
575+
ef_runtime_query = VectorRangeQuery(
576+
vector=[0.1, 0.1, 0.5],
577+
vector_field_name="user_embedding",
578+
return_fields=["user", "credit_score"],
579+
distance_threshold=0.2,
580+
ef_runtime=100,
581+
)
582+
583+
# EF_RUNTIME should be in the query string, not params
584+
query_string = str(ef_runtime_query)
585+
assert "$EF_RUNTIME: 100" in query_string
586+
assert "EF_RUNTIME" not in ef_runtime_query.params
587+
574588

575589
def test_vector_range_query_setter_methods():
576590
"""Unit test: Test setter methods for VectorRangeQuery parameters."""
@@ -594,6 +608,11 @@ def test_vector_range_query_setter_methods():
594608
assert query.epsilon == 0.1
595609
assert "$EPSILON: 0.1" in str(query)
596610

611+
# Set ef_runtime
612+
query.set_ef_runtime(100)
613+
assert query.ef_runtime == 100
614+
assert "$EF_RUNTIME: 100" in str(query)
615+
597616
# Set hybrid policy
598617
query.set_hybrid_policy("BATCHES")
599618
assert query.hybrid_policy == "BATCHES"
@@ -635,6 +654,13 @@ def test_vector_range_query_error_handling():
635654
with pytest.raises(ValueError, match="batch_size must be positive"):
636655
query.set_batch_size(-10)
637656

657+
# Test invalid ef_runtime
658+
with pytest.raises(TypeError, match="ef_runtime must be an integer"):
659+
query.set_ef_runtime("hey") # type: ignore
660+
661+
with pytest.raises(ValueError, match="ef_runtime must be positive"):
662+
query.set_ef_runtime(-10)
663+
638664

639665
def test_vector_query_ef_runtime():
640666
"""Test that VectorQuery correctly handles EF_RUNTIME parameter."""
@@ -646,10 +672,10 @@ def test_vector_query_ef_runtime():
646672

647673
# Check query string
648674
query_string = str(vector_query)
649-
assert "EF_RUNTIME 100" in query_string
675+
assert f"{VectorQuery.EF_RUNTIME} ${VectorQuery.EF_RUNTIME_PARAM}" in query_string
650676

651677
# Check params dictionary
652-
assert vector_query.params.get("EF_RUNTIME") == 100
678+
assert vector_query.params.get(VectorQuery.EF_RUNTIME_PARAM) == 100
653679

654680
# Test with different value
655681
vector_query = VectorQuery([0.1, 0.2, 0.3, 0.4], "vector_field", ef_runtime=50)
@@ -659,7 +685,7 @@ def test_vector_query_ef_runtime():
659685

660686
# Check query string
661687
query_string = str(vector_query)
662-
assert "EF_RUNTIME 50" in query_string
688+
assert "EF_RUNTIME $EF" in query_string
663689

664690

665691
def test_vector_query_set_ef_runtime():
@@ -669,7 +695,9 @@ def test_vector_query_set_ef_runtime():
669695

670696
# Initially no ef_runtime
671697
assert vector_query.ef_runtime is None
672-
assert "EF_RUNTIME" not in str(vector_query)
698+
assert f"{VectorQuery.EF_RUNTIME} ${VectorQuery.EF_RUNTIME_PARAM}" not in str(
699+
vector_query
700+
)
673701

674702
# Set ef_runtime
675703
vector_query.set_ef_runtime(200)
@@ -679,17 +707,17 @@ def test_vector_query_set_ef_runtime():
679707

680708
# Check query string
681709
query_string = str(vector_query)
682-
assert "EF_RUNTIME 200" in query_string
710+
assert f"{VectorQuery.EF_RUNTIME} ${VectorQuery.EF_RUNTIME_PARAM}" in query_string
683711

684712
# Check params dictionary
685-
assert vector_query.params.get("EF_RUNTIME") == 200
713+
assert vector_query.params.get(VectorQuery.EF_RUNTIME_PARAM) == 200
686714

687715

688716
def test_vector_query_invalid_ef_runtime():
689717
"""Test error handling for invalid EF_RUNTIME values."""
690718
# Test with invalid ef_runtime type
691719
with pytest.raises(TypeError, match="ef_runtime must be an integer"):
692-
VectorQuery([0.1, 0.2, 0.3, 0.4], "vector_field", ef_runtime="100")
720+
VectorQuery([0.1, 0.2, 0.3, 0.4], "vector_field", ef_runtime="hey") # type: ignore
693721

694722
# Test with invalid ef_runtime value
695723
with pytest.raises(ValueError, match="ef_runtime must be positive"):
@@ -703,7 +731,7 @@ def test_vector_query_invalid_ef_runtime():
703731

704732
# Test with invalid ef_runtime type
705733
with pytest.raises(TypeError, match="ef_runtime must be an integer"):
706-
vector_query.set_ef_runtime("100")
734+
vector_query.set_ef_runtime("hey") # type: ignore
707735

708736
# Test with invalid ef_runtime value
709737
with pytest.raises(ValueError, match="ef_runtime must be positive"):
@@ -724,11 +752,12 @@ def test_vector_range_query_ef_runtime():
724752
assert range_query.ef_runtime == 100
725753
assert range_query.distance_threshold == 0.3
726754

727-
# EF_RUNTIME should be in params but not in query string (like hybrid_policy)
728-
assert "EF_RUNTIME" not in str(range_query)
729-
assert range_query.params.get("EF_RUNTIME") == 100
755+
# EF_RUNTIME should not be in params and should be in query string
756+
assert f"${VectorRangeQuery.EF_RUNTIME}: 100" in str(range_query)
757+
assert VectorRangeQuery.EF_RUNTIME_PARAM not in range_query.params
730758

731759
# Test setting ef_runtime
732760
range_query.set_ef_runtime(150)
733761
assert range_query.ef_runtime == 150
734-
assert range_query.params.get("EF_RUNTIME") == 150
762+
assert f"${VectorRangeQuery.EF_RUNTIME}: 150" in str(range_query)
763+
assert VectorRangeQuery.EF_RUNTIME_PARAM not in range_query.params

0 commit comments

Comments
 (0)