Skip to content

Commit ab3f79a

Browse files
committed
Allow field-level decoding for all connections
1 parent f23a2e2 commit ab3f79a

File tree

3 files changed

+15
-42
lines changed

3 files changed

+15
-42
lines changed

redis/commands/search/commands.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from typing import Dict, List, Optional, Union
44

5-
from redis.client import Pipeline
5+
from redis.client import NEVER_DECODE, Pipeline
66
from redis.utils import deprecated_function
77

88
from ..helpers import get_protocol_version, parse_to_dict
@@ -500,7 +500,8 @@ def search(
500500
""" # noqa
501501
args, query = self._mk_query_args(query, query_params=query_params)
502502
st = time.time()
503-
res = self.execute_command(SEARCH_CMD, *args)
503+
options = {NEVER_DECODE: True}
504+
res = self.execute_command(SEARCH_CMD, *args, **options)
504505

505506
if isinstance(res, Pipeline):
506507
return res

tests/test_asyncio/test_search.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,6 @@ async def decoded_r(create_redis, stack_url):
4545
return await create_redis(decode_responses=True, url=stack_url)
4646

4747

48-
@pytest_asyncio.fixture()
49-
async def binary_client(create_redis, stack_url):
50-
return await create_redis(decode_responses=False, url=stack_url)
51-
52-
5348
async def waitForIndex(env, idx, timeout=None):
5449
delay = 0.1
5550
while True:
@@ -1577,16 +1572,12 @@ async def test_query_timeout(decoded_r: redis.Redis):
15771572

15781573
@pytest.mark.redismod
15791574
@skip_if_resp_version(3)
1580-
async def test_binary_and_text_fields(binary_client):
1581-
assert (
1582-
binary_client.get_connection_kwargs()["decode_responses"] is False
1583-
), "This feature is only available when decode_responses is False"
1584-
1575+
async def test_binary_and_text_fields(client):
15851576
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
15861577

15871578
index_name = "mixed_index"
15881579
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
1589-
await binary_client.hset(f"{index_name}:1", mapping=mixed_data)
1580+
await client.hset(f"{index_name}:1", mapping=mixed_data)
15901581

15911582
schema = (
15921583
TagField("first_name"),
@@ -1601,23 +1592,19 @@ async def test_binary_and_text_fields(binary_client):
16011592
),
16021593
)
16031594

1604-
await binary_client.ft(index_name).create_index(
1595+
await client.ft(index_name).create_index(
16051596
fields=schema,
16061597
definition=IndexDefinition(
16071598
prefix=[f"{index_name}:"], index_type=IndexType.HASH
16081599
),
16091600
)
16101601

1611-
bytes_person_1 = await binary_client.hget(f"{index_name}:1", "vector_emb")
1612-
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
1613-
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"
1614-
16151602
query = (
16161603
Query("*")
16171604
.return_field("vector_emb", decode_field=False)
1618-
.return_field("first_name", decode_field=True)
1605+
.return_field("first_name")
16191606
)
1620-
result = await binary_client.ft(index_name).search(query=query, query_params={})
1607+
result = await client.ft(index_name).search(query=query, query_params={})
16211608
docs = result.docs
16221609

16231610
decoded_vec_from_search_results = np.frombuffer(
@@ -1630,4 +1617,4 @@ async def test_binary_and_text_fields(binary_client):
16301617

16311618
assert (
16321619
docs[0]["first_name"] == mixed_data["first_name"]
1633-
), "The first is not decoded correctly"
1620+
), "The text field is not decoded correctly"

tests/test_search.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,6 @@ def client(request, stack_url):
115115
return r
116116

117117

118-
@pytest.fixture
119-
def binary_client(request, stack_url):
120-
r = _get_client(redis.Redis, request, decode_responses=False, from_url=stack_url)
121-
r.flushdb()
122-
return r
123-
124-
125118
@pytest.mark.redismod
126119
def test_client(client):
127120
num_docs = 500
@@ -1716,16 +1709,12 @@ def test_search_return_fields(client):
17161709

17171710
@pytest.mark.redismod
17181711
@skip_if_resp_version(3)
1719-
def test_binary_and_text_fields(binary_client):
1720-
assert (
1721-
binary_client.get_connection_kwargs()["decode_responses"] is False
1722-
), "This feature is only available when decode_responses is False"
1723-
1712+
def test_binary_and_text_fields(client):
17241713
fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)
17251714

17261715
index_name = "mixed_index"
17271716
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
1728-
binary_client.hset(f"{index_name}:1", mapping=mixed_data)
1717+
client.hset(f"{index_name}:1", mapping=mixed_data)
17291718

17301719
schema = (
17311720
TagField("first_name"),
@@ -1740,23 +1729,19 @@ def test_binary_and_text_fields(binary_client):
17401729
),
17411730
)
17421731

1743-
binary_client.ft(index_name).create_index(
1732+
client.ft(index_name).create_index(
17441733
fields=schema,
17451734
definition=IndexDefinition(
17461735
prefix=[f"{index_name}:"], index_type=IndexType.HASH
17471736
),
17481737
)
17491738

1750-
bytes_person_1 = binary_client.hget(f"{index_name}:1", "vector_emb")
1751-
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
1752-
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"
1753-
17541739
query = (
17551740
Query("*")
17561741
.return_field("vector_emb", decode_field=False)
1757-
.return_field("first_name", decode_field=True)
1742+
.return_field("first_name")
17581743
)
1759-
docs = binary_client.ft(index_name).search(query=query, query_params={}).docs
1744+
docs = client.ft(index_name).search(query=query, query_params={}).docs
17601745
decoded_vec_from_search_results = np.frombuffer(
17611746
docs[0]["vector_emb"], dtype=np.float32
17621747
)
@@ -1767,7 +1752,7 @@ def test_binary_and_text_fields(binary_client):
17671752

17681753
assert (
17691754
docs[0]["first_name"] == mixed_data["first_name"]
1770-
), "The first is not decoded correctly"
1755+
), "The text field is not decoded correctly"
17711756

17721757

17731758
@pytest.mark.redismod

0 commit comments

Comments
 (0)