Skip to content

Commit 5af2e8e

Browse files
savynorembsbodden
authored andcommitted
added return_fields function, attempting to optionally limit fields returned by find
1 parent b00c9e0 commit 5af2e8e

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

aredis_om/model/model.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def __init__(
418418
limit: Optional[int] = None,
419419
page_size: int = DEFAULT_PAGE_SIZE,
420420
sort_fields: Optional[List[str]] = None,
421+
return_fields: Optional[List[str]] = None,
421422
nocontent: bool = False,
422423
):
423424
if not has_redisearch(model.db()):
@@ -442,6 +443,11 @@ def __init__(
442443
else:
443444
self.sort_fields = []
444445

446+
if return_fields:
447+
self.return_fields = self.validate_return_fields(return_fields)
448+
else:
449+
self.return_fields = []
450+
445451
self._expression = None
446452
self._query: Optional[str] = None
447453
self._pagination: List[str] = []
@@ -499,8 +505,19 @@ def query(self):
499505
if self._query.startswith("(") or self._query == "*"
500506
else f"({self._query})"
501507
) + f"=>[{self.knn}]"
508+
if self.return_fields:
509+
self._query += f" RETURN {','.join(self.return_fields)}"
502510
return self._query
503511

512+
def validate_return_fields(self, return_fields: List[str]):
513+
for field in return_fields:
514+
if field not in self.model.__fields__: # type: ignore
515+
raise QueryNotSupportedError(
516+
f"You tried to return the field {field}, but that field "
517+
f"does not exist on the model {self.model}"
518+
)
519+
return return_fields
520+
504521
@property
505522
def query_params(self):
506523
params: List[Union[str, bytes]] = []
@@ -949,6 +966,11 @@ def sort_by(self, *fields: str):
949966
if not fields:
950967
return self
951968
return self.copy(sort_fields=list(fields))
969+
970+
def return_fields(self, *fields: str):
971+
if not fields:
972+
return self
973+
return self.copy(return_fields=list(fields))
952974

953975
async def update(self, use_transaction=True, **field_values):
954976
"""
@@ -1524,7 +1546,9 @@ def find(
15241546
*expressions: Union[Any, Expression],
15251547
knn: Optional[KNNExpression] = None,
15261548
) -> FindQuery:
1527-
return FindQuery(expressions=expressions, knn=knn, model=cls)
1549+
return FindQuery(
1550+
expressions=expressions, knn=knn, model=cls
1551+
)
15281552

15291553
@classmethod
15301554
def from_redis(cls, res: Any, knn: Optional[KNNExpression] = None):

tests/test_json_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,18 @@ class TypeWithUuid(JsonModel, index=True):
955955

956956
await item.save()
957957

958+
@py_test_mark_asyncio
959+
async def test_return_specified_fields(members, m):
960+
member1, member2, member3 = members
961+
actual = await m.Member.find(
962+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
963+
| (m.Member.last_name == "Smith")
964+
).all()
965+
assert actual == [
966+
{"first_name": "Andrew", "last_name": "Brookins"},
967+
{"first_name": "Andrew", "last_name": "Smith"},
968+
]
969+
958970

959971
@py_test_mark_asyncio
960972
async def test_type_with_enum():

0 commit comments

Comments
 (0)