Skip to content

Commit 4b71d01

Browse files
committed
update tests for BM25
1 parent 2172c58 commit 4b71d01

File tree

2 files changed

+24
-34
lines changed

2 files changed

+24
-34
lines changed

tests/test_asyncio/test_search.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,20 +1575,20 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
15751575
"doc1",
15761576
mapping={
15771577
"name": "cat book",
1578-
"description": "a book about cats",
1578+
"description": "an animal book about cats",
15791579
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
15801580
},
15811581
)
15821582
assert await decoded_r.hset(
15831583
"doc2",
15841584
mapping={
15851585
"name": "dog book",
1586-
"description": "a book about dogs",
1586+
"description": "an animal book about dogs",
15871587
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
15881588
},
15891589
)
15901590

1591-
query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]"
1591+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
15921592
req = (
15931593
aggregations.AggregateRequest(query_string)
15941594
.scorer("BM25")
@@ -1598,22 +1598,17 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
15981598
.dialect(4)
15991599
)
16001600

1601-
res = (
1602-
await decoded_r.ft()
1603-
.aggregate(
1604-
req,
1605-
query_params={
1606-
"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()
1607-
},
1608-
)
1609-
.rows[0]
1601+
res = await decoded_r.ft().aggregate(
1602+
req,
1603+
query_params={"vec_param": np.array([0.11, 0.22]).astype(np.float32).tobytes()},
16101604
)
16111605

1612-
assert len(res) == 6
1613-
assert b"hybrid_score" in res
1614-
assert b"__score" in res
1615-
assert b"__dist" in res
1616-
assert float(res[1]) + float(res[3]) == float(res[5])
1606+
if isinstance(res, dict):
1607+
assert len(res["results"]) == 2
1608+
else:
1609+
assert len(res.rows) == 2
1610+
for row in res.rows:
1611+
len(row) == 6
16171612

16181613

16191614
@pytest.mark.redismod

tests/test_search.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,20 +1485,20 @@ async def test_aggregations_hybrid_scoring(client):
14851485
"doc1",
14861486
mapping={
14871487
"name": "cat book",
1488-
"description": "a book about cats",
1488+
"description": "an animal book about cats",
14891489
"vector": np.array([0.1, 0.2]).astype(np.float32).tobytes(),
14901490
},
14911491
)
14921492
client.hset(
14931493
"doc2",
14941494
mapping={
14951495
"name": "dog book",
1496-
"description": "a book about dogs",
1496+
"description": "an animal book about dogs",
14971497
"vector": np.array([0.2, 0.1]).astype(np.float32).tobytes(),
14981498
},
14991499
)
15001500

1501-
query_string = "(@description:cat)=>[KNN 3 @vector $vec_param AS dist]"
1501+
query_string = "(@description:animal)=>[KNN 3 @vector $vec_param AS dist]"
15021502
req = (
15031503
aggregations.AggregateRequest(query_string)
15041504
.scorer("BM25")
@@ -1508,22 +1508,17 @@ async def test_aggregations_hybrid_scoring(client):
15081508
.dialect(4)
15091509
)
15101510

1511-
res = (
1512-
client.ft()
1513-
.aggregate(
1514-
req,
1515-
query_params={
1516-
"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()
1517-
},
1518-
)
1519-
.rows[0]
1511+
res = client.ft().aggregate(
1512+
req,
1513+
query_params={"vec_param": np.array([0.11, 0.21]).astype(np.float32).tobytes()},
15201514
)
15211515

1522-
assert len(res) == 6
1523-
assert b"hybrid_score" in res
1524-
assert b"__score" in res
1525-
assert b"__dist" in res
1526-
assert float(res[1]) + float(res[3]) == float(res[5])
1516+
if isinstance(res, dict):
1517+
assert len(res["results"]) == 2
1518+
else:
1519+
assert len(res.rows) == 2
1520+
for row in res.rows:
1521+
len(row) == 6
15271522

15281523

15291524
@pytest.mark.redismod

0 commit comments

Comments
 (0)