Skip to content

Commit b28656d

Browse files
committed
enum for aggregation scorer
1 parent f617162 commit b28656d

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

redis/commands/search/aggregation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1+
from enum import Enum
12
from typing import List, Union
23

34
FIELDNAME = object()
45

56

7+
class Scorers(Enum):
8+
TFIDF = "TFIDF"
9+
TFIDF_DOCNORM = "TFIDF.DOCNORM"
10+
BM25 = "BM25"
11+
DISMAX = "DISMAX"
12+
DOCSCORE = "DOCSCORE"
13+
HAMMING = "HAMMING"
14+
15+
616
class Limit:
717
def __init__(self, offset: int = 0, count: int = 0) -> None:
818
self.offset = offset
@@ -112,7 +122,7 @@ def __init__(self, query: str = "*") -> None:
112122
self._cursor = []
113123
self._dialect = None
114124
self._add_scores = False
115-
self._scorer = None
125+
self._scorer = Scorers.TFIDF.value
116126

117127
def load(self, *fields: List[str]) -> "AggregateRequest":
118128
"""
@@ -309,7 +319,7 @@ def scorer(self, scorer: str) -> "AggregateRequest":
309319
:param scorer: The scoring function to use
310320
(e.g. `TFIDF.DOCNORM` or `BM25`)
311321
"""
312-
self._scorer = scorer
322+
self._scorer = Scorers(scorer).value
313323
return self
314324

315325
def verbatim(self) -> "AggregateRequest":

tests/test_asyncio/test_search.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,14 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
16111611
len(row) == 6
16121612

16131613

1614+
@pytest.mark.redismod
1615+
@skip_ifmodversion_lt("2.10.05", "search")
1616+
async def test_invalid_scorer():
1617+
1618+
with pytest.raises(ValueError):
1619+
aggregations.AggregateRequest("*").scorer("blah")
1620+
1621+
16141622
@pytest.mark.redismod
16151623
@skip_if_redis_enterprise()
16161624
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,14 @@ async def test_aggregations_hybrid_scoring(client):
15211521
len(row) == 6
15221522

15231523

1524+
@pytest.mark.redismod
1525+
@skip_ifmodversion_lt("2.10.05", "search")
1526+
async def test_invalid_scorer():
1527+
1528+
with pytest.raises(ValueError):
1529+
aggregations.AggregateRequest("*").scorer("blah")
1530+
1531+
15241532
@pytest.mark.redismod
15251533
@skip_ifmodversion_lt("2.0.0", "search")
15261534
def test_index_definition(client):

0 commit comments

Comments
 (0)