Skip to content

Commit 0cc39a3

Browse files
committed
Adding WITHATTRIBS option to vector set's vsim command.
1 parent 7d57d09 commit 0cc39a3

File tree

4 files changed

+256
-10
lines changed

4 files changed

+256
-10
lines changed

redis/commands/vectorset/commands.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from enum import Enum
3-
from typing import Awaitable, Dict, List, Optional, Union
3+
from typing import Any, Awaitable, Dict, List, Optional, Union
44

55
from redis.client import NEVER_DECODE
66
from redis.commands.helpers import get_protocol_version
@@ -33,6 +33,7 @@ class CallbacksOptions(Enum):
3333

3434
RAW = "RAW"
3535
WITHSCORES = "WITHSCORES"
36+
WITHATTRIBS = "WITHATTRIBS"
3637
ALLOW_DECODING = "ALLOW_DECODING"
3738
RESP3 = "RESP3"
3839

@@ -123,6 +124,7 @@ def vsim(
123124
key: KeyT,
124125
input: Union[List[float], bytes, str],
125126
with_scores: Optional[bool] = False,
127+
with_attribs: Optional[bool] = False,
126128
count: Optional[int] = None,
127129
ef: Optional[Number] = None,
128130
filter: Optional[str] = None,
@@ -131,14 +133,35 @@ def vsim(
131133
no_thread: Optional[bool] = False,
132134
epsilon: Optional[Number] = None,
133135
) -> Union[
134-
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
135-
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
136+
Awaitable[
137+
Optional[
138+
List[
139+
Union[
140+
List[EncodableT],
141+
Dict[EncodableT, Number],
142+
Dict[EncodableT, Dict[str, Any]],
143+
]
144+
]
145+
]
146+
],
147+
Optional[
148+
List[
149+
Union[
150+
List[EncodableT],
151+
Dict[EncodableT, Number],
152+
Dict[EncodableT, Dict[str, Any]],
153+
]
154+
]
155+
],
136156
]:
137157
"""
138158
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
139159
140-
``with_scores`` sets if the results should be returned with the
141-
similarity scores of the elements in the result.
160+
``with_scores`` sets if returns, for each element, the JSON attribute associated
161+
with the element or None when no attributes are present.
162+
163+
``with_attribs`` sets if the results should be returned with the
164+
attributes of the elements in the result.
142165
143166
``count`` sets the number of results to return.
144167
@@ -177,6 +200,10 @@ def vsim(
177200
pieces.append("WITHSCORES")
178201
options[CallbacksOptions.WITHSCORES.value] = True
179202

203+
if with_attribs:
204+
pieces.append("WITHATTRIBS")
205+
options[CallbacksOptions.WITHATTRIBS.value] = True
206+
180207
if count:
181208
pieces.extend(["COUNT", count])
182209

redis/commands/vectorset/utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
from redis._parsers.helpers import pairs_to_dict
24
from redis.commands.vectorset.commands import CallbacksOptions
35

@@ -75,19 +77,40 @@ def parse_vsim_result(response, **options):
7577
structures depending on input options.
7678
Parsing VSIM result into:
7779
- List[List[str]]
78-
- List[Dict[str, Number]]
80+
- List[Dict[str, Number]] - when with_scores is used (without attributes)
81+
- List[Dict[str, Mapping[str, Any]]] - when with_attribs is used (without scores)
82+
- List[Dict[str, Union[Number, Mapping[str, Any]]]] - when with_scores and with_attribs are used
83+
7984
"""
8085
if response is None:
8186
return response
8287

83-
if options.get(CallbacksOptions.WITHSCORES.value):
88+
withscores = bool(options.get(CallbacksOptions.WITHSCORES.value))
89+
withattribs = bool(options.get(CallbacksOptions.WITHATTRIBS.value))
90+
91+
if withscores ^ withattribs:
8492
# Redis will return a list of list of pairs.
8593
# This list have to be transformed to dict
8694
result_dict = {}
8795
for key, value in pairs_to_dict(response).items():
88-
value = float(value)
96+
if withscores:
97+
value = float(value)
98+
else:
99+
value = json.loads(value) if value else None
100+
89101
result_dict[key] = value
90102
return result_dict
103+
elif withscores and withattribs:
104+
it = iter(response)
105+
result_dict = {}
106+
for elem, score, attribs in zip(it, it, it):
107+
if attribs is not None:
108+
attribs_dict = json.loads(attribs)
109+
else:
110+
attribs_dict = None
111+
112+
result_dict[elem] = {"score": float(score), "attributes": attribs_dict}
113+
return result_dict
91114
else:
92115
# return the list of elements for each level
93116
# list of lists

tests/test_asyncio/test_vsets.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,80 @@ async def test_vsim_with_scores(d_client):
262262
assert 0 <= vsim["elem1"] <= 1
263263

264264

265+
@skip_if_server_version_lt("7.9.0")
266+
async def test_vsim_with_attribs_attribs_set(d_client):
267+
elements_count = 5
268+
vector_dim = 10
269+
attrs_dict = {"key1": "value1", "key2": "value2"}
270+
for i in range(elements_count):
271+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
272+
await d_client.vset().vadd(
273+
"myset",
274+
float_array,
275+
f"elem{i}",
276+
numlinks=64,
277+
attributes=attrs_dict if i % 2 == 0 else None,
278+
)
279+
280+
vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True)
281+
assert len(vsim) == 5
282+
assert isinstance(vsim, dict)
283+
assert vsim["elem1"] is None
284+
assert vsim["elem2"] == attrs_dict
285+
286+
287+
@skip_if_server_version_lt("7.9.0")
288+
async def test_vsim_with_scores_and_attribs_attribs_set(d_client):
289+
elements_count = 5
290+
vector_dim = 10
291+
attrs_dict = {"key1": "value1", "key2": "value2"}
292+
for i in range(elements_count):
293+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
294+
await d_client.vset().vadd(
295+
"myset",
296+
float_array,
297+
f"elem{i}",
298+
numlinks=64,
299+
attributes=attrs_dict if i % 2 == 0 else None,
300+
)
301+
302+
vsim = await d_client.vset().vsim(
303+
"myset", input="elem1", with_scores=True, with_attribs=True
304+
)
305+
assert len(vsim) == 5
306+
assert isinstance(vsim, dict)
307+
assert isinstance(vsim["elem1"], dict)
308+
assert "score" in vsim["elem1"]
309+
assert "attributes" in vsim["elem1"]
310+
assert isinstance(vsim["elem1"]["score"], float)
311+
assert vsim["elem1"]["attributes"] is None
312+
313+
assert isinstance(vsim["elem2"], dict)
314+
assert "score" in vsim["elem2"]
315+
assert "attributes" in vsim["elem2"]
316+
assert isinstance(vsim["elem2"]["score"], float)
317+
assert vsim["elem2"]["attributes"] == attrs_dict
318+
319+
320+
@skip_if_server_version_lt("7.9.0")
321+
async def test_vsim_with_attribs_attribs_not_set(d_client):
322+
elements_count = 20
323+
vector_dim = 50
324+
for i in range(elements_count):
325+
float_array = [random.uniform(0, 10) for x in range(vector_dim)]
326+
await d_client.vset().vadd(
327+
"myset",
328+
float_array,
329+
f"elem{i}",
330+
numlinks=64,
331+
)
332+
333+
vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True)
334+
assert len(vsim) == 10
335+
assert isinstance(vsim, dict)
336+
assert vsim["elem1"] is None
337+
338+
265339
@skip_if_server_version_lt("7.9.0")
266340
async def test_vsim_with_different_vector_input_types(d_client):
267341
elements_count = 10
@@ -789,9 +863,12 @@ async def test_vrandmember(d_client):
789863
async def test_vset_commands_without_decoding_responces(client):
790864
# test vadd
791865
elements = ["elem1", "elem2", "elem3"]
866+
attrs_dict = {"key1": "value1", "key2": "value2"}
792867
for elem in elements:
793868
float_array = [random.uniform(0.5, 10) for x in range(0, 8)]
794-
resp = await client.vset().vadd("myset", float_array, element=elem)
869+
resp = await client.vset().vadd(
870+
"myset", float_array, element=elem, attributes=attrs_dict
871+
)
795872
assert resp == 1
796873

797874
# test vemb
@@ -820,6 +897,28 @@ async def test_vset_commands_without_decoding_responces(client):
820897
assert isinstance(vsim_with_scores, dict)
821898
assert isinstance(vsim_with_scores[b"elem1"], float)
822899

900+
# test vsim with attributes
901+
vsim_with_attribs = await client.vset().vsim(
902+
"myset", input="elem1", with_attribs=True
903+
)
904+
assert len(vsim_with_attribs) == 3
905+
assert isinstance(vsim_with_attribs, dict)
906+
assert isinstance(vsim_with_attribs[b"elem1"], dict)
907+
assert vsim_with_attribs[b"elem1"] == attrs_dict
908+
909+
# test vsim with score and attributes
910+
vsim_with_scores_and_attribs = await client.vset().vsim(
911+
"myset", input="elem1", with_scores=True, with_attribs=True
912+
)
913+
assert len(vsim_with_scores_and_attribs) == 3
914+
assert isinstance(vsim_with_scores_and_attribs, dict)
915+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"], dict)
916+
assert "score" in vsim_with_scores_and_attribs[b"elem1"]
917+
assert "attributes" in vsim_with_scores_and_attribs[b"elem1"]
918+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["score"], float)
919+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["attributes"], dict)
920+
assert vsim_with_scores_and_attribs[b"elem1"]["attributes"] == attrs_dict
921+
823922
# test vlinks - no scores
824923
element_links_all_layers = await client.vset().vlinks("myset", "elem1")
825924
assert len(element_links_all_layers) >= 1

tests/test_vsets.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,80 @@ def test_vsim_with_scores(d_client):
264264
assert 0 <= vsim["elem1"] <= 1
265265

266266

267+
@skip_if_server_version_lt("7.9.0")
268+
def test_vsim_with_attribs_attribs_set(d_client):
269+
elements_count = 5
270+
vector_dim = 10
271+
attrs_dict = {"key1": "value1", "key2": "value2"}
272+
for i in range(elements_count):
273+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
274+
d_client.vset().vadd(
275+
"myset",
276+
float_array,
277+
f"elem{i}",
278+
numlinks=64,
279+
attributes=attrs_dict if i % 2 == 0 else None,
280+
)
281+
282+
vsim = d_client.vset().vsim("myset", input="elem1", with_attribs=True)
283+
assert len(vsim) == 5
284+
assert isinstance(vsim, dict)
285+
assert vsim["elem1"] is None
286+
assert vsim["elem2"] == attrs_dict
287+
288+
289+
@skip_if_server_version_lt("7.9.0")
290+
def test_vsim_with_scores_and_attribs_attribs_set(d_client):
291+
elements_count = 5
292+
vector_dim = 10
293+
attrs_dict = {"key1": "value1", "key2": "value2"}
294+
for i in range(elements_count):
295+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
296+
d_client.vset().vadd(
297+
"myset",
298+
float_array,
299+
f"elem{i}",
300+
numlinks=64,
301+
attributes=attrs_dict if i % 2 == 0 else None,
302+
)
303+
304+
vsim = d_client.vset().vsim(
305+
"myset", input="elem1", with_scores=True, with_attribs=True
306+
)
307+
assert len(vsim) == 5
308+
assert isinstance(vsim, dict)
309+
assert isinstance(vsim["elem1"], dict)
310+
assert "score" in vsim["elem1"]
311+
assert "attributes" in vsim["elem1"]
312+
assert isinstance(vsim["elem1"]["score"], float)
313+
assert vsim["elem1"]["attributes"] is None
314+
315+
assert isinstance(vsim["elem2"], dict)
316+
assert "score" in vsim["elem2"]
317+
assert "attributes" in vsim["elem2"]
318+
assert isinstance(vsim["elem2"]["score"], float)
319+
assert vsim["elem2"]["attributes"] == attrs_dict
320+
321+
322+
@skip_if_server_version_lt("7.9.0")
323+
def test_vsim_with_attribs_attribs_not_set(d_client):
324+
elements_count = 20
325+
vector_dim = 50
326+
for i in range(elements_count):
327+
float_array = [random.uniform(0, 10) for x in range(vector_dim)]
328+
d_client.vset().vadd(
329+
"myset",
330+
float_array,
331+
f"elem{i}",
332+
numlinks=64,
333+
)
334+
335+
vsim = d_client.vset().vsim("myset", input="elem1", with_attribs=True)
336+
assert len(vsim) == 10
337+
assert isinstance(vsim, dict)
338+
assert vsim["elem1"] is None
339+
340+
267341
@skip_if_server_version_lt("7.9.0")
268342
def test_vsim_with_different_vector_input_types(d_client):
269343
elements_count = 10
@@ -789,9 +863,12 @@ def test_vrandmember(d_client):
789863
def test_vset_commands_without_decoding_responces(client):
790864
# test vadd
791865
elements = ["elem1", "elem2", "elem3"]
866+
attrs_dict = {"key1": "value1", "key2": "value2"}
792867
for elem in elements:
793868
float_array = [random.uniform(0.5, 10) for x in range(0, 8)]
794-
resp = client.vset().vadd("myset", float_array, element=elem)
869+
resp = client.vset().vadd(
870+
"myset", float_array, element=elem, attributes=attrs_dict
871+
)
795872
assert resp == 1
796873

797874
# test vemb
@@ -818,6 +895,26 @@ def test_vset_commands_without_decoding_responces(client):
818895
assert isinstance(vsim_with_scores, dict)
819896
assert isinstance(vsim_with_scores[b"elem1"], float)
820897

898+
# test vsim with attributes
899+
vsim_with_attribs = client.vset().vsim("myset", input="elem1", with_attribs=True)
900+
assert len(vsim_with_attribs) == 3
901+
assert isinstance(vsim_with_attribs, dict)
902+
assert isinstance(vsim_with_attribs[b"elem1"], dict)
903+
assert vsim_with_attribs[b"elem1"] == attrs_dict
904+
905+
# test vsim with score and attributes
906+
vsim_with_scores_and_attribs = client.vset().vsim(
907+
"myset", input="elem1", with_scores=True, with_attribs=True
908+
)
909+
assert len(vsim_with_scores_and_attribs) == 3
910+
assert isinstance(vsim_with_scores_and_attribs, dict)
911+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"], dict)
912+
assert "score" in vsim_with_scores_and_attribs[b"elem1"]
913+
assert "attributes" in vsim_with_scores_and_attribs[b"elem1"]
914+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["score"], float)
915+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["attributes"], dict)
916+
assert vsim_with_scores_and_attribs[b"elem1"]["attributes"] == attrs_dict
917+
821918
# test vlinks - no scores
822919
element_links_all_layers = client.vset().vlinks("myset", "elem1")
823920
assert len(element_links_all_layers) >= 1

0 commit comments

Comments
 (0)