Skip to content

Commit c210c95

Browse files
committed
Skip test for RESP3 and add async test
1 parent 7ca2f29 commit c210c95

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

tests/test_asyncio/test_search.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,22 @@
44
import time
55
from io import TextIOWrapper
66

7+
import numpy as np
78
import pytest
89
import pytest_asyncio
910
import redis.asyncio as redis
1011
import redis.commands.search
1112
import redis.commands.search.aggregation as aggregations
1213
import redis.commands.search.reducers as reducers
1314
from redis.commands.search import AsyncSearch
14-
from redis.commands.search.field import GeoField, NumericField, TagField, TextField
15-
from redis.commands.search.indexDefinition import IndexDefinition
15+
from redis.commands.search.field import (
16+
GeoField,
17+
NumericField,
18+
TagField,
19+
TextField,
20+
VectorField,
21+
)
22+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
1623
from redis.commands.search.query import GeoFilter, NumericFilter, Query
1724
from redis.commands.search.result import Result
1825
from redis.commands.search.suggestion import Suggestion
@@ -21,6 +28,7 @@
2128
is_resp2_connection,
2229
skip_if_redis_enterprise,
2330
skip_ifmodversion_lt,
31+
skip_if_resp_version,
2432
)
2533

2634
WILL_PLAY_TEXT = os.path.abspath(
@@ -37,6 +45,11 @@ async def decoded_r(create_redis, stack_url):
3745
return await create_redis(decode_responses=True, url=stack_url)
3846

3947

48+
@pytest_asyncio.fixture()
49+
async def binary_client(create_redis, stack_url):
50+
return await create_redis(decode_responses=False, url=stack_url)
51+
52+
4053
async def waitForIndex(env, idx, timeout=None):
4154
delay = 0.1
4255
while True:
@@ -1560,3 +1573,61 @@ async def test_query_timeout(decoded_r: redis.Redis):
15601573
q2 = Query("foo").timeout("not_a_number")
15611574
with pytest.raises(redis.ResponseError):
15621575
await decoded_r.ft().search(q2)
1576+
1577+
1578+
@pytest.mark.redismod
1579+
@skip_if_resp_version(3)
1580+
async def test_binary_and_text_fields(binary_client):
1581+
assert (
1582+
binary_client.get_connection_kwargs()["decode_responses"] is False
1583+
), "This feature is only available when decode_responses is False"
1584+
1585+
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
1586+
1587+
index_name = "mixed_index"
1588+
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
1589+
await binary_client.hset(f"{index_name}:1", mapping=mixed_data)
1590+
1591+
schema = (
1592+
TagField("first_name"),
1593+
VectorField(
1594+
"embeddings_bio",
1595+
algorithm="HNSW",
1596+
attributes={
1597+
"TYPE": "FLOAT32",
1598+
"DIM": 4,
1599+
"DISTANCE_METRIC": "COSINE",
1600+
},
1601+
),
1602+
)
1603+
1604+
await binary_client.ft(index_name).create_index(
1605+
fields=schema,
1606+
definition=IndexDefinition(
1607+
prefix=[f"{index_name}:"], index_type=IndexType.HASH
1608+
),
1609+
)
1610+
1611+
bytes_person_1 = await binary_client.hget(f"{index_name}:1", "vector_emb")
1612+
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
1613+
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"
1614+
1615+
query = (
1616+
Query("*")
1617+
.return_field("vector_emb", decode_field=False)
1618+
.return_field("first_name", decode_field=True)
1619+
)
1620+
result = await binary_client.ft(index_name).search(query=query, query_params={})
1621+
docs = result.docs
1622+
1623+
decoded_vec_from_search_results = np.frombuffer(
1624+
docs[0]["vector_emb"], dtype=np.float32
1625+
)
1626+
1627+
assert np.array_equal(
1628+
decoded_vec_from_search_results, fake_vec
1629+
), "The vectors are not equal"
1630+
1631+
assert (
1632+
docs[0]["first_name"] == mixed_data["first_name"]
1633+
), "The first is not decoded correctly"

tests/test_search.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
is_resp2_connection,
3232
skip_if_redis_enterprise,
3333
skip_ifmodversion_lt,
34+
skip_if_resp_version,
3435
)
3536

3637
WILL_PLAY_TEXT = os.path.abspath(
@@ -116,7 +117,9 @@ def client(request, stack_url):
116117

117118
@pytest.fixture
118119
def binary_client(request, stack_url):
119-
r = _get_client(redis.Redis, request, decode_responses=False, from_url=stack_url)
120+
r = _get_client(
121+
redis.Redis, request, decode_responses=False, from_url=stack_url, protocol=3
122+
)
120123
r.flushdb()
121124
return r
122125

@@ -1714,6 +1717,7 @@ def test_search_return_fields(client):
17141717

17151718

17161719
@pytest.mark.redismod
1720+
@skip_if_resp_version(3)
17171721
def test_binary_and_text_fields(binary_client):
17181722
assert (
17191723
binary_client.get_connection_kwargs()["decode_responses"] is False

0 commit comments

Comments
 (0)