@@ -91,6 +91,7 @@ def test_aggregate_hybrid_query():
9191 stopwords = ["the" , "a" , "of" ],
9292 )
9393 assert hybrid_query .stopwords == set (["the" , "a" , "of" ])
94+
9495 hybrid_query = HybridQuery (
9596 sample_text ,
9697 text_field_name ,
@@ -195,7 +196,7 @@ def test_hybrid_query_with_string_filter():
195196 assert "AND" not in query_string_wildcard
196197
197198
198- def test_aggregate_multi_vector_query ():
199+ def test_multi_vector_query ():
199200 # test we require vectors and field names
200201 with pytest .raises (TypeError ):
201202 _ = MultiVectorQuery ()
@@ -265,51 +266,37 @@ def test_aggregate_multi_vector_query():
265266 assert multivector_query ._dialect == 4
266267
267268
268- def test_aggregate_multi_vector_query_broadcasting ():
269- # if a single vector and multiple fields is passed we search with the same vector over all fields
270- multivector_query = MultiVectorQuery (
271- vectors = [sample_vector ],
272- vector_field_names = ["text embedding" , "image embedding" ],
273- )
274- assert multi_vector_query .query == "<raw text here>"
275-
276- # vector being broadcast doesn't need to be in a list
277- multivector_query = MultiVectorQuery (
278- vectors = sample_vector , vector_field_names = ["text embedding" , "image embedding" ]
279- )
280- assert multi_vector_query .query == "<raw text here>"
281-
282- # if multiple vectors are passed and a single field name we search with all vectors on that field
283- multivector_query = MultiVectorQuery (
269+ def test_multi_vector_query_broadcasting ():
270+ # if a single weight is passed it is applied to all similarity scores
271+ field_1 = "text embedding"
272+ field_2 = "image embedding"
273+ weight = 0.2
274+ multi_vector_query = MultiVectorQuery (
284275 vectors = [sample_vector_2 , sample_vector_3 ],
285- vector_field_names = ["text embedding" ],
276+ vector_field_names = [field_1 , field_2 ],
277+ weights = [weight ],
286278 )
287- assert multi_vector_query .query == "<raw text here>"
288279
289- # vector field name does not need to be in a list if only one is provided
290- multivector_query = MultiVectorQuery (
291- vectors = [ sample_vector_2 , sample_vector_3 ], vector_field_names = "text embedding "
280+ assert (
281+ str ( multi_vector_query )
282+ == f"@ { field_1 } :[VECTOR_RANGE 2.0 $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @ { field_2 } :[VECTOR_RANGE 2.0 $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * { weight } + @score_1 * { weight } AS combined_score SORTBY 2 @combined_score DESC MAX 10 "
292283 )
293- assert multi_vector_query .query == "<raw text here>"
294284
295- # if a single weight is passed it is applied to all similarity scores
296- multivector_query = MultiVectorQuery (
285+ # if a single dtype is passed it is applied to all vectors
286+ multi_vector_query = MultiVectorQuery (
297287 vectors = [sample_vector_2 , sample_vector_3 ],
298288 vector_field_names = ["text embedding" , "image embedding" ],
299- weights = [ 0.2 ],
289+ dtypes = [ "float16" ],
300290 )
301- assert multi_vector_query .query == "<raw text here>"
302291
303- # weight does not need to be in a list if only one is provided
304- multivector_query = MultiVectorQuery (
305- vectors = [sample_vector_2 , sample_vector_3 ],
306- vector_field_names = ["text embedding" , "image embedding" ],
307- weights = 0.2 ,
292+ assert multi_vector_query ._dtypes == ["float16" , "float16" ]
293+ assert (
294+ str (multi_vector_query )
295+ == f"@{ field_1 } :[VECTOR_RANGE 2.0 $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{ field_2 } :[VECTOR_RANGE 2.0 $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * 1.0 + @score_1 * 1.0 AS combined_score SORTBY 2 @combined_score DESC MAX 10"
308296 )
309- assert multi_vector_query .query == "<raw text here>"
310297
311298
312- def test_aggregate_multi_vector_query_errors ():
299+ def test_multi_vector_query_errors ():
313300 # test an error is raised if the number of vectors and number of fields don't match
314301 with pytest .raises (ValueError ):
315302 _ = MultiVectorQuery (
0 commit comments