Skip to content

Commit 0518dc4

Browse files
cleans up unit test
1 parent 6df30aa commit 0518dc4

File tree

1 file changed

+21
-34
lines changed

1 file changed

+21
-34
lines changed

tests/unit/test_aggregation_types.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def test_aggregate_hybrid_query():
9191
stopwords=["the", "a", "of"],
9292
)
9393
assert hybrid_query.stopwords == set(["the", "a", "of"])
94+
9495
hybrid_query = HybridQuery(
9596
sample_text,
9697
text_field_name,
@@ -195,7 +196,7 @@ def test_hybrid_query_with_string_filter():
195196
assert "AND" not in query_string_wildcard
196197

197198

198-
def test_aggregate_multi_vector_query():
199+
def test_multi_vector_query():
199200
# test we require vectors and field names
200201
with pytest.raises(TypeError):
201202
_ = MultiVectorQuery()
@@ -265,51 +266,37 @@ def test_aggregate_multi_vector_query():
265266
assert multivector_query._dialect == 4
266267

267268

268-
def test_aggregate_multi_vector_query_broadcasting():
269-
# if a single vector and multiple fields is passed we search with the same vector over all fields
270-
multivector_query = MultiVectorQuery(
271-
vectors=[sample_vector],
272-
vector_field_names=["text embedding", "image embedding"],
273-
)
274-
assert multi_vector_query.query == "<raw text here>"
275-
276-
# vector being broadcast doesn't need to be in a list
277-
multivector_query = MultiVectorQuery(
278-
vectors=sample_vector, vector_field_names=["text embedding", "image embedding"]
279-
)
280-
assert multi_vector_query.query == "<raw text here>"
281-
282-
# if multiple vectors are passed and a single field name we search with all vectors on that field
283-
multivector_query = MultiVectorQuery(
269+
def test_multi_vector_query_broadcasting():
270+
# if a single weight is passed it is applied to all similarity scores
271+
field_1 = "text embedding"
272+
field_2 = "image embedding"
273+
weight = 0.2
274+
multi_vector_query = MultiVectorQuery(
284275
vectors=[sample_vector_2, sample_vector_3],
285-
vector_field_names=["text embedding"],
276+
vector_field_names=[field_1, field_2],
277+
weights=[weight],
286278
)
287-
assert multi_vector_query.query == "<raw text here>"
288279

289-
# vector field name does not need to be in a list if only one is provided
290-
multivector_query = MultiVectorQuery(
291-
vectors=[sample_vector_2, sample_vector_3], vector_field_names="text embedding"
280+
assert (
281+
str(multi_vector_query)
282+
== f"@{field_1}:[VECTOR_RANGE 2.0 $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{field_2}:[VECTOR_RANGE 2.0 $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * {weight} + @score_1 * {weight} AS combined_score SORTBY 2 @combined_score DESC MAX 10"
292283
)
293-
assert multi_vector_query.query == "<raw text here>"
294284

295-
# if a single weight is passed it is applied to all similarity scores
296-
multivector_query = MultiVectorQuery(
285+
# if a single dtype is passed it is applied to all vectors
286+
multi_vector_query = MultiVectorQuery(
297287
vectors=[sample_vector_2, sample_vector_3],
298288
vector_field_names=["text embedding", "image embedding"],
299-
weights=[0.2],
289+
dtypes=["float16"],
300290
)
301-
assert multi_vector_query.query == "<raw text here>"
302291

303-
# weight does not need to be in a list if only one is provided
304-
multivector_query = MultiVectorQuery(
305-
vectors=[sample_vector_2, sample_vector_3],
306-
vector_field_names=["text embedding", "image embedding"],
307-
weights=0.2,
292+
assert multi_vector_query._dtypes == ["float16", "float16"]
293+
assert (
294+
str(multi_vector_query)
295+
== f"@{field_1}:[VECTOR_RANGE 2.0 $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{field_2}:[VECTOR_RANGE 2.0 $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * 1.0 + @score_1 * 1.0 AS combined_score SORTBY 2 @combined_score DESC MAX 10"
308296
)
309-
assert multi_vector_query.query == "<raw text here>"
310297

311298

312-
def test_aggregate_multi_vector_query_errors():
299+
def test_multi_vector_query_errors():
313300
# test an error is raised if the number of vectors and number of fields don't match
314301
with pytest.raises(ValueError):
315302
_ = MultiVectorQuery(

0 commit comments

Comments
 (0)