Skip to content

Commit 8b7c0e8

Browse files
rbs333abrookins
authored andcommitted
add normalize cosine distance flag
1 parent c7c6165 commit 8b7c0e8

File tree

5 files changed

+165
-10
lines changed

5 files changed

+165
-10
lines changed

redisvl/index/index.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
Union,
1919
)
2020

21-
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
21+
from redisvl.utils.utils import (
22+
deprecated_argument,
23+
deprecated_function,
24+
norm_cosine_distance,
25+
sync_wrapper,
26+
)
2227

2328
if TYPE_CHECKING:
2429
from redis.commands.search.aggregation import AggregateResult
@@ -34,14 +39,21 @@
3439

3540
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
3641
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
37-
from redisvl.query import BaseQuery, CountQuery, FilterQuery
42+
from redisvl.query import (
43+
BaseQuery,
44+
CountQuery,
45+
FilterQuery,
46+
VectorQuery,
47+
VectorRangeQuery,
48+
)
3849
from redisvl.query.filter import FilterExpression
3950
from redisvl.redis.connection import (
4051
RedisConnectionFactory,
4152
convert_index_info_to_schema,
4253
)
4354
from redisvl.redis.utils import convert_bytes
4455
from redisvl.schema import IndexSchema, StorageType
56+
from redisvl.schema.fields import VectorDistanceMetric
4557
from redisvl.utils.log import get_logger
4658

4759
logger = get_logger(__name__)
@@ -62,7 +74,7 @@
6274

6375

6476
def process_results(
65-
results: "Result", query: BaseQuery, storage_type: StorageType
77+
results: "Result", query: BaseQuery, schema: IndexSchema
6678
) -> List[Dict[str, Any]]:
6779
"""Convert a list of search Result objects into a list of document
6880
dictionaries.
@@ -87,11 +99,18 @@ def process_results(
8799

88100
# Determine if unpacking JSON is needed
89101
unpack_json = (
90-
(storage_type == StorageType.JSON)
102+
(schema.index.storage_type == StorageType.JSON)
91103
and isinstance(query, FilterQuery)
92104
and not query._return_fields # type: ignore
93105
)
94106

107+
normalize_cosine_distance = (
108+
(isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery))
109+
and query._normalize_cosine_distance
110+
and schema.fields[query._vector_field_name].attrs.distance_metric # type: ignore
111+
== VectorDistanceMetric.COSINE
112+
)
113+
95114
# Process records
96115
def _process(doc: "Document") -> Dict[str, Any]:
97116
doc_dict = doc.__dict__
@@ -105,6 +124,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
105124
return {"id": doc_dict.get("id"), **json_data}
106125
raise ValueError(f"Unable to parse json data from Redis {json_data}")
107126

127+
if normalize_cosine_distance:
128+
# convert float back to string to be consistent
129+
doc_dict[query.DISTANCE_ID] = str( # type: ignore
130+
norm_cosine_distance(float(doc_dict[query.DISTANCE_ID])) # type: ignore
131+
)
132+
108133
# Remove 'payload' if present
109134
doc_dict.pop("payload", None)
110135

@@ -771,9 +796,7 @@ def batch_query(
771796
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
772797
"""Execute a query and process results."""
773798
results = self.search(query.query, query_params=query.params)
774-
return process_results(
775-
results, query=query, storage_type=self.schema.index.storage_type
776-
)
799+
return process_results(results, query=query, schema=self.schema)
777800

778801
def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
779802
"""Execute a query on the index.
@@ -1415,9 +1438,7 @@ async def batch_query(
14151438
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14161439
"""Asynchronously execute a query and process results."""
14171440
results = await self.search(query.query, query_params=query.params)
1418-
return process_results(
1419-
results, query=query, storage_type=self.schema.index.storage_type
1420-
)
1441+
return process_results(results, query=query, schema=self.schema)
14211442

14221443
async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14231444
"""Asynchronously execute a query on the index.

redisvl/query/query.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
in_order: bool = False,
199199
hybrid_policy: Optional[str] = None,
200200
batch_size: Optional[int] = None,
201+
normalize_cosine_distance: bool = False,
201202
):
202203
"""A query for running a vector search along with an optional filter
203204
expression.
@@ -233,6 +234,9 @@ def __init__(
233234
of vectors to fetch in each batch. Larger values may improve performance
234235
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
235236
Defaults to None, which lets Redis auto-select an appropriate batch size.
237+
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
238+
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
239+
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
236240
237241
Raises:
238242
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -246,6 +250,7 @@ def __init__(
246250
self._num_results = num_results
247251
self._hybrid_policy: Optional[HybridPolicy] = None
248252
self._batch_size: Optional[int] = None
253+
self._normalize_cosine_distance = normalize_cosine_distance
249254
self.set_filter(filter_expression)
250255
query_string = self._build_query_string()
251256

@@ -394,6 +399,7 @@ def __init__(
394399
in_order: bool = False,
395400
hybrid_policy: Optional[str] = None,
396401
batch_size: Optional[int] = None,
402+
normalize_cosine_distance: bool = False,
397403
):
398404
"""A query for running a filtered vector search based on semantic
399405
distance threshold.
@@ -437,6 +443,16 @@ def __init__(
437443
of vectors to fetch in each batch. Larger values may improve performance
438444
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
439445
Defaults to None, which lets Redis auto-select an appropriate batch size.
446+
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
447+
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
448+
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
449+
450+
Raises:
451+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
452+
453+
Note:
454+
Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query
455+
440456
"""
441457
self._vector = vector
442458
self._vector_field_name = vector_field_name
@@ -456,6 +472,7 @@ def __init__(
456472
if batch_size is not None:
457473
self.set_batch_size(batch_size)
458474

475+
self._normalize_cosine_distance = normalize_cosine_distance
459476
self.set_distance_threshold(distance_threshold)
460477
self.set_filter(filter_expression)
461478
query_string = self._build_query_string()

redisvl/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,10 @@ def wrapper():
191191
return
192192

193193
return wrapper
194+
195+
196+
def norm_cosine_distance(value: float) -> float:
197+
"""
198+
Normalize the cosine distance to a similarity score between 0 and 1.
199+
"""
200+
return (2 - value) / 2

tests/integration/test_query.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ def sorted_vector_query():
5353
)
5454

5555

56+
@pytest.fixture
57+
def normalized_vector_query():
58+
return VectorQuery(
59+
vector=[0.1, 0.1, 0.5],
60+
vector_field_name="user_embedding",
61+
normalize_cosine_distance=True,
62+
return_score=True,
63+
return_fields=[
64+
"user",
65+
"credit_score",
66+
"age",
67+
"job",
68+
"location",
69+
"last_updated",
70+
],
71+
)
72+
73+
5674
@pytest.fixture
5775
def filter_query():
5876
return FilterQuery(
@@ -84,6 +102,18 @@ def sorted_filter_query():
84102
)
85103

86104

105+
@pytest.fixture
106+
def normalized_range_query():
107+
return RangeQuery(
108+
vector=[0.1, 0.1, 0.5],
109+
vector_field_name="user_embedding",
110+
normalize_cosine_distance=True,
111+
return_score=True,
112+
return_fields=["user", "credit_score", "age", "job", "location"],
113+
distance_threshold=0.2,
114+
)
115+
116+
87117
@pytest.fixture
88118
def range_query():
89119
return RangeQuery(
@@ -155,6 +185,56 @@ def hash_preprocess(item: dict) -> dict:
155185
index.delete(drop=True)
156186

157187

188+
@pytest.fixture
189+
def L2_index(sample_data, redis_url):
190+
# construct a search index from the schema
191+
index = SearchIndex.from_dict(
192+
{
193+
"index": {
194+
"name": "L2_index",
195+
"prefix": "L2_index",
196+
"storage_type": "hash",
197+
},
198+
"fields": [
199+
{"name": "credit_score", "type": "tag"},
200+
{"name": "job", "type": "text"},
201+
{"name": "age", "type": "numeric"},
202+
{"name": "last_updated", "type": "numeric"},
203+
{"name": "location", "type": "geo"},
204+
{
205+
"name": "user_embedding",
206+
"type": "vector",
207+
"attrs": {
208+
"dims": 3,
209+
"distance_metric": "L2",
210+
"algorithm": "flat",
211+
"datatype": "float32",
212+
},
213+
},
214+
],
215+
},
216+
redis_url=redis_url,
217+
)
218+
219+
# create the index (no data yet)
220+
index.create(overwrite=True)
221+
222+
# Prepare and load the data
223+
def hash_preprocess(item: dict) -> dict:
224+
return {
225+
**item,
226+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
227+
}
228+
229+
index.load(sample_data, preprocess=hash_preprocess)
230+
231+
# run the test
232+
yield index
233+
234+
# clean up
235+
index.delete(drop=True)
236+
237+
158238
def test_search_and_query(index):
159239
# *=>[KNN 7 @user_embedding $vector AS vector_distance]
160240
v = VectorQuery(
@@ -659,3 +739,26 @@ def test_range_query_with_filter_and_hybrid_policy(index):
659739
for result in results:
660740
assert result["credit_score"] == "high"
661741
assert float(result["vector_distance"]) <= 0.5
742+
743+
744+
def test_query_normalize_cosine_distance(index, normalized_vector_query):
745+
746+
res = index.query(normalized_vector_query)
747+
748+
for r in res:
749+
assert 0 <= float(r["vector_distance"]) <= 1
750+
751+
752+
def test_query_normalize_cosine_distance_ip_distance(L2_index, normalized_vector_query):
753+
754+
res = L2_index.query(normalized_vector_query)
755+
756+
assert any(float(r["vector_distance"]) > 1 for r in res)
757+
758+
759+
def test_range_query_normalize_cosine_distance(index, normalized_range_query):
760+
761+
res = index.query(normalized_range_query)
762+
763+
for r in res:
764+
assert 0 <= float(r["vector_distance"]) <= 1

tests/unit/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,16 @@
1818
assert_no_warnings,
1919
deprecated_argument,
2020
deprecated_function,
21+
norm_cosine_distance,
2122
)
2223

2324

25+
def test_norm_cosine_distance():
26+
input = 2
27+
expected = 0
28+
assert norm_cosine_distance(input) == expected
29+
30+
2431
def test_even_number_of_elements():
2532
"""Test with an even number of elements"""
2633
values = ["key1", "value1", "key2", "value2"]

0 commit comments

Comments
 (0)