Skip to content

Commit 2210d7f

Browse files
committed
Lazily construct the query string
1 parent b4b875e commit 2210d7f

File tree

1 file changed

+65
-36
lines changed

1 file changed

+65
-36
lines changed

redisvl/query/query.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,23 @@
1010

1111

1212
class BaseQuery(RedisQuery):
13-
"""Base query class used to subclass many query types."""
13+
"""
14+
Base query class used to subclass many query types.
15+
16+
NOTE: In the base class, the `_query_string` field is set once on
17+
initialization, and afterward, the redis-py codebase expects to be able to
18+
read it. By contrast, our query subclasses allow users to call methods that
19+
alter the query string at runtime. To avoid having to rebuild
20+
`_query_string` every time one of these methods is called, we lazily build
21+
the query string when a user calls `query()` or accesses the property
22+
`_query_string`, when the underlying `__query_string` field is None. Any
23+
method that alters the query string should set `__query_string` to None so
24+
that the next time the query string is accessed, it is rebuilt.
25+
"""
1426

1527
_params: Dict[str, Any] = {}
1628
_filter_expression: Union[str, FilterExpression] = FilterExpression("*")
29+
__query_string: Optional[str] = None
1730

1831
def __init__(self, query_string: str = "*"):
1932
"""
@@ -54,8 +67,8 @@ def set_filter(
5467
"filter_expression must be of type FilterExpression or string or None"
5568
)
5669

57-
# Reset the query string
58-
self._query_string = self._build_query_string()
70+
# Invalidate the query string
71+
self.__query_string = None
5972

6073
@property
6174
def filter(self) -> Union[str, FilterExpression]:
@@ -72,6 +85,18 @@ def params(self) -> Dict[str, Any]:
7285
"""Return the query parameters."""
7386
return self._params
7487

88+
@property
89+
def _query_string(self) -> str:
90+
"""Maintains compatibility with parent class while providing lazy loading."""
91+
if self.__query_string is None:
92+
self.__query_string = self._build_query_string()
93+
return self.__query_string
94+
95+
@_query_string.setter
96+
def _query_string(self, value: Optional[str]):
97+
"""Setter for _query_string to maintain compatibility with parent class."""
98+
self.__query_string = value
99+
75100

76101
class FilterQuery(BaseQuery):
77102
def __init__(
@@ -107,9 +132,9 @@ def __init__(
107132

108133
self._num_results = num_results
109134

110-
# Initialize the base query with the full query string constructed from the filter expression
111-
query_string = self._build_query_string()
112-
super().__init__(query_string)
135+
# Initialize the base query with the query string from the property
136+
super().__init__("*")
137+
self.__query_string = None # Ensure it's invalidated after initialization
113138

114139
# Handle query settings
115140
if return_fields:
@@ -161,9 +186,9 @@ def __init__(
161186
if params:
162187
self._params = params
163188

164-
# Initialize the base query with the full query string constructed from the filter expression
165-
query_string = self._build_query_string()
166-
super().__init__(query_string)
189+
# Initialize the base query with the query string from the property
190+
super().__init__("*")
191+
self.__query_string = None
167192

168193
# Query specific modifications
169194
self.no_content().paging(0, 0).dialect(dialect)
@@ -268,9 +293,10 @@ def __init__(
268293
self._ef_runtime: Optional[int] = None
269294
self._normalize_vector_distance = normalize_vector_distance
270295
self.set_filter(filter_expression)
271-
query_string = self._build_query_string()
272296

273-
super().__init__(query_string)
297+
# Initialize the base query
298+
super().__init__("*")
299+
self.__query_string = None
274300

275301
# Handle query modifiers
276302
if return_fields:
@@ -343,8 +369,8 @@ def set_hybrid_policy(self, hybrid_policy: str):
343369
f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}"
344370
)
345371

346-
# Reset the query string
347-
self._query_string = self._build_query_string()
372+
# Invalidate the query string
373+
self.__query_string = None
348374

349375
def set_batch_size(self, batch_size: int):
350376
"""Set the batch size for the query.
@@ -362,8 +388,8 @@ def set_batch_size(self, batch_size: int):
362388
raise ValueError("batch_size must be positive")
363389
self._batch_size = batch_size
364390

365-
# Reset the query string
366-
self._query_string = self._build_query_string()
391+
# Invalidate the query string
392+
self.__query_string = None
367393

368394
def set_ef_runtime(self, ef_runtime: int):
369395
"""Set the EF_RUNTIME parameter for the query.
@@ -382,8 +408,8 @@ def set_ef_runtime(self, ef_runtime: int):
382408
raise ValueError("ef_runtime must be positive")
383409
self._ef_runtime = ef_runtime
384410

385-
# Reset the query string
386-
self._query_string = self._build_query_string()
411+
# Invalidate the query string
412+
self.__query_string = None
387413

388414
@property
389415
def hybrid_policy(self) -> Optional[str]:
@@ -526,6 +552,11 @@ def __init__(
526552
self._hybrid_policy: Optional[HybridPolicy] = None
527553
self._batch_size: Optional[int] = None
528554
self._ef_runtime: Optional[int] = None
555+
self._normalize_vector_distance = normalize_vector_distance
556+
557+
# Initialize the base query
558+
super().__init__("*")
559+
self.__query_string = None
529560

530561
if epsilon is not None:
531562
self.set_epsilon(epsilon)
@@ -539,12 +570,8 @@ def __init__(
539570
if ef_runtime is not None:
540571
self.set_ef_runtime(ef_runtime)
541572

542-
self._normalize_vector_distance = normalize_vector_distance
543573
self.set_distance_threshold(distance_threshold)
544574
self.set_filter(filter_expression)
545-
query_string = self._build_query_string()
546-
547-
super().__init__(query_string)
548575

549576
# Handle query modifiers
550577
if return_fields:
@@ -587,8 +614,8 @@ def set_distance_threshold(self, distance_threshold: float):
587614
distance_threshold = denorm_cosine_distance(distance_threshold)
588615
self._distance_threshold = distance_threshold
589616

590-
# Reset the query string
591-
self._query_string = self._build_query_string()
617+
# Invalidate the query string
618+
self.__query_string = None
592619

593620
def set_epsilon(self, epsilon: float):
594621
"""Set the epsilon parameter for the range query.
@@ -607,8 +634,8 @@ def set_epsilon(self, epsilon: float):
607634
raise ValueError("epsilon must be non-negative")
608635
self._epsilon = epsilon
609636

610-
# Reset the query string
611-
self._query_string = self._build_query_string()
637+
# Invalidate the query string
638+
self.__query_string = None
612639

613640
def set_hybrid_policy(self, hybrid_policy: str):
614641
"""Set the hybrid policy for the query.
@@ -627,8 +654,8 @@ def set_hybrid_policy(self, hybrid_policy: str):
627654
f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}"
628655
)
629656

630-
# Reset the query string
631-
self._query_string = self._build_query_string()
657+
# Invalidate the query string
658+
self.__query_string = None
632659

633660
def set_batch_size(self, batch_size: int):
634661
"""Set the batch size for the query.
@@ -646,8 +673,8 @@ def set_batch_size(self, batch_size: int):
646673
raise ValueError("batch_size must be positive")
647674
self._batch_size = batch_size
648675

649-
# Reset the query string
650-
self._query_string = self._build_query_string()
676+
# Invalidate the query string
677+
self.__query_string = None
651678

652679
def set_ef_runtime(self, ef_runtime: int):
653680
"""Set the EF_RUNTIME parameter for the query.
@@ -666,8 +693,8 @@ def set_ef_runtime(self, ef_runtime: int):
666693
raise ValueError("ef_runtime must be positive")
667694
self._ef_runtime = ef_runtime
668695

669-
# Reset the query string
670-
self._query_string = self._build_query_string()
696+
# Invalidate the query string
697+
self.__query_string = None
671698

672699
def _build_query_string(self) -> str:
673700
"""Build the full query string for vector range queries with optional filtering"""
@@ -856,7 +883,7 @@ def __init__(
856883
TypeError: If stopwords is not a valid iterable set of strings.
857884
"""
858885
self._text = text
859-
self._text_field = text_field_name
886+
self._text_field_name = text_field_name
860887
self._num_results = num_results
861888

862889
self._set_stopwords(stopwords)
@@ -865,9 +892,9 @@ def __init__(
865892
if params:
866893
self._params = params
867894

868-
# initialize the base query with the full query string and filter expression
869-
query_string = self._build_query_string()
870-
super().__init__(query_string)
895+
# Initialize the base query
896+
super().__init__("*")
897+
self.__query_string = None
871898

872899
# handle query settings
873900
self.scorer(text_scorer)
@@ -953,7 +980,9 @@ def _build_query_string(self) -> str:
953980
else:
954981
filter_expression = ""
955982

956-
text = f"@{self._text_field}:({self._tokenize_and_escape_query(self._text)})"
983+
text = (
984+
f"@{self._text_field_name}:({self._tokenize_and_escape_query(self._text)})"
985+
)
957986
if filter_expression and filter_expression != "*":
958987
text += f" AND {filter_expression}"
959988
return text

0 commit comments

Comments
 (0)