Skip to content

Commit 977b859

Browse files
committed
More tests, fix comparison to use enum
1 parent 493172f commit 977b859

File tree

5 files changed

+165
-25
lines changed

5 files changed

+165
-25
lines changed

redisvl/index/index.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,18 @@
4848
BaseVectorQuery,
4949
CountQuery,
5050
FilterQuery,
51-
HybridQuery,
5251
)
5352
from redisvl.query.filter import FilterExpression
5453
from redisvl.redis.connection import (
5554
RedisConnectionFactory,
5655
convert_index_info_to_schema,
5756
)
58-
from redisvl.redis.utils import convert_bytes
5957
from redisvl.schema import IndexSchema, StorageType
60-
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
58+
from redisvl.schema.fields import (
59+
VECTOR_NORM_MAP,
60+
VectorDistanceMetric,
61+
VectorIndexAlgorithm,
62+
)
6163
from redisvl.utils.log import get_logger
6264

6365
logger = get_logger(__name__)
@@ -196,6 +198,15 @@ def _storage(self) -> BaseStorage:
196198
index_schema=self.schema
197199
)
198200

201+
def _validate_query(self, query: BaseQuery) -> None:
202+
"""Validate a query."""
203+
if isinstance(query, VectorQuery):
204+
field = self.schema.fields[query._vector_field_name]
205+
if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore
206+
raise QueryValidationError(
207+
"Vector field using 'flat' algorithm does not support EF_RUNTIME query parameter."
208+
)
209+
199210
@property
200211
def name(self) -> str:
201212
"""The name of the Redis search index."""
@@ -837,15 +848,6 @@ def batch_query(
837848
all_parsed.append(parsed)
838849
return all_parsed
839850

840-
def _validate_query(self, query: BaseQuery) -> None:
841-
"""Validate a query."""
842-
if isinstance(query, VectorQuery):
843-
field = self.schema.fields[query._vector_field_name]
844-
if query.ef_runtime and field.attrs.algorithm != "hnsw": # type: ignore
845-
raise QueryValidationError(
846-
"Flat index does not support vector queries."
847-
)
848-
849851
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
850852
"""Execute a query and process results."""
851853
try:
@@ -1416,7 +1418,8 @@ async def _aggregate(
14161418
) -> List[Dict[str, Any]]:
14171419
"""Execute an aggregation query and processes the results."""
14181420
results = await self.aggregate(
1419-
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
1421+
aggregation_query,
1422+
query_params=aggregation_query.params, # type: ignore[attr-defined]
14201423
)
14211424
return process_aggregate_results(
14221425
results,
@@ -1542,15 +1545,6 @@ async def batch_query(
15421545

15431546
return all_parsed
15441547

1545-
def _validate_query(self, query: BaseQuery) -> None:
1546-
"""Validate a query."""
1547-
if isinstance(query, VectorQuery):
1548-
field = self.schema.fields[query._vector_field_name]
1549-
if query.ef_runtime and field.attrs.algorithm != "hnsw": # type: ignore
1550-
raise QueryValidationError(
1551-
"Flat index does not support vector queries."
1552-
)
1553-
15541548
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
15551549
"""Asynchronously execute a query and process results."""
15561550
try:

tests/conftest.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,103 @@ def hash_preprocess(item: dict) -> dict:
301301

302302
# clean up
303303
await index.delete(drop=True)
304+
305+
306+
@pytest.fixture
307+
async def async_hnsw_index(sample_data, redis_url):
308+
"""
309+
A fixture that uses the "hnsw" algorithm for its vector field.
310+
"""
311+
index = AsyncSearchIndex.from_dict(
312+
{
313+
"index": {
314+
"name": "user_index",
315+
"prefix": "v1",
316+
"storage_type": "hash",
317+
},
318+
"fields": [
319+
{"name": "description", "type": "text"},
320+
{"name": "credit_score", "type": "tag"},
321+
{"name": "job", "type": "text"},
322+
{"name": "age", "type": "numeric"},
323+
{"name": "last_updated", "type": "numeric"},
324+
{"name": "location", "type": "geo"},
325+
{
326+
"name": "user_embedding",
327+
"type": "vector",
328+
"attrs": {
329+
"dims": 3,
330+
"distance_metric": "cosine",
331+
"algorithm": "hnsw",
332+
"datatype": "float32",
333+
},
334+
},
335+
],
336+
},
337+
redis_url=redis_url,
338+
)
339+
340+
# create the index (no data yet)
341+
await index.create(overwrite=True)
342+
343+
# Prepare and load the data
344+
def hash_preprocess(item: dict) -> dict:
345+
return {
346+
**item,
347+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
348+
}
349+
350+
await index.load(sample_data, preprocess=hash_preprocess)
351+
352+
# run the test
353+
yield index
354+
355+
356+
@pytest.fixture
357+
def hnsw_index(sample_data, redis_url):
358+
"""
359+
A fixture that uses the "hnsw" algorithm for its vector field.
360+
"""
361+
index = SearchIndex.from_dict(
362+
{
363+
"index": {
364+
"name": "user_index",
365+
"prefix": "v1",
366+
"storage_type": "hash",
367+
},
368+
"fields": [
369+
{"name": "description", "type": "text"},
370+
{"name": "credit_score", "type": "tag"},
371+
{"name": "job", "type": "text"},
372+
{"name": "age", "type": "numeric"},
373+
{"name": "last_updated", "type": "numeric"},
374+
{"name": "location", "type": "geo"},
375+
{
376+
"name": "user_embedding",
377+
"type": "vector",
378+
"attrs": {
379+
"dims": 3,
380+
"distance_metric": "cosine",
381+
"algorithm": "hnsw",
382+
"datatype": "float32",
383+
},
384+
},
385+
],
386+
},
387+
redis_url=redis_url,
388+
)
389+
390+
# create the index (no data yet)
391+
index.create(overwrite=True)
392+
393+
# Prepare and load the data
394+
def hash_preprocess(item: dict) -> dict:
395+
return {
396+
**item,
397+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
398+
}
399+
400+
index.load(sample_data, preprocess=hash_preprocess)
401+
402+
# run the test
403+
yield index

tests/integration/test_async_search_index.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from redisvl.query.query import FilterQuery
1717
from redisvl.redis.utils import convert_bytes
1818
from redisvl.schema import IndexSchema, StorageType
19+
from redisvl.schema.fields import VectorIndexAlgorithm
1920

2021
fields = [{"name": "test", "type": "tag"}]
2122

@@ -622,7 +623,13 @@ async def test_async_search_index_expire_keys(async_index):
622623

623624

624625
@pytest.mark.asyncio
625-
async def test_search_index_validates_query(async_flat_index, sample_data):
626+
async def test_search_index_validates_query_with_flat_algorithm(
627+
async_flat_index, sample_data
628+
):
629+
assert (
630+
async_flat_index.schema.fields["user_embedding"].attrs.algorithm
631+
== VectorIndexAlgorithm.FLAT
632+
)
626633
query = VectorQuery(
627634
[0.1, 0.1, 0.5],
628635
"user_embedding",
@@ -632,3 +639,22 @@ async def test_search_index_validates_query(async_flat_index, sample_data):
632639
)
633640
with pytest.raises(QueryValidationError):
634641
await async_flat_index.query(query)
642+
643+
644+
@pytest.mark.asyncio
645+
async def test_search_index_validates_query_with_hnsw_algorithm(
646+
async_hnsw_index, sample_data
647+
):
648+
assert (
649+
async_hnsw_index.schema.fields["user_embedding"].attrs.algorithm
650+
== VectorIndexAlgorithm.HNSW
651+
)
652+
query = VectorQuery(
653+
[0.1, 0.1, 0.5],
654+
"user_embedding",
655+
return_fields=["user", "credit_score", "age", "job", "location"],
656+
num_results=7,
657+
ef_runtime=100,
658+
)
659+
# Should not raise
660+
await async_hnsw_index.query(query)

tests/integration/test_query.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Text,
2222
Timestamp,
2323
)
24-
from redisvl.query.query import VectorRangeQuery
2524
from redisvl.redis.utils import array_to_buffer
2625

2726
# TODO expand to multiple schema types and sync + async

tests/integration/test_search_index.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from redisvl.query.query import FilterQuery
1616
from redisvl.redis.utils import convert_bytes
1717
from redisvl.schema import IndexSchema, StorageType
18+
from redisvl.schema.fields import VectorIndexAlgorithm
1819

1920
fields = [
2021
{"name": "test", "type": "tag"},
@@ -563,7 +564,11 @@ def test_search_index_expire_keys(index):
563564
assert ttl <= 30
564565

565566

566-
def test_search_index_validates_query(flat_index, sample_data):
567+
def test_search_index_validates_query_with_flat_algorithm(flat_index, sample_data):
568+
assert (
569+
flat_index.schema.fields["user_embedding"].attrs.algorithm
570+
== VectorIndexAlgorithm.FLAT
571+
)
567572
query = VectorQuery(
568573
[0.1, 0.1, 0.5],
569574
"user_embedding",
@@ -573,3 +578,19 @@ def test_search_index_validates_query(flat_index, sample_data):
573578
)
574579
with pytest.raises(QueryValidationError):
575580
flat_index.query(query)
581+
582+
583+
def test_search_index_validates_query_with_hnsw_algorithm(hnsw_index, sample_data):
584+
assert (
585+
hnsw_index.schema.fields["user_embedding"].attrs.algorithm
586+
== VectorIndexAlgorithm.HNSW
587+
)
588+
query = VectorQuery(
589+
[0.1, 0.1, 0.5],
590+
"user_embedding",
591+
return_fields=["user", "credit_score", "age", "job", "location"],
592+
num_results=7,
593+
ef_runtime=100,
594+
)
595+
# Should not raise
596+
hnsw_index.query(query)

0 commit comments

Comments
 (0)