Skip to content

Commit b997389

Browse files
committed
fix: correctly implement field projection with return_fields
- Move RETURN clause from query string to args list (fixes RediSearch syntax) - Rename internal parameter from return_fields to projected_fields to avoid conflict with method name - Return dictionaries instead of model instances when using field projection - Add proper parsing for projected results that returns flat key-value pairs - Add tests for both HashModel and JsonModel field projection - Update validation to use model_fields instead of deprecated __fields__ This addresses all review comments from PR #633 and implements field projection correctly.
1 parent 5af2e8e commit b997389

File tree

3 files changed

+78
-20
lines changed

3 files changed

+78
-20
lines changed

aredis_om/model/model.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def __init__(
418418
limit: Optional[int] = None,
419419
page_size: int = DEFAULT_PAGE_SIZE,
420420
sort_fields: Optional[List[str]] = None,
421-
return_fields: Optional[List[str]] = None,
421+
projected_fields: Optional[List[str]] = None,
422422
nocontent: bool = False,
423423
):
424424
if not has_redisearch(model.db()):
@@ -443,10 +443,10 @@ def __init__(
443443
else:
444444
self.sort_fields = []
445445

446-
if return_fields:
447-
self.return_fields = self.validate_return_fields(return_fields)
446+
if projected_fields:
447+
self.projected_fields = self.validate_projected_fields(projected_fields)
448448
else:
449-
self.return_fields = []
449+
self.projected_fields = []
450450

451451
self._expression = None
452452
self._query: Optional[str] = None
@@ -505,18 +505,45 @@ def query(self):
505505
if self._query.startswith("(") or self._query == "*"
506506
else f"({self._query})"
507507
) + f"=>[{self.knn}]"
508-
if self.return_fields:
509-
self._query += f" RETURN {','.join(self.return_fields)}"
508+
# RETURN clause should be added to args, not to the query string
510509
return self._query
511510

512-
def validate_return_fields(self, return_fields: List[str]):
513-
for field in return_fields:
514-
if field not in self.model.__fields__: # type: ignore
511+
def validate_projected_fields(self, projected_fields: List[str]):
512+
for field in projected_fields:
513+
if field not in self.model.model_fields: # type: ignore
515514
raise QueryNotSupportedError(
516515
f"You tried to return the field {field}, but that field "
517516
f"does not exist on the model {self.model}"
518517
)
519-
return return_fields
518+
return projected_fields
519+
520+
def _parse_projected_results(self, res: Any) -> List[Dict[str, Any]]:
521+
"""Parse results when using RETURN clause with specific fields."""
522+
523+
def to_string(s):
524+
if isinstance(s, (str,)):
525+
return s
526+
elif isinstance(s, bytes):
527+
return s.decode(errors="ignore")
528+
else:
529+
return s
530+
531+
docs = []
532+
step = 2 # Because the result has content
533+
offset = 1 # The first item is the count of total matches.
534+
535+
for i in range(1, len(res), step):
536+
if res[i + offset] is None:
537+
continue
538+
# When using RETURN, we get flat key-value pairs
539+
fields: Dict[str, str] = dict(
540+
zip(
541+
map(to_string, res[i + offset][::2]),
542+
map(to_string, res[i + offset][1::2]),
543+
)
544+
)
545+
docs.append(fields)
546+
return docs
520547

521548
@property
522549
def query_params(self):
@@ -899,6 +926,12 @@ async def execute(
899926
if self.nocontent:
900927
args.append("NOCONTENT")
901928

929+
# Add RETURN clause to the args list, not to the query string
930+
if self.projected_fields:
931+
args.extend(
932+
["RETURN", str(len(self.projected_fields))] + self.projected_fields
933+
)
934+
902935
if return_query_args:
903936
return self.model.Meta.index_name, args
904937

@@ -912,7 +945,12 @@ async def execute(
912945
if return_raw_result:
913946
return raw_result
914947
count = raw_result[0]
915-
results = self.model.from_redis(raw_result, self.knn)
948+
949+
# If we're using field projection, return dictionaries instead of model instances
950+
if self.projected_fields:
951+
results = self._parse_projected_results(raw_result)
952+
else:
953+
results = self.model.from_redis(raw_result, self.knn)
916954
self._model_cache += results
917955

918956
if not exhaust_results:
@@ -966,11 +1004,11 @@ def sort_by(self, *fields: str):
9661004
if not fields:
9671005
return self
9681006
return self.copy(sort_fields=list(fields))
969-
1007+
9701008
def return_fields(self, *fields: str):
9711009
if not fields:
9721010
return self
973-
return self.copy(return_fields=list(fields))
1011+
return self.copy(projected_fields=list(fields))
9741012

9751013
async def update(self, use_transaction=True, **field_values):
9761014
"""
@@ -1546,9 +1584,7 @@ def find(
15461584
*expressions: Union[Any, Expression],
15471585
knn: Optional[KNNExpression] = None,
15481586
) -> FindQuery:
1549-
return FindQuery(
1550-
expressions=expressions, knn=knn, model=cls
1551-
)
1587+
return FindQuery(expressions=expressions, knn=knn, model=cls)
15521588

15531589
@classmethod
15541590
def from_redis(cls, res: Any, knn: Optional[KNNExpression] = None):

tests/test_hash_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,23 @@ class Meta:
11331133
).first()
11341134

11351135

1136+
@py_test_mark_asyncio
1137+
async def test_return_specified_fields(members, m):
1138+
member1, member2, member3 = members
1139+
actual = (
1140+
await m.Member.find(
1141+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
1142+
| (m.Member.last_name == "Smith")
1143+
)
1144+
.return_fields("first_name", "last_name")
1145+
.all()
1146+
)
1147+
assert actual == [
1148+
{"first_name": "Andrew", "last_name": "Brookins"},
1149+
{"first_name": "Andrew", "last_name": "Smith"},
1150+
]
1151+
1152+
11361153
@py_test_mark_asyncio
11371154
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
11381155
class Location(HashModel, index=True):

tests/test_json_model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -955,13 +955,18 @@ class TypeWithUuid(JsonModel, index=True):
955955

956956
await item.save()
957957

958+
958959
@py_test_mark_asyncio
959960
async def test_return_specified_fields(members, m):
960961
member1, member2, member3 = members
961-
actual = await m.Member.find(
962-
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
963-
| (m.Member.last_name == "Smith")
964-
).all()
962+
actual = (
963+
await m.Member.find(
964+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
965+
| (m.Member.last_name == "Smith")
966+
)
967+
.return_fields("first_name", "last_name")
968+
.all()
969+
)
965970
assert actual == [
966971
{"first_name": "Andrew", "last_name": "Brookins"},
967972
{"first_name": "Andrew", "last_name": "Smith"},

0 commit comments

Comments
 (0)