Skip to content

Commit 09b898e

Browse files
committed
Use parameter substitution for vector data in hybrid query
1 parent b5ba960 commit 09b898e

File tree

3 files changed

+55
-19
lines changed

3 files changed

+55
-19
lines changed

redisvl/index/index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,7 @@ def _hybrid_search(self, query: HybridQuery, **kwargs) -> List[Dict[str, Any]]:
10641064
if query.postprocessing_config.build_args()
10651065
else None
10661066
),
1067+
params_substitution=query.params, # type: ignore[arg-type]
10671068
**kwargs,
10681069
) # type: ignore
10691070
return [convert_bytes(r) for r in results.results] # type: ignore[union-attr]
@@ -1938,6 +1939,7 @@ async def _hybrid_search(
19381939
if query.postprocessing_config.build_args()
19391940
else None
19401941
),
1942+
params_substitution=query.params, # type: ignore[arg-type]
19411943
**kwargs,
19421944
) # type: ignore
19431945
return [convert_bytes(r) for r in results.results] # type: ignore[union-attr]

redisvl/query/hybrid.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
text_field_name: str,
5050
vector: Union[bytes, List[float]],
5151
vector_field_name: str,
52+
vector_param_name: str = "vector",
5253
text_scorer: str = "BM25STD",
5354
yield_text_score_as: Optional[str] = None,
5455
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
@@ -76,6 +77,7 @@ def __init__(
7677
text_field_name: The text field name to search in.
7778
vector: The vector to perform vector similarity search.
7879
vector_field_name: The vector field name to search in.
80+
vector_param_name: The name of the parameter substitution containing the vector blob.
7981
text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
8082
BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more
8183
information about supported scoring algorithms,
@@ -123,6 +125,7 @@ def __init__(
123125
ValueError: If `vector_search_method` is "KNN" and `knn_k` is not provided.
124126
ValueError: If `vector_search_method` is "RANGE" and `range_radius` is not provided.
125127
"""
128+
126129
try:
127130
from redis.commands.search.hybrid_query import (
128131
CombineResultsMethod,
@@ -146,9 +149,18 @@ def __init__(
146149
text, text_field_name, filter_expression
147150
)
148151

152+
if isinstance(vector, bytes):
153+
vector_data = vector
154+
else:
155+
vector_data = array_to_buffer(vector, dtype)
156+
157+
self.params = {
158+
vector_param_name: vector_data,
159+
}
160+
149161
self.query = build_base_query(
150162
text_query=query_string,
151-
vector=vector,
163+
vector_param_name=vector_param_name,
152164
vector_field_name=vector_field_name,
153165
text_scorer=text_scorer,
154166
yield_text_score_as=yield_text_score_as,
@@ -159,7 +171,6 @@ def __init__(
159171
range_epsilon=range_epsilon,
160172
yield_vsim_score_as=yield_vsim_score_as,
161173
filter_expression=filter_expression,
162-
dtype=dtype,
163174
)
164175

165176
if combination_method:
@@ -178,7 +189,7 @@ def __init__(
178189

179190
def build_base_query(
180191
text_query: str,
181-
vector: Union[bytes, List[float]],
192+
vector_param_name: str,
182193
vector_field_name: str,
183194
text_scorer: str = "BM25STD",
184195
yield_text_score_as: Optional[str] = None,
@@ -189,13 +200,12 @@ def build_base_query(
189200
range_epsilon: Optional[float] = None,
190201
yield_vsim_score_as: Optional[str] = None,
191202
filter_expression: Optional[Union[str, FilterExpression]] = None,
192-
dtype: str = "float32",
193203
):
194204
"""Build a Redis HybridQuery for performing hybrid search.
195205
196206
Args:
197207
text_query: The query for the text search.
198-
vector: The vector to perform vector similarity search.
208+
vector_param_name: The name of the parameter substitution containing the vector blob.
199209
vector_field_name: The vector field name to search in.
200210
text_scorer: The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
201211
BM25STD, BM25STD.NORM, BM25STD.TANH, DISMAX, DOCSCORE, HAMMING}. Defaults to "BM25STD". For more
@@ -210,7 +220,6 @@ def build_base_query(
210220
accuracy of the search.
211221
yield_vsim_score_as: The name of the field to yield the vector similarity score as.
212222
filter_expression: The filter expression to use for the vector similarity search. Defaults to None.
213-
dtype: The data type of the vector. Defaults to "float32".
214223
215224
Notes:
216225
If RRF combination method is used, then at least one of `rrf_window` or `rrf_constant` must be provided.
@@ -242,11 +251,6 @@ def build_base_query(
242251
yield_score_as=yield_text_score_as,
243252
)
244253

245-
if isinstance(vector, bytes):
246-
vector_data = vector
247-
else:
248-
vector_data = array_to_buffer(vector, dtype)
249-
250254
# Serialize vector similarity search method and params, if specified
251255
vsim_search_method: Optional[VectorSearchMethods] = None
252256
vsim_search_method_params: Dict[str, Any] = {}
@@ -284,7 +288,7 @@ def build_base_query(
284288
# Serialize the vector similarity query
285289
vsim_query = HybridVsimQuery(
286290
vector_field_name="@" + vector_field_name,
287-
vector_data=vector_data,
291+
vector_data="$" + vector_param_name,
288292
vsim_search_method=vsim_search_method,
289293
vsim_search_method_params=vsim_search_method_params,
290294
filter=vsim_filter,

tests/unit/test_hybrid_types.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def get_query_pieces(query: HybridQuery) -> List[str]:
4949
pieces.extend(query.combination_method.get_args())
5050
if query.postprocessing_config.build_args():
5151
pieces.extend(query.postprocessing_config.build_args())
52+
if query.params:
53+
params = [
54+
"PARAMS",
55+
len(query.params) * 2,
56+
] + [item for pair in query.params.items() for item in pair]
57+
pieces.extend(params)
5258
return pieces
5359

5460

@@ -76,10 +82,14 @@ def test_hybrid_query_basic_initialization():
7682
"BM25STD",
7783
"VSIM",
7884
"@embedding",
79-
bytes_vector,
85+
"$vector",
8086
"LIMIT",
8187
"0",
8288
"10",
89+
"PARAMS",
90+
2,
91+
"vector",
92+
bytes_vector,
8393
]
8494

8595
# Verify that no combination method is set
@@ -126,7 +136,7 @@ def test_hybrid_query_with_all_parameters():
126136
"text_score",
127137
"VSIM",
128138
"@embedding",
129-
bytes_vector,
139+
"$vector",
130140
"KNN",
131141
4,
132142
"K",
@@ -149,11 +159,15 @@ def test_hybrid_query_with_all_parameters():
149159
"LIMIT",
150160
"0",
151161
"10",
162+
"PARAMS",
163+
2,
164+
"vector",
165+
bytes_vector,
152166
]
153167

154168
# Add post-processing and verify that it is reflected in the query
155169
hybrid_query.postprocessing_config.limit(offset=10, num=20)
156-
assert get_query_pieces(hybrid_query)[-3:] == ["LIMIT", "10", "20"]
170+
assert hybrid_query.postprocessing_config.build_args() == ["LIMIT", "10", "20"]
157171

158172

159173
# Stopwords tests
@@ -376,12 +390,16 @@ def test_hybrid_query_with_string_filter():
376390
"BM25STD",
377391
"VSIM",
378392
"@embedding",
379-
bytes_vector,
393+
"$vector",
380394
"FILTER",
381395
"@category:{tech|science|engineering}",
382396
"LIMIT",
383397
"0",
384398
"10",
399+
"PARAMS",
400+
2,
401+
"vector",
402+
bytes_vector,
385403
]
386404

387405

@@ -405,12 +423,16 @@ def test_hybrid_query_with_tag_filter():
405423
"BM25STD",
406424
"VSIM",
407425
"@embedding",
408-
bytes_vector,
426+
"$vector",
409427
"FILTER",
410428
"@genre:{comedy}",
411429
"LIMIT",
412430
"0",
413431
"10",
432+
"PARAMS",
433+
2,
434+
"vector",
435+
bytes_vector,
414436
]
415437

416438

@@ -645,10 +667,14 @@ def test_hybrid_query_special_characters_in_text():
645667
"BM25STD",
646668
"VSIM",
647669
"@embedding",
648-
bytes_vector,
670+
"$vector",
649671
"LIMIT",
650672
"0",
651673
"10",
674+
"PARAMS",
675+
2,
676+
"vector",
677+
bytes_vector,
652678
]
653679

654680

@@ -672,10 +698,14 @@ def test_hybrid_query_unicode_text():
672698
"BM25STD",
673699
"VSIM",
674700
"@embedding",
675-
bytes_vector,
701+
"$vector",
676702
"LIMIT",
677703
"0",
678704
"10",
705+
"PARAMS",
706+
2,
707+
"vector",
708+
bytes_vector,
679709
]
680710

681711

0 commit comments

Comments
 (0)