Skip to content

Commit 6493304

Browse files
tests hnsw multi vector indices only on supported search module versions
1 parent 7c85122 commit 6493304

File tree

2 files changed

+61
-86
lines changed

2 files changed

+61
-86
lines changed

tests/conftest.py

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -232,89 +232,6 @@ def sample_datetimes():
232232
}
233233

234234

235-
@pytest.fixture
236-
def OG(sample_datetimes):
237-
return [
238-
{
239-
"user": "john",
240-
"age": 18,
241-
"job": "engineer",
242-
"description": "engineers conduct trains that ride on train tracks",
243-
"last_updated": sample_datetimes["low"].timestamp(),
244-
"credit_score": "high",
245-
"location": "-122.4194,37.7749",
246-
"user_embedding": [0.1, 0.1, 0.5],
247-
"image_embedding": [0.1, 0.1, 0.1, 0.1, 0.1],
248-
},
249-
{
250-
"user": "mary",
251-
"age": 14,
252-
"job": "doctor",
253-
"description": "a medical professional who treats diseases and helps people stay healthy",
254-
"last_updated": sample_datetimes["low"].timestamp(),
255-
"credit_score": "low",
256-
"location": "-122.4194,37.7749",
257-
"user_embedding": [0.1, 0.1, 0.5],
258-
"image_embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
259-
},
260-
{
261-
"user": "nancy",
262-
"age": 94,
263-
"job": "doctor",
264-
"description": "a research scientist specializing in cancers and diseases of the lungs",
265-
"last_updated": sample_datetimes["mid"].timestamp(),
266-
"credit_score": "high",
267-
"location": "-122.4194,37.7749",
268-
"user_embedding": [0.7, 0.1, 0.5],
269-
"image_embedding": [0.1, 0.1, 0.3, 0.3, 0.5],
270-
},
271-
{
272-
"user": "tyler",
273-
"age": 100,
274-
"job": "engineer",
275-
"description": "a software developer with expertise in mathematics and computer science",
276-
"last_updated": sample_datetimes["mid"].timestamp(),
277-
"credit_score": "high",
278-
"location": "-110.0839,37.3861",
279-
"user_embedding": [0.1, 0.4, 0.5],
280-
"image_embedding": [-0.1, -0.2, -0.3, -0.4, -0.5],
281-
},
282-
{
283-
"user": "tim",
284-
"age": 12,
285-
"job": "dermatologist",
286-
"description": "a medical professional specializing in diseases of the skin",
287-
"last_updated": sample_datetimes["mid"].timestamp(),
288-
"credit_score": "high",
289-
"location": "-110.0839,37.3861",
290-
"user_embedding": [0.4, 0.4, 0.5],
291-
"image_embedding": [-0.1, 0.0, 0.6, 0.0, -0.9],
292-
},
293-
{
294-
"user": "taimur",
295-
"age": 15,
296-
"job": "CEO",
297-
"description": "high stress, but financially rewarding position at the head of a company",
298-
"last_updated": sample_datetimes["high"].timestamp(),
299-
"credit_score": "low",
300-
"location": "-110.0839,37.3861",
301-
"user_embedding": [0.6, 0.1, 0.5],
302-
"image_embedding": [1.1, 1.2, -0.3, -4.1, 5.0],
303-
},
304-
{
305-
"user": "joe",
306-
"age": 35,
307-
"job": "dentist",
308-
"description": "like the tooth fairy because they'll take your teeth, but you have to pay them!",
309-
"last_updated": sample_datetimes["high"].timestamp(),
310-
"credit_score": "medium",
311-
"location": "-110.0839,37.3861",
312-
"user_embedding": [-0.1, -0.1, -0.5],
313-
"image_embedding": [-0.8, 2.0, 3.1, 1.5, -1.6],
314-
},
315-
]
316-
317-
318235
@pytest.fixture
319236
def sample_data(sample_datetimes):
320237
return [

tests/integration/test_aggregation.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def index(multi_vector_data, redis_url, worker_id):
5050
"attrs": {
5151
"dims": 6,
5252
"distance_metric": "cosine",
53-
"algorithm": "hnsw",
53+
"algorithm": "flat",
5454
"datatype": "bfloat16",
5555
},
5656
},
@@ -581,6 +581,64 @@ def test_multivector_query_datatypes(index):
581581
)
582582

583583

584-
def test_multivector_query_broadcasting(index):
584+
def test_multivector_query_mixed_index(index):
585+
# test that we can do multi vector queries on indices with both a 'flat' and 'hnsw' index
585586
skip_if_redis_version_below(index.client, "7.2.0")
586-
pass
587+
try:
588+
index.schema.remove_field("audio_embedding")
589+
index.schema.add_field(
590+
{
591+
"name": "audio_embedding",
592+
"type": "vector",
593+
"attrs": {
594+
"dims": 6,
595+
"distance_metric": "cosine",
596+
"algorithm": "hnsw",
597+
"datatype": "bfloat16",
598+
},
599+
},
600+
)
601+
602+
except:
603+
pytest.skip("Required Redis modules not available or version too low")
604+
605+
vectors = [[0.1, 0.2, 0.5], [1.2, 0.3, -0.4, 0.7, 0.2, -0.3]]
606+
vector_fields = ["user_embedding", "audio_embedding"]
607+
return_fields = [
608+
"distance_0",
609+
"distance_1",
610+
"score_0",
611+
"score_1",
612+
"user_embedding",
613+
"audio_embedding",
614+
]
615+
616+
# changing the weights does indeed change the result order
617+
multi_query = MultiVectorQuery(
618+
vectors=vectors,
619+
vector_field_names=vector_fields,
620+
return_fields=return_fields,
621+
dtypes=["float32", "bfloat16"],
622+
)
623+
results = index.query(multi_query)
624+
625+
for i in range(1, len(results)):
626+
assert results[i]["combined_score"] <= results[i - 1]["combined_score"]
627+
628+
# verify we're doing the combined score math correctly
629+
weights = [-1.322, 0.851]
630+
multi_query = MultiVectorQuery(
631+
vectors=vectors,
632+
vector_field_names=vector_fields,
633+
return_fields=return_fields,
634+
dtypes=["float32", "bfloat16"],
635+
weights=weights,
636+
)
637+
638+
results = index.query(multi_query)
639+
assert results
640+
for r in results:
641+
score = float(r["score_0"]) * weights[0] + float(r["score_1"]) * weights[1]
642+
assert (
643+
float(r["combined_score"]) - score <= 0.0001
644+
) # allow for small floating point error

0 commit comments

Comments
 (0)