Skip to content

Commit 976063d

Browse files
Merge branch 'master' into DOC-4345-json-intro
2 parents 78f96f4 + 00f5be4 commit 976063d

File tree

5 files changed

+251
-1
lines changed

5 files changed

+251
-1
lines changed

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ packaging>=20.4
99
pytest
1010
pytest-asyncio>=0.23.0,<0.24.0
1111
pytest-cov
12-
pytest-profiling
12+
pytest-profiling==1.7.0
1313
pytest-timeout
1414
ujson>=4.2.0
1515
uvloop

doctests/query_combined.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# EXAMPLE: query_combined
2+
# HIDE_START
3+
import json
4+
import numpy as np
5+
import redis
6+
import warnings
7+
from redis.commands.json.path import Path
8+
from redis.commands.search.field import NumericField, TagField, TextField, VectorField
9+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
10+
from redis.commands.search.query import Query
11+
from sentence_transformers import SentenceTransformer
12+
13+
14+
def embed_text(model, text):
15+
return np.array(model.encode(text)).astype(np.float32).tobytes()
16+
17+
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces.*")
18+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
19+
query = "Bike for small kids"
20+
query_vector = embed_text(model, query)
21+
22+
r = redis.Redis(decode_responses=True)
23+
24+
# create index
25+
schema = (
26+
TextField("$.description", no_stem=True, as_name="model"),
27+
TagField("$.condition", as_name="condition"),
28+
NumericField("$.price", as_name="price"),
29+
VectorField(
30+
"$.description_embeddings",
31+
"FLAT",
32+
{
33+
"TYPE": "FLOAT32",
34+
"DIM": 384,
35+
"DISTANCE_METRIC": "COSINE",
36+
},
37+
as_name="vector",
38+
),
39+
)
40+
41+
index = r.ft("idx:bicycle")
42+
index.create_index(
43+
schema,
44+
definition=IndexDefinition(prefix=["bicycle:"], index_type=IndexType.JSON),
45+
)
46+
47+
# load data
48+
with open("data/query_vector.json") as f:
49+
bicycles = json.load(f)
50+
51+
pipeline = r.pipeline(transaction=False)
52+
for bid, bicycle in enumerate(bicycles):
53+
pipeline.json().set(f'bicycle:{bid}', Path.root_path(), bicycle)
54+
pipeline.execute()
55+
# HIDE_END
56+
57+
# STEP_START combined1
58+
q = Query("@price:[500 1000] @condition:{new}")
59+
res = index.search(q)
60+
print(res.total) # >>> 1
61+
# REMOVE_START
62+
assert res.total == 1
63+
# REMOVE_END
64+
# STEP_END
65+
66+
# STEP_START combined2
67+
q = Query("kids @price:[500 1000] @condition:{used}")
68+
res = index.search(q)
69+
print(res.total) # >>> 1
70+
# REMOVE_START
71+
assert res.total == 1
72+
# REMOVE_END
73+
# STEP_END
74+
75+
# STEP_START combined3
76+
q = Query("(kids | small) @condition:{used}")
77+
res = index.search(q)
78+
print(res.total) # >>> 2
79+
# REMOVE_START
80+
assert res.total == 2
81+
# REMOVE_END
82+
# STEP_END
83+
84+
# STEP_START combined4
85+
q = Query("@description:(kids | small) @condition:{used}")
86+
res = index.search(q)
87+
print(res.total) # >>> 0
88+
# REMOVE_START
89+
assert res.total == 0
90+
# REMOVE_END
91+
# STEP_END
92+
93+
# STEP_START combined5
94+
q = Query("@description:(kids | small) @condition:{new | used}")
95+
res = index.search(q)
96+
print(res.total) # >>> 0
97+
# REMOVE_START
98+
assert res.total == 0
99+
# REMOVE_END
100+
# STEP_END
101+
102+
# STEP_START combined6
103+
q = Query("@price:[500 1000] -@condition:{new}")
104+
res = index.search(q)
105+
print(res.total) # >>> 2
106+
# REMOVE_START
107+
assert res.total == 2
108+
# REMOVE_END
109+
# STEP_END
110+
111+
# STEP_START combined7
112+
q = Query("(@price:[500 1000] -@condition:{new})=>[KNN 3 @vector $query_vector]").dialect(2)
113+
# put query string here
114+
res = index.search(q,{ 'query_vector': query_vector })
115+
print(res.total) # >>> 2
116+
# REMOVE_START
117+
assert res.total == 2
118+
# REMOVE_END
119+
# STEP_END
120+
121+
# REMOVE_START
122+
# destroy index and data
123+
r.ft("idx:bicycle").dropindex(delete_documents=True)
124+
# REMOVE_END

redis/commands/search/aggregation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(self, query: str = "*") -> None:
112112
self._cursor = []
113113
self._dialect = None
114114
self._add_scores = False
115+
self._scorer = "TFIDF"
115116

116117
def load(self, *fields: List[str]) -> "AggregateRequest":
117118
"""
@@ -300,6 +301,17 @@ def add_scores(self) -> "AggregateRequest":
300301
self._add_scores = True
301302
return self
302303

304+
def scorer(self, scorer: str) -> "AggregateRequest":
305+
"""
306+
Use a different scoring function to evaluate document relevance.
307+
Default is `TFIDF`.
308+
309+
:param scorer: The scoring function to use
310+
(e.g. `TFIDF.DOCNORM` or `BM25`)
311+
"""
312+
self._scorer = scorer
313+
return self
314+
303315
def verbatim(self) -> "AggregateRequest":
304316
self._verbatim = True
305317
return self
@@ -323,6 +335,9 @@ def build_args(self) -> List[str]:
323335
if self._verbatim:
324336
ret.append("VERBATIM")
325337

338+
if self._scorer:
339+
ret.extend(["SCORER", self._scorer])
340+
326341
if self._add_scores:
327342
ret.append("ADDSCORES")
328343

@@ -332,6 +347,7 @@ def build_args(self) -> List[str]:
332347
if self._loadall:
333348
ret.append("LOAD")
334349
ret.append("*")
350+
335351
elif self._loadfields:
336352
ret.append("LOAD")
337353
ret.append(str(len(self._loadfields)))

tests/test_asyncio/test_search.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,61 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis):
15561556
assert res.rows[1] == ["__score", "0.2"]
15571557

15581558

1559+
@pytest.mark.redismod
1560+
@skip_ifmodversion_lt("2.10.05", "search")
1561+
async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
1562+
assert await decoded_r.ft().create_index(
1563+
(
1564+
TextField("name", sortable=True, weight=5.0),
1565+
TextField("description", sortable=True, weight=5.0),
1566+
VectorField(
1567+
"vector",
1568+
"HNSW",
1569+
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
1570+
),
1571+
)
1572+
)
1573+
1574+
assert await decoded_r.hset(
1575+
"doc1",
1576+
mapping={
1577+
"name": "cat book",
1578+
"description": "an animal book about cats",
1579+
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
1580+
},
1581+
)
1582+
assert await decoded_r.hset(
1583+
"doc2",
1584+
mapping={
1585+
"name": "dog book",
1586+
"description": "an animal book about dogs",
1587+
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
1588+
},
1589+
)
1590+
1591+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
1592+
req = (
1593+
aggregations.AggregateRequest(query_string)
1594+
.scorer("BM25")
1595+
.add_scores()
1596+
.apply(hybrid_score="@__score + @dist")
1597+
.load("*")
1598+
.dialect(4)
1599+
)
1600+
1601+
res = await decoded_r.ft().aggregate(
1602+
req,
1603+
query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()},
1604+
)
1605+
1606+
if isinstance(res, dict):
1607+
assert len(res["results"]) == 2
1608+
else:
1609+
assert len(res.rows) == 2
1610+
for row in res.rows:
1611+
len(row) == 6
1612+
1613+
15591614
@pytest.mark.redismod
15601615
@skip_if_redis_enterprise()
15611616
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,61 @@ def test_aggregations_add_scores(client):
14661466
assert res.rows[1] == ["__score", "0.2"]
14671467

14681468

1469+
@pytest.mark.redismod
1470+
@skip_ifmodversion_lt("2.10.05", "search")
1471+
async def test_aggregations_hybrid_scoring(client):
1472+
client.ft().create_index(
1473+
(
1474+
TextField("name", sortable=True, weight=5.0),
1475+
TextField("description", sortable=True, weight=5.0),
1476+
VectorField(
1477+
"vector",
1478+
"HNSW",
1479+
{"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "COSINE"},
1480+
),
1481+
)
1482+
)
1483+
1484+
client.hset(
1485+
"doc1",
1486+
mapping={
1487+
"name": "cat book",
1488+
"description": "an animal book about cats",
1489+
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
1490+
},
1491+
)
1492+
client.hset(
1493+
"doc2",
1494+
mapping={
1495+
"name": "dog book",
1496+
"description": "an animal book about dogs",
1497+
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
1498+
},
1499+
)
1500+
1501+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
1502+
req = (
1503+
aggregations.AggregateRequest(query_string)
1504+
.scorer("BM25")
1505+
.add_scores()
1506+
.apply(hybrid_score="@__score + @dist")
1507+
.load("*")
1508+
.dialect(4)
1509+
)
1510+
1511+
res = client.ft().aggregate(
1512+
req,
1513+
query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()},
1514+
)
1515+
1516+
if isinstance(res, dict):
1517+
assert len(res["results"]) == 2
1518+
else:
1519+
assert len(res.rows) == 2
1520+
for row in res.rows:
1521+
len(row) == 6
1522+
1523+
14691524
@pytest.mark.redismod
14701525
@skip_ifmodversion_lt("2.0.0", "search")
14711526
def test_index_definition(client):

0 commit comments

Comments
 (0)