Skip to content

Commit 0164bbd

Browse files
authored
Merge branch 'master' into ps_fix_readthedocs_yml_format
2 parents cee3ce4 + f0f22db commit 0164bbd

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

redis/commands/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
GroupT,
3636
KeysT,
3737
KeyT,
38+
Number,
3839
PatternT,
3940
ResponseT,
4041
ScriptTextT,
@@ -2567,7 +2568,7 @@ class ListCommands(CommandsProtocol):
25672568
"""
25682569

25692570
def blpop(
2570-
self, keys: List, timeout: Optional[int] = 0
2571+
self, keys: List, timeout: Optional[Number] = 0
25712572
) -> Union[Awaitable[list], list]:
25722573
"""
25732574
LPOP a value off of the first non-empty list
@@ -2588,7 +2589,7 @@ def blpop(
25882589
return self.execute_command("BLPOP", *keys)
25892590

25902591
def brpop(
2591-
self, keys: List, timeout: Optional[int] = 0
2592+
self, keys: List, timeout: Optional[Number] = 0
25922593
) -> Union[Awaitable[list], list]:
25932594
"""
25942595
RPOP a value off of the first non-empty list
@@ -2609,7 +2610,7 @@ def brpop(
26092610
return self.execute_command("BRPOP", *keys)
26102611

26112612
def brpoplpush(
2612-
self, src: str, dst: str, timeout: Optional[int] = 0
2613+
self, src: str, dst: str, timeout: Optional[Number] = 0
26132614
) -> Union[Awaitable[Optional[str]], Optional[str]]:
26142615
"""
26152616
Pop a value off the tail of ``src``, push it on the head of ``dst``

redis/commands/search/commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _get_aggregate_result(
586586

587587
def profile(
588588
self,
589-
query: Union[str, Query, AggregateRequest],
589+
query: Union[Query, AggregateRequest],
590590
limited: bool = False,
591591
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
592592
):
@@ -596,7 +596,7 @@ def profile(
596596
597597
### Parameters
598598
599-
**query**: This can be either an `AggregateRequest`, `Query` or string.
599+
**query**: This can be either an `AggregateRequest` or `Query`.
600600
**limited**: If set to True, removes details of reader iterator.
601601
**query_params**: Define one or more value parameters.
602602
Each parameter has a name and a value.

tests/test_search.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2856,6 +2856,64 @@ def test_vector_search_with_default_dialect(client):
28562856
assert res["total_results"] == 2
28572857

28582858

2859+
@pytest.mark.redismod
2860+
@skip_if_server_version_lt("7.9.0")
2861+
def test_vector_search_with_int8_type(client):
2862+
client.ft().create_index(
2863+
(VectorField("v", "FLAT", {"TYPE": "INT8", "DIM": 2, "DISTANCE_METRIC": "L2"}),)
2864+
)
2865+
2866+
a = [1.5, 10]
2867+
b = [123, 100]
2868+
c = [1, 1]
2869+
2870+
client.hset("a", "v", np.array(a, dtype=np.int8).tobytes())
2871+
client.hset("b", "v", np.array(b, dtype=np.int8).tobytes())
2872+
client.hset("c", "v", np.array(c, dtype=np.int8).tobytes())
2873+
2874+
query = Query("*=>[KNN 2 @v $vec as score]")
2875+
query_params = {"vec": np.array(a, dtype=np.int8).tobytes()}
2876+
2877+
assert 2 in query.get_args()
2878+
2879+
res = client.ft().search(query, query_params=query_params)
2880+
if is_resp2_connection(client):
2881+
assert res.total == 2
2882+
else:
2883+
assert res["total_results"] == 2
2884+
2885+
2886+
@pytest.mark.redismod
2887+
@skip_if_server_version_lt("7.9.0")
2888+
def test_vector_search_with_uint8_type(client):
2889+
client.ft().create_index(
2890+
(
2891+
VectorField(
2892+
"v", "FLAT", {"TYPE": "UINT8", "DIM": 2, "DISTANCE_METRIC": "L2"}
2893+
),
2894+
)
2895+
)
2896+
2897+
a = [1.5, 10]
2898+
b = [123, 100]
2899+
c = [1, 1]
2900+
2901+
client.hset("a", "v", np.array(a, dtype=np.uint8).tobytes())
2902+
client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes())
2903+
client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes())
2904+
2905+
query = Query("*=>[KNN 2 @v $vec as score]")
2906+
query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()}
2907+
2908+
assert 2 in query.get_args()
2909+
2910+
res = client.ft().search(query, query_params=query_params)
2911+
if is_resp2_connection(client):
2912+
assert res.total == 2
2913+
else:
2914+
assert res["total_results"] == 2
2915+
2916+
28592917
@pytest.mark.redismod
28602918
@skip_ifmodversion_lt("2.4.3", "search")
28612919
def test_search_query_with_different_dialects(client):

0 commit comments

Comments
 (0)