Skip to content

Commit 493172f

Browse files
committed
Add query validation to indexes and validate EF_RUNTIME
1 parent a1ed87f commit 493172f

File tree

6 files changed

+202
-2
lines changed

6 files changed

+202
-2
lines changed

redisvl/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ def __init__(self, message, index=None):
3030
if index is not None:
3131
message = f"Validation failed for object at index {index}: {message}"
3232
super().__init__(message)
33+
34+
35+
class QueryValidationError(RedisVLError):
36+
"""Error when validating a query."""
37+
38+
pass

redisvl/index/index.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Union,
1919
)
2020

21+
from redisvl.query.query import VectorQuery
2122
from redisvl.redis.utils import convert_bytes, make_dict
2223
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2324

@@ -34,6 +35,7 @@
3435
from redis.commands.search.indexDefinition import IndexDefinition
3536

3637
from redisvl.exceptions import (
38+
QueryValidationError,
3739
RedisModuleVersionError,
3840
RedisSearchError,
3941
RedisVLError,
@@ -835,8 +837,21 @@ def batch_query(
835837
all_parsed.append(parsed)
836838
return all_parsed
837839

840+
def _validate_query(self, query: BaseQuery) -> None:
841+
"""Validate a query."""
842+
if isinstance(query, VectorQuery):
843+
field = self.schema.fields[query._vector_field_name]
844+
if query.ef_runtime and field.attrs.algorithm != "hnsw": # type: ignore
845+
raise QueryValidationError(
846+
"Flat index does not support vector queries."
847+
)
848+
838849
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
839850
"""Execute a query and process results."""
851+
try:
852+
self._validate_query(query)
853+
except QueryValidationError as e:
854+
raise QueryValidationError(f"Invalid query: {str(e)}") from e
840855
results = self.search(query.query, query_params=query.params)
841856
return process_results(results, query=query, schema=self.schema)
842857

@@ -1527,8 +1542,21 @@ async def batch_query(
15271542

15281543
return all_parsed
15291544

1545+
def _validate_query(self, query: BaseQuery) -> None:
1546+
"""Validate a query."""
1547+
if isinstance(query, VectorQuery):
1548+
field = self.schema.fields[query._vector_field_name]
1549+
if query.ef_runtime and field.attrs.algorithm != "hnsw": # type: ignore
1550+
raise QueryValidationError(
1551+
"Flat index does not support vector queries."
1552+
)
1553+
15301554
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
15311555
"""Asynchronously execute a query and process results."""
1556+
try:
1557+
self._validate_query(query)
1558+
except QueryValidationError as e:
1559+
raise QueryValidationError(f"Invalid query: {str(e)}") from e
15321560
results = await self.search(query.query, query_params=query.params)
15331561
return process_results(results, query=query, schema=self.schema)
15341562

tests/conftest.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pytest
55
from testcontainers.compose import DockerCompose
66

7+
from redisvl.index.index import AsyncSearchIndex, SearchIndex
78
from redisvl.redis.connection import RedisConnectionFactory
9+
from redisvl.redis.utils import array_to_buffer
810
from redisvl.utils.vectorize import HFTextVectorizer
911

1012

@@ -191,3 +193,111 @@ def pytest_collection_modifyitems(
191193
for item in items:
192194
if item.get_closest_marker("requires_api_keys"):
193195
item.add_marker(skip_api)
196+
197+
198+
@pytest.fixture
199+
def flat_index(sample_data, redis_url):
200+
"""
201+
A fixture that uses the "flag" algorithm for its vector field.
202+
"""
203+
# construct a search index from the schema
204+
index = SearchIndex.from_dict(
205+
{
206+
"index": {
207+
"name": "user_index",
208+
"prefix": "v1",
209+
"storage_type": "hash",
210+
},
211+
"fields": [
212+
{"name": "description", "type": "text"},
213+
{"name": "credit_score", "type": "tag"},
214+
{"name": "job", "type": "text"},
215+
{"name": "age", "type": "numeric"},
216+
{"name": "last_updated", "type": "numeric"},
217+
{"name": "location", "type": "geo"},
218+
{
219+
"name": "user_embedding",
220+
"type": "vector",
221+
"attrs": {
222+
"dims": 3,
223+
"distance_metric": "cosine",
224+
"algorithm": "flat",
225+
"datatype": "float32",
226+
},
227+
},
228+
],
229+
},
230+
redis_url=redis_url,
231+
)
232+
233+
# create the index (no data yet)
234+
index.create(overwrite=True)
235+
236+
# Prepare and load the data
237+
def hash_preprocess(item: dict) -> dict:
238+
return {
239+
**item,
240+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
241+
}
242+
243+
index.load(sample_data, preprocess=hash_preprocess)
244+
245+
# run the test
246+
yield index
247+
248+
# clean up
249+
index.delete(drop=True)
250+
251+
252+
@pytest.fixture
253+
async def async_flat_index(sample_data, redis_url):
254+
"""
255+
A fixture that uses the "flag" algorithm for its vector field.
256+
"""
257+
# construct a search index from the schema
258+
index = AsyncSearchIndex.from_dict(
259+
{
260+
"index": {
261+
"name": "user_index",
262+
"prefix": "v1",
263+
"storage_type": "hash",
264+
},
265+
"fields": [
266+
{"name": "description", "type": "text"},
267+
{"name": "credit_score", "type": "tag"},
268+
{"name": "job", "type": "text"},
269+
{"name": "age", "type": "numeric"},
270+
{"name": "last_updated", "type": "numeric"},
271+
{"name": "location", "type": "geo"},
272+
{
273+
"name": "user_embedding",
274+
"type": "vector",
275+
"attrs": {
276+
"dims": 3,
277+
"distance_metric": "cosine",
278+
"algorithm": "flat",
279+
"datatype": "float32",
280+
},
281+
},
282+
],
283+
},
284+
redis_url=redis_url,
285+
)
286+
287+
# create the index (no data yet)
288+
await index.create(overwrite=True)
289+
290+
# Prepare and load the data
291+
def hash_preprocess(item: dict) -> dict:
292+
return {
293+
**item,
294+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
295+
}
296+
297+
await index.load(sample_data, preprocess=hash_preprocess)
298+
299+
# run the test
300+
yield index
301+
302+
# clean up
303+
await index.delete(drop=True)

tests/integration/test_async_search_index.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from redis import Redis as SyncRedis
66
from redis.asyncio import Redis as AsyncRedis
77

8-
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError, RedisVLError
8+
from redisvl.exceptions import (
9+
QueryValidationError,
10+
RedisModuleVersionError,
11+
RedisSearchError,
12+
RedisVLError,
13+
)
914
from redisvl.index import AsyncSearchIndex
1015
from redisvl.query import VectorQuery
1116
from redisvl.query.query import FilterQuery
@@ -614,3 +619,16 @@ async def test_async_search_index_expire_keys(async_index):
614619
ttl = await client.ttl(key)
615620
assert ttl > 0
616621
assert ttl <= 30
622+
623+
624+
@pytest.mark.asyncio
625+
async def test_search_index_validates_query(async_flat_index, sample_data):
626+
query = VectorQuery(
627+
[0.1, 0.1, 0.5],
628+
"user_embedding",
629+
return_fields=["user", "credit_score", "age", "job", "location"],
630+
num_results=7,
631+
ef_runtime=100,
632+
)
633+
with pytest.raises(QueryValidationError):
634+
await async_flat_index.query(query)

tests/integration/test_query.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from redis.commands.search.result import Result
55

6+
from redisvl.exceptions import QueryValidationError
67
from redisvl.index import SearchIndex
78
from redisvl.query import (
89
CountQuery,
@@ -898,3 +899,23 @@ def test_vector_query_with_ef_runtime(index, vector_query, sample_data):
898899
assert len(results) > 0
899900
for result in results:
900901
assert "vector_distance" in result
902+
903+
904+
def test_vector_query_with_ef_runtime_flat_index(flat_index, vector_query, sample_data):
905+
"""
906+
Integration test: Verify that Redis ignores EF_RUNTIME on a query if the
907+
algo is "flat." EF_RUNTIME is only valid with the "hnsw" algorithm.
908+
"""
909+
vector_query.set_ef_runtime(100)
910+
911+
# The vector query does not know if the index field supports EF_RUNTIME,
912+
# so it should include this param in the query string if asked.
913+
query_string = str(vector_query)
914+
assert (
915+
f"{vector_query.__class__.EF_RUNTIME} ${vector_query.__class__.EF_RUNTIME_PARAM}"
916+
in query_string
917+
), "EF_RUNTIME should be in query string"
918+
919+
# However, the index should raise an error if EF_RUNTIME is set on a flat index.
920+
with pytest.raises(QueryValidationError): # noqa: F821
921+
flat_index.query(vector_query)

tests/integration/test_search_index.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import pytest
55
from redis import Redis
66

7-
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError, RedisVLError
7+
from redisvl.exceptions import (
8+
QueryValidationError,
9+
RedisModuleVersionError,
10+
RedisSearchError,
11+
RedisVLError,
12+
)
813
from redisvl.index import SearchIndex
914
from redisvl.query import VectorQuery
1015
from redisvl.query.query import FilterQuery
@@ -556,3 +561,15 @@ def test_search_index_expire_keys(index):
556561
ttl = index.client.ttl(key)
557562
assert ttl > 0
558563
assert ttl <= 30
564+
565+
566+
def test_search_index_validates_query(flat_index, sample_data):
567+
query = VectorQuery(
568+
[0.1, 0.1, 0.5],
569+
"user_embedding",
570+
return_fields=["user", "credit_score", "age", "job", "location"],
571+
num_results=7,
572+
ef_runtime=100,
573+
)
574+
with pytest.raises(QueryValidationError):
575+
flat_index.query(query)

0 commit comments

Comments
 (0)