4
4
import time
5
5
from io import TextIOWrapper
6
6
7
+ import numpy as np
7
8
import pytest
8
9
import pytest_asyncio
9
10
import redis .asyncio as redis
10
11
import redis .commands .search
11
12
import redis .commands .search .aggregation as aggregations
12
13
import redis .commands .search .reducers as reducers
13
14
from redis .commands .search import AsyncSearch
14
- from redis .commands .search .field import GeoField , NumericField , TagField , TextField
15
- from redis .commands .search .indexDefinition import IndexDefinition
15
+ from redis .commands .search .field import (
16
+ GeoField ,
17
+ NumericField ,
18
+ TagField ,
19
+ TextField ,
20
+ VectorField ,
21
+ )
22
+ from redis .commands .search .indexDefinition import IndexDefinition , IndexType
16
23
from redis .commands .search .query import GeoFilter , NumericFilter , Query
17
24
from redis .commands .search .result import Result
18
25
from redis .commands .search .suggestion import Suggestion
21
28
is_resp2_connection ,
22
29
skip_if_redis_enterprise ,
23
30
skip_ifmodversion_lt ,
31
+ skip_if_resp_version ,
24
32
)
25
33
26
34
WILL_PLAY_TEXT = os .path .abspath (
@@ -37,6 +45,11 @@ async def decoded_r(create_redis, stack_url):
37
45
return await create_redis (decode_responses = True , url = stack_url )
38
46
39
47
48
+ @pytest_asyncio .fixture ()
49
+ async def binary_client (create_redis , stack_url ):
50
+ return await create_redis (decode_responses = False , url = stack_url )
51
+
52
+
40
53
async def waitForIndex (env , idx , timeout = None ):
41
54
delay = 0.1
42
55
while True :
@@ -1560,3 +1573,61 @@ async def test_query_timeout(decoded_r: redis.Redis):
1560
1573
q2 = Query ("foo" ).timeout ("not_a_number" )
1561
1574
with pytest .raises (redis .ResponseError ):
1562
1575
await decoded_r .ft ().search (q2 )
1576
+
1577
+
1578
+ @pytest .mark .redismod
1579
+ @skip_if_resp_version (3 )
1580
+ async def test_binary_and_text_fields (binary_client ):
1581
+ assert (
1582
+ binary_client .get_connection_kwargs ()["decode_responses" ] is False
1583
+ ), "This feature is only available when decode_responses is False"
1584
+
1585
+ fake_vec = np .array ([0.1 , 0.2 , 0.3 , 0.4 ], dtype = np .float32 )
1586
+
1587
+ index_name = "mixed_index"
1588
+ mixed_data = {"first_name" : "🐍python" , "vector_emb" : fake_vec .tobytes ()}
1589
+ await binary_client .hset (f"{ index_name } :1" , mapping = mixed_data )
1590
+
1591
+ schema = (
1592
+ TagField ("first_name" ),
1593
+ VectorField (
1594
+ "embeddings_bio" ,
1595
+ algorithm = "HNSW" ,
1596
+ attributes = {
1597
+ "TYPE" : "FLOAT32" ,
1598
+ "DIM" : 4 ,
1599
+ "DISTANCE_METRIC" : "COSINE" ,
1600
+ },
1601
+ ),
1602
+ )
1603
+
1604
+ await binary_client .ft (index_name ).create_index (
1605
+ fields = schema ,
1606
+ definition = IndexDefinition (
1607
+ prefix = [f"{ index_name } :" ], index_type = IndexType .HASH
1608
+ ),
1609
+ )
1610
+
1611
+ bytes_person_1 = await binary_client .hget (f"{ index_name } :1" , "vector_emb" )
1612
+ decoded_vec_from_hash = np .frombuffer (bytes_person_1 , dtype = np .float32 )
1613
+ assert np .array_equal (decoded_vec_from_hash , fake_vec ), "The vectors are not equal"
1614
+
1615
+ query = (
1616
+ Query ("*" )
1617
+ .return_field ("vector_emb" , decode_field = False )
1618
+ .return_field ("first_name" , decode_field = True )
1619
+ )
1620
+ result = await binary_client .ft (index_name ).search (query = query , query_params = {})
1621
+ docs = result .docs
1622
+
1623
+ decoded_vec_from_search_results = np .frombuffer (
1624
+ docs [0 ]["vector_emb" ], dtype = np .float32
1625
+ )
1626
+
1627
+ assert np .array_equal (
1628
+ decoded_vec_from_search_results , fake_vec
1629
+ ), "The vectors are not equal"
1630
+
1631
+ assert (
1632
+ docs [0 ]["first_name" ] == mixed_data ["first_name" ]
1633
+ ), "The first is not decoded correctly"
0 commit comments