44from redis .commands .search .result import Result
55
66from redisvl .index import SearchIndex
7- from redisvl .query import CountQuery , FilterQuery , RangeQuery , VectorQuery
7+ from redisvl .query import CountQuery , FilterQuery , VectorQuery , VectorRangeQuery
88from redisvl .query .filter import (
99 FilterExpression ,
1010 Geo ,
@@ -104,7 +104,7 @@ def sorted_filter_query():
104104
105105@pytest .fixture
106106def normalized_range_query ():
107- return RangeQuery (
107+ return VectorRangeQuery (
108108 vector = [0.1 , 0.1 , 0.5 ],
109109 vector_field_name = "user_embedding" ,
110110 normalize_vector_distance = True ,
@@ -116,7 +116,7 @@ def normalized_range_query():
116116
117117@pytest .fixture
118118def range_query ():
119- return RangeQuery (
119+ return VectorRangeQuery (
120120 vector = [0.1 , 0.1 , 0.5 ],
121121 vector_field_name = "user_embedding" ,
122122 return_fields = ["user" , "credit_score" , "age" , "job" , "location" ],
@@ -126,7 +126,7 @@ def range_query():
126126
127127@pytest .fixture
128128def sorted_range_query ():
129- return RangeQuery (
129+ return VectorRangeQuery (
130130 vector = [0.1 , 0.1 , 0.5 ],
131131 vector_field_name = "user_embedding" ,
132132 return_fields = ["user" , "credit_score" , "age" , "job" , "location" ],
@@ -271,7 +271,7 @@ def test_search_and_query(index):
271271
272272
273273def test_range_query (index ):
274- r = RangeQuery (
274+ r = VectorRangeQuery (
275275 vector = [0.1 , 0.1 , 0.5 ],
276276 vector_field_name = "user_embedding" ,
277277 return_fields = ["user" , "credit_score" , "age" , "job" ],
@@ -342,7 +342,7 @@ def search(
342342 assert doc .location == location
343343
344344 # if range query, test results by distance threshold
345- if isinstance (query , RangeQuery ):
345+ if isinstance (query , VectorRangeQuery ):
346346 for doc in results .docs :
347347 print (doc .vector_distance )
348348 assert float (doc .vector_distance ) <= distance_threshold
@@ -353,7 +353,7 @@ def search(
353353
354354 # check results are in sorted order
355355 if sort :
356- if isinstance (query , RangeQuery ):
356+ if isinstance (query , VectorRangeQuery ):
357357 assert [int (doc .age ) for doc in results .docs ] == [12 , 14 , 18 , 100 ]
358358 else :
359359 assert [int (doc .age ) for doc in results .docs ] == [
@@ -369,7 +369,7 @@ def search(
369369
370370@pytest .fixture (
371371 params = ["vector_query" , "filter_query" , "range_query" ],
372- ids = ["VectorQuery" , "FilterQuery" , "RangeQuery " ],
372+ ids = ["VectorQuery" , "FilterQuery" , "VectorRangeQuery " ],
373373)
374374def query (request ):
375375 return request .getfixturevalue (request .param )
@@ -650,3 +650,15 @@ def test_range_query_normalize_cosine_distance(index, normalized_range_query):
650650
651651 for r in res :
652652 assert 0 <= float (r ["vector_distance" ]) <= 1
653+
654+
655+ def test_range_query_normalize_bad_input (index ):
656+ with pytest .raises (ValueError ):
657+ VectorRangeQuery (
658+ vector = [0.1 , 0.1 , 0.5 ],
659+ vector_field_name = "user_embedding" ,
660+ normalize_vector_distance = True ,
661+ return_score = True ,
662+ return_fields = ["user" , "credit_score" , "age" , "job" , "location" ],
663+ distance_threshold = 1.2 ,
664+ )
0 commit comments