@@ -418,6 +418,7 @@ def __init__(
418
418
limit : Optional [int ] = None ,
419
419
page_size : int = DEFAULT_PAGE_SIZE ,
420
420
sort_fields : Optional [List [str ]] = None ,
421
+ return_fields : Optional [List [str ]] = None ,
421
422
nocontent : bool = False ,
422
423
):
423
424
if not has_redisearch (model .db ()):
@@ -442,6 +443,11 @@ def __init__(
442
443
else :
443
444
self .sort_fields = []
444
445
446
+ if return_fields :
447
+ self .return_fields = self .validate_return_fields (return_fields )
448
+ else :
449
+ self .return_fields = []
450
+
445
451
self ._expression = None
446
452
self ._query : Optional [str ] = None
447
453
self ._pagination : List [str ] = []
@@ -499,8 +505,19 @@ def query(self):
499
505
if self ._query .startswith ("(" ) or self ._query == "*"
500
506
else f"({ self ._query } )"
501
507
) + f"=>[{ self .knn } ]"
508
+ if self .return_fields :
509
+ self ._query += f" RETURN { ',' .join (self .return_fields )} "
502
510
return self ._query
503
511
512
+ def validate_return_fields (self , return_fields : List [str ]):
513
+ for field in return_fields :
514
+ if field not in self .model .__fields__ : # type: ignore
515
+ raise QueryNotSupportedError (
516
+ f"You tried to return the field { field } , but that field "
517
+ f"does not exist on the model { self .model } "
518
+ )
519
+ return return_fields
520
+
504
521
@property
505
522
def query_params (self ):
506
523
params : List [Union [str , bytes ]] = []
@@ -949,6 +966,11 @@ def sort_by(self, *fields: str):
949
966
if not fields :
950
967
return self
951
968
return self .copy (sort_fields = list (fields ))
969
+
970
+ def return_fields (self , * fields : str ):
971
+ if not fields :
972
+ return self
973
+ return self .copy (return_fields = list (fields ))
952
974
953
975
async def update (self , use_transaction = True , ** field_values ):
954
976
"""
@@ -1524,7 +1546,9 @@ def find(
1524
1546
* expressions : Union [Any , Expression ],
1525
1547
knn : Optional [KNNExpression ] = None ,
1526
1548
) -> FindQuery :
1527
- return FindQuery (expressions = expressions , knn = knn , model = cls )
1549
+ return FindQuery (
1550
+ expressions = expressions , knn = knn , model = cls
1551
+ )
1528
1552
1529
1553
@classmethod
1530
1554
def from_redis (cls , res : Any , knn : Optional [KNNExpression ] = None ):
0 commit comments