Skip to content

Commit d24c9c9

Browse files
authored
Adding WITHATTRIBS option to vector set's vsim command. (#3746)
1 parent 78fb85e commit d24c9c9

File tree

5 files changed

+293
-17
lines changed

5 files changed

+293
-17
lines changed

redis/commands/vectorset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ def __init__(self, client, **kwargs):
2424
# Set the module commands' callbacks
2525
self._MODULE_CALLBACKS = {
2626
VEMB_CMD: parse_vemb_result,
27+
VSIM_CMD: parse_vsim_result,
2728
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
2829
}
2930

3031
self._RESP2_MODULE_CALLBACKS = {
3132
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
32-
VSIM_CMD: parse_vsim_result,
3333
VLINKS_CMD: parse_vlinks_result,
3434
}
3535
self._RESP3_MODULE_CALLBACKS = {}

redis/commands/vectorset/commands.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from enum import Enum
3-
from typing import Awaitable, Dict, List, Optional, Union
3+
from typing import Any, Awaitable, Dict, List, Optional, Union
44

55
from redis.client import NEVER_DECODE
66
from redis.commands.helpers import get_protocol_version
@@ -19,6 +19,15 @@
1919
VGETATTR_CMD = "VGETATTR"
2020
VRANDMEMBER_CMD = "VRANDMEMBER"
2121

22+
# Return type for vsim command
23+
VSimResult = Optional[
24+
List[
25+
Union[
26+
List[EncodableT], Dict[EncodableT, Number], Dict[EncodableT, Dict[str, Any]]
27+
]
28+
]
29+
]
30+
2231

2332
class QuantizationOptions(Enum):
2433
"""Quantization options for the VADD command."""
@@ -33,6 +42,7 @@ class CallbacksOptions(Enum):
3342

3443
RAW = "RAW"
3544
WITHSCORES = "WITHSCORES"
45+
WITHATTRIBS = "WITHATTRIBS"
3646
ALLOW_DECODING = "ALLOW_DECODING"
3747
RESP3 = "RESP3"
3848

@@ -123,22 +133,22 @@ def vsim(
123133
key: KeyT,
124134
input: Union[List[float], bytes, str],
125135
with_scores: Optional[bool] = False,
136+
with_attribs: Optional[bool] = False,
126137
count: Optional[int] = None,
127138
ef: Optional[Number] = None,
128139
filter: Optional[str] = None,
129140
filter_ef: Optional[str] = None,
130141
truth: Optional[bool] = False,
131142
no_thread: Optional[bool] = False,
132143
epsilon: Optional[Number] = None,
133-
) -> Union[
134-
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
135-
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
136-
]:
144+
) -> Union[Awaitable[VSimResult], VSimResult]:
137145
"""
138146
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
139147
140-
``with_scores`` sets if the results should be returned with the
141-
similarity scores of the elements in the result.
148+
``with_scores`` sets if similarity scores should be returned for each element in the result.
149+
150+
``with_attribs`` ``with_attribs`` sets if the results should be returned with the
151+
attributes of the elements in the result, or None when no attributes are present.
142152
143153
``count`` sets the number of results to return.
144154
@@ -173,9 +183,17 @@ def vsim(
173183
else:
174184
pieces.extend(["ELE", input])
175185

176-
if with_scores:
177-
pieces.append("WITHSCORES")
178-
options[CallbacksOptions.WITHSCORES.value] = True
186+
if with_scores or with_attribs:
187+
if get_protocol_version(self.client) in ["3", 3]:
188+
options[CallbacksOptions.RESP3.value] = True
189+
190+
if with_scores:
191+
pieces.append("WITHSCORES")
192+
options[CallbacksOptions.WITHSCORES.value] = True
193+
194+
if with_attribs:
195+
pieces.append("WITHATTRIBS")
196+
options[CallbacksOptions.WITHATTRIBS.value] = True
179197

180198
if count:
181199
pieces.extend(["COUNT", count])

redis/commands/vectorset/utils.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
from redis._parsers.helpers import pairs_to_dict
24
from redis.commands.vectorset.commands import CallbacksOptions
35

@@ -75,19 +77,53 @@ def parse_vsim_result(response, **options):
7577
structures depending on input options.
7678
Parsing VSIM result into:
7779
- List[List[str]]
78-
- List[Dict[str, Number]]
80+
- List[Dict[str, Number]] - when with_scores is used (without attributes)
81+
- List[Dict[str, Mapping[str, Any]]] - when with_attribs is used (without scores)
82+
- List[Dict[str, Union[Number, Mapping[str, Any]]]] - when with_scores and with_attribs are used
83+
7984
"""
8085
if response is None:
8186
return response
8287

83-
if options.get(CallbacksOptions.WITHSCORES.value):
88+
withscores = bool(options.get(CallbacksOptions.WITHSCORES.value))
89+
withattribs = bool(options.get(CallbacksOptions.WITHATTRIBS.value))
90+
91+
# Exactly one of withscores or withattribs is True
92+
if (withscores and not withattribs) or (not withscores and withattribs):
8493
# Redis will return a list of list of pairs.
8594
# This list have to be transformed to dict
8695
result_dict = {}
87-
for key, value in pairs_to_dict(response).items():
88-
value = float(value)
96+
if options.get(CallbacksOptions.RESP3.value):
97+
resp_dict = response
98+
else:
99+
resp_dict = pairs_to_dict(response)
100+
for key, value in resp_dict.items():
101+
if withscores:
102+
value = float(value)
103+
else:
104+
value = json.loads(value) if value else None
105+
89106
result_dict[key] = value
90107
return result_dict
108+
elif withscores and withattribs:
109+
it = iter(response)
110+
result_dict = {}
111+
if options.get(CallbacksOptions.RESP3.value):
112+
for elem, data in response.items():
113+
if data[1] is not None:
114+
attribs_dict = json.loads(data[1])
115+
else:
116+
attribs_dict = None
117+
result_dict[elem] = {"score": data[0], "attributes": attribs_dict}
118+
else:
119+
for elem, score, attribs in zip(it, it, it):
120+
if attribs is not None:
121+
attribs_dict = json.loads(attribs)
122+
else:
123+
attribs_dict = None
124+
125+
result_dict[elem] = {"score": float(score), "attributes": attribs_dict}
126+
return result_dict
91127
else:
92128
# return the list of elements for each level
93129
# list of lists

tests/test_asyncio/test_vsets.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,80 @@ async def test_vsim_with_scores(d_client):
262262
assert 0 <= vsim["elem1"] <= 1
263263

264264

265+
@skip_if_server_version_lt("8.2.0")
266+
async def test_vsim_with_attribs_attribs_set(d_client):
267+
elements_count = 5
268+
vector_dim = 10
269+
attrs_dict = {"key1": "value1", "key2": "value2"}
270+
for i in range(elements_count):
271+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
272+
await d_client.vset().vadd(
273+
"myset",
274+
float_array,
275+
f"elem{i}",
276+
numlinks=64,
277+
attributes=attrs_dict if i % 2 == 0 else None,
278+
)
279+
280+
vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True)
281+
assert len(vsim) == 5
282+
assert isinstance(vsim, dict)
283+
assert vsim["elem1"] is None
284+
assert vsim["elem2"] == attrs_dict
285+
286+
287+
@skip_if_server_version_lt("8.2.0")
288+
async def test_vsim_with_scores_and_attribs_attribs_set(d_client):
289+
elements_count = 5
290+
vector_dim = 10
291+
attrs_dict = {"key1": "value1", "key2": "value2"}
292+
for i in range(elements_count):
293+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
294+
await d_client.vset().vadd(
295+
"myset",
296+
float_array,
297+
f"elem{i}",
298+
numlinks=64,
299+
attributes=attrs_dict if i % 2 == 0 else None,
300+
)
301+
302+
vsim = await d_client.vset().vsim(
303+
"myset", input="elem1", with_scores=True, with_attribs=True
304+
)
305+
assert len(vsim) == 5
306+
assert isinstance(vsim, dict)
307+
assert isinstance(vsim["elem1"], dict)
308+
assert "score" in vsim["elem1"]
309+
assert "attributes" in vsim["elem1"]
310+
assert isinstance(vsim["elem1"]["score"], float)
311+
assert vsim["elem1"]["attributes"] is None
312+
313+
assert isinstance(vsim["elem2"], dict)
314+
assert "score" in vsim["elem2"]
315+
assert "attributes" in vsim["elem2"]
316+
assert isinstance(vsim["elem2"]["score"], float)
317+
assert vsim["elem2"]["attributes"] == attrs_dict
318+
319+
320+
@skip_if_server_version_lt("8.2.0")
321+
async def test_vsim_with_attribs_attribs_not_set(d_client):
322+
elements_count = 20
323+
vector_dim = 50
324+
for i in range(elements_count):
325+
float_array = [random.uniform(0, 10) for x in range(vector_dim)]
326+
await d_client.vset().vadd(
327+
"myset",
328+
float_array,
329+
f"elem{i}",
330+
numlinks=64,
331+
)
332+
333+
vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True)
334+
assert len(vsim) == 10
335+
assert isinstance(vsim, dict)
336+
assert vsim["elem1"] is None
337+
338+
265339
@skip_if_server_version_lt("7.9.0")
266340
async def test_vsim_with_different_vector_input_types(d_client):
267341
elements_count = 10
@@ -785,13 +859,51 @@ async def test_vrandmember(d_client):
785859
assert members_list == []
786860

787861

862+
@skip_if_server_version_lt("8.2.0")
863+
async def test_8_2_new_vset_features_without_decoding_responces(client):
864+
# test vadd
865+
elements = ["elem1", "elem2", "elem3"]
866+
attrs_dict = {"key1": "value1", "key2": "value2"}
867+
for elem in elements:
868+
float_array = [random.uniform(0.5, 10) for x in range(0, 8)]
869+
resp = await client.vset().vadd(
870+
"myset", float_array, element=elem, attributes=attrs_dict
871+
)
872+
assert resp == 1
873+
874+
# test vsim with attributes
875+
vsim_with_attribs = await client.vset().vsim(
876+
"myset", input="elem1", with_attribs=True
877+
)
878+
assert len(vsim_with_attribs) == 3
879+
assert isinstance(vsim_with_attribs, dict)
880+
assert isinstance(vsim_with_attribs[b"elem1"], dict)
881+
assert vsim_with_attribs[b"elem1"] == attrs_dict
882+
883+
# test vsim with score and attributes
884+
vsim_with_scores_and_attribs = await client.vset().vsim(
885+
"myset", input="elem1", with_scores=True, with_attribs=True
886+
)
887+
assert len(vsim_with_scores_and_attribs) == 3
888+
assert isinstance(vsim_with_scores_and_attribs, dict)
889+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"], dict)
890+
assert "score" in vsim_with_scores_and_attribs[b"elem1"]
891+
assert "attributes" in vsim_with_scores_and_attribs[b"elem1"]
892+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["score"], float)
893+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["attributes"], dict)
894+
assert vsim_with_scores_and_attribs[b"elem1"]["attributes"] == attrs_dict
895+
896+
788897
@skip_if_server_version_lt("7.9.0")
789898
async def test_vset_commands_without_decoding_responces(client):
790899
# test vadd
791900
elements = ["elem1", "elem2", "elem3"]
901+
attrs_dict = {"key1": "value1", "key2": "value2"}
792902
for elem in elements:
793903
float_array = [random.uniform(0.5, 10) for x in range(0, 8)]
794-
resp = await client.vset().vadd("myset", float_array, element=elem)
904+
resp = await client.vset().vadd(
905+
"myset", float_array, element=elem, attributes=attrs_dict
906+
)
795907
assert resp == 1
796908

797909
# test vemb

0 commit comments

Comments
 (0)