@@ -418,7 +418,7 @@ def __init__(
418
418
limit : Optional [int ] = None ,
419
419
page_size : int = DEFAULT_PAGE_SIZE ,
420
420
sort_fields : Optional [List [str ]] = None ,
421
- return_fields : Optional [List [str ]] = None ,
421
+ projected_fields : Optional [List [str ]] = None ,
422
422
nocontent : bool = False ,
423
423
):
424
424
if not has_redisearch (model .db ()):
@@ -443,10 +443,10 @@ def __init__(
443
443
else :
444
444
self .sort_fields = []
445
445
446
- if return_fields :
447
- self .return_fields = self .validate_return_fields ( return_fields )
446
+ if projected_fields :
447
+ self .projected_fields = self .validate_projected_fields ( projected_fields )
448
448
else :
449
- self .return_fields = []
449
+ self .projected_fields = []
450
450
451
451
self ._expression = None
452
452
self ._query : Optional [str ] = None
@@ -505,18 +505,45 @@ def query(self):
505
505
if self ._query .startswith ("(" ) or self ._query == "*"
506
506
else f"({ self ._query } )"
507
507
) + f"=>[{ self .knn } ]"
508
- if self .return_fields :
509
- self ._query += f" RETURN { ',' .join (self .return_fields )} "
508
+ # RETURN clause should be added to args, not to the query string
510
509
return self ._query
511
510
512
- def validate_return_fields (self , return_fields : List [str ]):
513
- for field in return_fields :
514
- if field not in self .model .__fields__ : # type: ignore
511
+ def validate_projected_fields (self , projected_fields : List [str ]):
512
+ for field in projected_fields :
513
+ if field not in self .model .model_fields : # type: ignore
515
514
raise QueryNotSupportedError (
516
515
f"You tried to return the field { field } , but that field "
517
516
f"does not exist on the model { self .model } "
518
517
)
519
- return return_fields
518
+ return projected_fields
519
+
520
+ def _parse_projected_results (self , res : Any ) -> List [Dict [str , Any ]]:
521
+ """Parse results when using RETURN clause with specific fields."""
522
+
523
+ def to_string (s ):
524
+ if isinstance (s , (str ,)):
525
+ return s
526
+ elif isinstance (s , bytes ):
527
+ return s .decode (errors = "ignore" )
528
+ else :
529
+ return s
530
+
531
+ docs = []
532
+ step = 2 # Because the result has content
533
+ offset = 1 # The first item is the count of total matches.
534
+
535
+ for i in range (1 , len (res ), step ):
536
+ if res [i + offset ] is None :
537
+ continue
538
+ # When using RETURN, we get flat key-value pairs
539
+ fields : Dict [str , str ] = dict (
540
+ zip (
541
+ map (to_string , res [i + offset ][::2 ]),
542
+ map (to_string , res [i + offset ][1 ::2 ]),
543
+ )
544
+ )
545
+ docs .append (fields )
546
+ return docs
520
547
521
548
@property
522
549
def query_params (self ):
@@ -899,6 +926,12 @@ async def execute(
899
926
if self .nocontent :
900
927
args .append ("NOCONTENT" )
901
928
929
+ # Add RETURN clause to the args list, not to the query string
930
+ if self .projected_fields :
931
+ args .extend (
932
+ ["RETURN" , str (len (self .projected_fields ))] + self .projected_fields
933
+ )
934
+
902
935
if return_query_args :
903
936
return self .model .Meta .index_name , args
904
937
@@ -912,7 +945,12 @@ async def execute(
912
945
if return_raw_result :
913
946
return raw_result
914
947
count = raw_result [0 ]
915
- results = self .model .from_redis (raw_result , self .knn )
948
+
949
+ # If we're using field projection, return dictionaries instead of model instances
950
+ if self .projected_fields :
951
+ results = self ._parse_projected_results (raw_result )
952
+ else :
953
+ results = self .model .from_redis (raw_result , self .knn )
916
954
self ._model_cache += results
917
955
918
956
if not exhaust_results :
@@ -966,11 +1004,11 @@ def sort_by(self, *fields: str):
966
1004
if not fields :
967
1005
return self
968
1006
return self .copy (sort_fields = list (fields ))
969
-
1007
+
970
1008
def return_fields (self , * fields : str ):
971
1009
if not fields :
972
1010
return self
973
- return self .copy (return_fields = list (fields ))
1011
+ return self .copy (projected_fields = list (fields ))
974
1012
975
1013
async def update (self , use_transaction = True , ** field_values ):
976
1014
"""
@@ -1546,9 +1584,7 @@ def find(
1546
1584
* expressions : Union [Any , Expression ],
1547
1585
knn : Optional [KNNExpression ] = None ,
1548
1586
) -> FindQuery :
1549
- return FindQuery (
1550
- expressions = expressions , knn = knn , model = cls
1551
- )
1587
+ return FindQuery (expressions = expressions , knn = knn , model = cls )
1552
1588
1553
1589
@classmethod
1554
1590
def from_redis (cls , res : Any , knn : Optional [KNNExpression ] = None ):
0 commit comments