Skip to content

Commit 3b54ce5

Browse files
committed
add normalize cosine distance flag
1 parent 0b3a5ce commit 3b54ce5

File tree

5 files changed

+165
-10
lines changed

5 files changed

+165
-10
lines changed

redisvl/index/index.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
Union,
1717
)
1818

19-
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
19+
from redisvl.utils.utils import (
20+
deprecated_argument,
21+
deprecated_function,
22+
norm_cosine_distance,
23+
sync_wrapper,
24+
)
2025

2126
if TYPE_CHECKING:
2227
from redis.commands.search.aggregation import AggregateResult
@@ -30,14 +35,21 @@
3035

3136
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
3237
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
33-
from redisvl.query import BaseQuery, CountQuery, FilterQuery
38+
from redisvl.query import (
39+
BaseQuery,
40+
CountQuery,
41+
FilterQuery,
42+
VectorQuery,
43+
VectorRangeQuery,
44+
)
3445
from redisvl.query.filter import FilterExpression
3546
from redisvl.redis.connection import (
3647
RedisConnectionFactory,
3748
convert_index_info_to_schema,
3849
)
3950
from redisvl.redis.utils import convert_bytes
4051
from redisvl.schema import IndexSchema, StorageType
52+
from redisvl.schema.fields import VectorDistanceMetric
4153
from redisvl.utils.log import get_logger
4254

4355
logger = get_logger(__name__)
@@ -50,7 +62,7 @@
5062

5163

5264
def process_results(
53-
results: "Result", query: BaseQuery, storage_type: StorageType
65+
results: "Result", query: BaseQuery, schema: IndexSchema
5466
) -> List[Dict[str, Any]]:
5567
"""Convert a list of search Result objects into a list of document
5668
dictionaries.
@@ -75,11 +87,18 @@ def process_results(
7587

7688
# Determine if unpacking JSON is needed
7789
unpack_json = (
78-
(storage_type == StorageType.JSON)
90+
(schema.index.storage_type == StorageType.JSON)
7991
and isinstance(query, FilterQuery)
8092
and not query._return_fields # type: ignore
8193
)
8294

95+
normalize_cosine_distance = (
96+
(isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery))
97+
and query._normalize_cosine_distance
98+
and schema.fields[query._vector_field_name].attrs.distance_metric # type: ignore
99+
== VectorDistanceMetric.COSINE
100+
)
101+
83102
# Process records
84103
def _process(doc: "Document") -> Dict[str, Any]:
85104
doc_dict = doc.__dict__
@@ -93,6 +112,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
93112
return {"id": doc_dict.get("id"), **json_data}
94113
raise ValueError(f"Unable to parse json data from Redis {json_data}")
95114

115+
if normalize_cosine_distance:
116+
# convert float back to string to be consistent
117+
doc_dict[query.DISTANCE_ID] = str( # type: ignore
118+
norm_cosine_distance(float(doc_dict[query.DISTANCE_ID])) # type: ignore
119+
)
120+
96121
# Remove 'payload' if present
97122
doc_dict.pop("payload", None)
98123

@@ -665,9 +690,7 @@ def search(self, *args, **kwargs) -> "Result":
665690
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
666691
"""Execute a query and process results."""
667692
results = self.search(query.query, query_params=query.params)
668-
return process_results(
669-
results, query=query, storage_type=self.schema.index.storage_type
670-
)
693+
return process_results(results, query=query, schema=self.schema)
671694

672695
def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
673696
"""Execute a query on the index.
@@ -1219,9 +1242,7 @@ async def search(self, *args, **kwargs) -> "Result":
12191242
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
12201243
"""Asynchronously execute a query and process results."""
12211244
results = await self.search(query.query, query_params=query.params)
1222-
return process_results(
1223-
results, query=query, storage_type=self.schema.index.storage_type
1224-
)
1245+
return process_results(results, query=query, schema=self.schema)
12251246

12261247
async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
12271248
"""Asynchronously execute a query on the index.

redisvl/query/query.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
in_order: bool = False,
199199
hybrid_policy: Optional[str] = None,
200200
batch_size: Optional[int] = None,
201+
normalize_cosine_distance: bool = False,
201202
):
202203
"""A query for running a vector search along with an optional filter
203204
expression.
@@ -233,6 +234,9 @@ def __init__(
233234
of vectors to fetch in each batch. Larger values may improve performance
234235
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
235236
Defaults to None, which lets Redis auto-select an appropriate batch size.
237+
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
238+
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
239+
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
236240
237241
Raises:
238242
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -246,6 +250,7 @@ def __init__(
246250
self._num_results = num_results
247251
self._hybrid_policy: Optional[HybridPolicy] = None
248252
self._batch_size: Optional[int] = None
253+
self._normalize_cosine_distance = normalize_cosine_distance
249254
self.set_filter(filter_expression)
250255
query_string = self._build_query_string()
251256

@@ -394,6 +399,7 @@ def __init__(
394399
in_order: bool = False,
395400
hybrid_policy: Optional[str] = None,
396401
batch_size: Optional[int] = None,
402+
normalize_cosine_distance: bool = False,
397403
):
398404
"""A query for running a filtered vector search based on semantic
399405
distance threshold.
@@ -437,6 +443,16 @@ def __init__(
437443
of vectors to fetch in each batch. Larger values may improve performance
438444
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
439445
Defaults to None, which lets Redis auto-select an appropriate batch size.
446+
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
447+
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
448+
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
449+
450+
Raises:
451+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
452+
453+
Note:
454+
Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query
455+
440456
"""
441457
self._vector = vector
442458
self._vector_field_name = vector_field_name
@@ -456,6 +472,7 @@ def __init__(
456472
if batch_size is not None:
457473
self.set_batch_size(batch_size)
458474

475+
self._normalize_cosine_distance = normalize_cosine_distance
459476
self.set_distance_threshold(distance_threshold)
460477
self.set_filter(filter_expression)
461478
query_string = self._build_query_string()

redisvl/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,10 @@ def wrapper():
191191
return
192192

193193
return wrapper
194+
195+
196+
def norm_cosine_distance(value: float) -> float:
197+
"""
198+
Normalize the cosine distance to a similarity score between 0 and 1.
199+
"""
200+
return (2 - value) / 2

tests/integration/test_query.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ def sorted_vector_query():
5353
)
5454

5555

56+
@pytest.fixture
57+
def normalized_vector_query():
58+
return VectorQuery(
59+
vector=[0.1, 0.1, 0.5],
60+
vector_field_name="user_embedding",
61+
normalize_cosine_distance=True,
62+
return_score=True,
63+
return_fields=[
64+
"user",
65+
"credit_score",
66+
"age",
67+
"job",
68+
"location",
69+
"last_updated",
70+
],
71+
)
72+
73+
5674
@pytest.fixture
5775
def filter_query():
5876
return FilterQuery(
@@ -84,6 +102,18 @@ def sorted_filter_query():
84102
)
85103

86104

105+
@pytest.fixture
106+
def normalized_range_query():
107+
return RangeQuery(
108+
vector=[0.1, 0.1, 0.5],
109+
vector_field_name="user_embedding",
110+
normalize_cosine_distance=True,
111+
return_score=True,
112+
return_fields=["user", "credit_score", "age", "job", "location"],
113+
distance_threshold=0.2,
114+
)
115+
116+
87117
@pytest.fixture
88118
def range_query():
89119
return RangeQuery(
@@ -155,6 +185,56 @@ def hash_preprocess(item: dict) -> dict:
155185
index.delete(drop=True)
156186

157187

188+
@pytest.fixture
189+
def L2_index(sample_data, redis_url):
190+
# construct a search index from the schema
191+
index = SearchIndex.from_dict(
192+
{
193+
"index": {
194+
"name": "L2_index",
195+
"prefix": "L2_index",
196+
"storage_type": "hash",
197+
},
198+
"fields": [
199+
{"name": "credit_score", "type": "tag"},
200+
{"name": "job", "type": "text"},
201+
{"name": "age", "type": "numeric"},
202+
{"name": "last_updated", "type": "numeric"},
203+
{"name": "location", "type": "geo"},
204+
{
205+
"name": "user_embedding",
206+
"type": "vector",
207+
"attrs": {
208+
"dims": 3,
209+
"distance_metric": "L2",
210+
"algorithm": "flat",
211+
"datatype": "float32",
212+
},
213+
},
214+
],
215+
},
216+
redis_url=redis_url,
217+
)
218+
219+
# create the index (no data yet)
220+
index.create(overwrite=True)
221+
222+
# Prepare and load the data
223+
def hash_preprocess(item: dict) -> dict:
224+
return {
225+
**item,
226+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
227+
}
228+
229+
index.load(sample_data, preprocess=hash_preprocess)
230+
231+
# run the test
232+
yield index
233+
234+
# clean up
235+
index.delete(drop=True)
236+
237+
158238
def test_search_and_query(index):
159239
# *=>[KNN 7 @user_embedding $vector AS vector_distance]
160240
v = VectorQuery(
@@ -659,3 +739,26 @@ def test_range_query_with_filter_and_hybrid_policy(index):
659739
for result in results:
660740
assert result["credit_score"] == "high"
661741
assert float(result["vector_distance"]) <= 0.5
742+
743+
744+
def test_query_normalize_cosine_distance(index, normalized_vector_query):
745+
746+
res = index.query(normalized_vector_query)
747+
748+
for r in res:
749+
assert 0 <= float(r["vector_distance"]) <= 1
750+
751+
752+
def test_query_normalize_cosine_distance_ip_distance(L2_index, normalized_vector_query):
753+
754+
res = L2_index.query(normalized_vector_query)
755+
756+
assert any(float(r["vector_distance"]) > 1 for r in res)
757+
758+
759+
def test_range_query_normalize_cosine_distance(index, normalized_range_query):
760+
761+
res = index.query(normalized_range_query)
762+
763+
for r in res:
764+
assert 0 <= float(r["vector_distance"]) <= 1

tests/unit/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,16 @@
1818
assert_no_warnings,
1919
deprecated_argument,
2020
deprecated_function,
21+
norm_cosine_distance,
2122
)
2223

2324

25+
def test_norm_cosine_distance():
26+
input = 2
27+
expected = 0
28+
assert norm_cosine_distance(input) == expected
29+
30+
2431
def test_even_number_of_elements():
2532
"""Test with an even number of elements"""
2633
values = ["key1", "value1", "key2", "value2"]

0 commit comments

Comments
 (0)