From e82dc8e1374c6668ee32b31138dc644129e6f30d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Jul 2024 12:26:41 +0300 Subject: [PATCH 1/6] Added support for ADDSCORES modifier --- redis/commands/search/aggregation.py | 11 +++++++++++ tests/test_asyncio/test_search.py | 17 +++++++++++++++++ tests/test_search.py | 17 +++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 50d18f476a..42c3547b0b 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None: self._verbatim = False self._cursor = [] self._dialect = None + self._add_scores = False def load(self, *fields: List[str]) -> "AggregateRequest": """ @@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest": self._with_schema = True return self + def add_scores(self) -> "AggregateRequest": + """ + If set, includes the score as an ordinary field of the row. + """ + self._add_scores = True + return self + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self @@ -315,6 +323,9 @@ def build_args(self) -> List[str]: if self._verbatim: ret.append("VERBATIM") + if self._add_scores: + ret.append("ADDSCORES") + if self._cursor: ret += self._cursor diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 68560d1f2a..5e5b30702e 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1530,6 +1530,23 @@ async def test_withsuffixtrie(decoded_r: redis.Redis): assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +async def test_aggregations_add_scores(decoded_r: redis.Redis): + assert await decoded_r.ft().create_index( + (TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True)) + ) + + assert await decoded_r.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) + assert await decoded_r.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) + + req = (aggregations.AggregateRequest("*").add_scores()) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] + + @pytest.mark.redismod @skip_if_redis_enterprise() async def test_search_commands_in_pipeline(decoded_r: redis.Redis): diff --git a/tests/test_search.py b/tests/test_search.py index e84f03c0e4..9ce25b7bf1 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1440,6 +1440,23 @@ def test_aggregations_filter(client): assert res["results"][1]["extra_attributes"] == {"age": "25"} +@pytest.mark.redismod +@skip_ifmodversion_lt("2.10.05", "search") +def test_aggregations_add_scores(client): + client.ft().create_index( + (TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True)) + ) + + client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) + client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) + + req = (aggregations.AggregateRequest("*").add_scores()) + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] + + @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") def test_index_definition(client): From d08f9cdef8d3d6cfa48c3265104bdcd8a8845691 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Jul 2024 12:29:59 +0300 Subject: [PATCH 2/6] Fixed codestyle issues --- tests/test_asyncio/test_search.py | 13 ++++++++++--- tests/test_search.py | 5 ++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 5e5b30702e..f6b43c4b10 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1534,11 +1534,18 @@ async def test_withsuffixtrie(decoded_r: redis.Redis): @skip_ifmodversion_lt("2.10.05", "search") async def test_aggregations_add_scores(decoded_r: redis.Redis): assert await decoded_r.ft().create_index( - (TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True)) + ( + TextField("name", sortable=True, weight=5.0), + NumericField("age", sortable=True) + ) ) - assert await decoded_r.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) - assert await decoded_r.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) + assert await decoded_r.ft().client.hset( + "doc1", mapping={"name": "bar", "age": "25"} + ) + assert await decoded_r.ft().client.hset( + "doc2", mapping={"name": "foo", "age": "19"} + ) req = (aggregations.AggregateRequest("*").add_scores()) res = await decoded_r.ft().aggregate(req) diff --git a/tests/test_search.py b/tests/test_search.py index 9ce25b7bf1..c918e6d940 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1444,7 +1444,10 @@ def test_aggregations_filter(client): @skip_ifmodversion_lt("2.10.05", "search") def test_aggregations_add_scores(client): client.ft().create_index( - (TextField("name", sortable=True, weight=5.0), NumericField("age", sortable=True)) + ( + TextField("name", sortable=True, weight=5.0), + NumericField("age", sortable=True) + ) ) client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) From e53f22cccb96a50298785fd92629eb44f14cd344 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 22 Jul 2024 12:39:10 +0300 Subject: [PATCH 3/6] More codestyle fixes --- tests/test_asyncio/test_search.py | 4 ++-- tests/test_search.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index f6b43c4b10..a352f9abe2 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1536,7 +1536,7 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): assert await decoded_r.ft().create_index( ( TextField("name", sortable=True, weight=5.0), - NumericField("age", sortable=True) + NumericField("age", sortable=True), ) ) @@ -1547,7 +1547,7 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): "doc2", mapping={"name": "foo", "age": "19"} ) - req = (aggregations.AggregateRequest("*").add_scores()) + req = aggregations.AggregateRequest("*").add_scores() res = await decoded_r.ft().aggregate(req) assert len(res.rows) == 2 assert res.rows[0] == ["__score", "0.2"] diff --git a/tests/test_search.py b/tests/test_search.py index c918e6d940..a3b847a259 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1446,14 +1446,14 @@ def test_aggregations_add_scores(client): client.ft().create_index( ( TextField("name", sortable=True, weight=5.0), - NumericField("age", sortable=True) + NumericField("age", sortable=True), ) ) client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) - req = (aggregations.AggregateRequest("*").add_scores()) + req = aggregations.AggregateRequest("*").add_scores() res = client.ft().aggregate(req) assert len(res.rows) == 2 assert res.rows[0] == ["__score", "0.2"] From ef97b3d64184af4f542e586abe457c6da2805305 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 30 Jul 2024 13:23:40 +0300 Subject: [PATCH 4/6] Updated test cases and testing image to represent latest --- .github/workflows/integration.yaml | 2 +- tests/test_asyncio/test_search.py | 4 ++-- tests/test_search.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 94fe8f35b6..5342238dd3 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -28,7 +28,7 @@ env: # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon REDIS_IMAGE: redis:7.4-rc2 - REDIS_STACK_IMAGE: redis/redis-stack-server:7.4.0-rc2 + REDIS_STACK_IMAGE: redis/redis-stack-server:latest jobs: dependency-audit: diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index a352f9abe2..4713ec0963 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1540,10 +1540,10 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): ) ) - assert await decoded_r.ft().client.hset( + assert await decoded_r.hset( "doc1", mapping={"name": "bar", "age": "25"} ) - assert await decoded_r.ft().client.hset( + assert await decoded_r.hset( "doc2", mapping={"name": "foo", "age": "19"} ) diff --git a/tests/test_search.py b/tests/test_search.py index a3b847a259..073e607dbc 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1450,8 +1450,8 @@ def test_aggregations_add_scores(client): ) ) - client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) - client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) + client.hset("doc1", mapping={"name": "bar", "age": "25"}) + client.hset("doc2", mapping={"name": "foo", "age": "19"}) req = aggregations.AggregateRequest("*").add_scores() res = client.ft().aggregate(req) From 4719d73385c5c156fda73e886d746949f8a0c760 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 30 Jul 2024 13:26:52 +0300 Subject: [PATCH 5/6] Codestyle issues --- tests/test_asyncio/test_search.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 4713ec0963..9f5d4fc76b 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1540,12 +1540,8 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): ) ) - assert await decoded_r.hset( - "doc1", mapping={"name": "bar", "age": "25"} - ) - assert await decoded_r.hset( - "doc2", mapping={"name": "foo", "age": "19"} - ) + assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"}) + assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"}) req = aggregations.AggregateRequest("*").add_scores() res = await decoded_r.ft().aggregate(req) From a90d571be5356ba131b84b6e9061038a937901f1 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 30 Jul 2024 14:07:19 +0300 Subject: [PATCH 6/6] Added handling for dict responses --- tests/test_asyncio/test_search.py | 12 +++++++++--- tests/test_search.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 9f5d4fc76b..0e6fe22131 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1545,9 +1545,15 @@ async def test_aggregations_add_scores(decoded_r: redis.Redis): req = aggregations.AggregateRequest("*").add_scores() res = await decoded_r.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["__score", "0.2"] - assert res.rows[1] == ["__score", "0.2"] + + if isinstance(res, dict): + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} + assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} + else: + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] @pytest.mark.redismod diff --git a/tests/test_search.py b/tests/test_search.py index 073e607dbc..dde59f0f87 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1455,9 +1455,15 @@ def test_aggregations_add_scores(client): req = aggregations.AggregateRequest("*").add_scores() res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["__score", "0.2"] - assert res.rows[1] == ["__score", "0.2"] + + if isinstance(res, dict): + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"__score": "0.2"} + assert res["results"][1]["extra_attributes"] == {"__score": "0.2"} + else: + assert len(res.rows) == 2 + assert res.rows[0] == ["__score", "0.2"] + assert res.rows[1] == ["__score", "0.2"] @pytest.mark.redismod