53
53
log = logging .getLogger (__name__ )
54
54
escaper = TokenEscaper ()
55
55
56
+
57
+ class PartialModel :
58
+ """A partial model instance that only contains certain fields.
59
+
60
+ Accessing fields that weren't loaded will raise AttributeError.
61
+ This is used for .only() queries to provide partial model instances.
62
+ """
63
+
64
+ def __init__ (self , model_class , data : dict , loaded_fields : set ):
65
+ self .__dict__ ["_model_class" ] = model_class
66
+ self .__dict__ ["_loaded_fields" ] = loaded_fields
67
+ self .__dict__ ["_data" ] = data
68
+
69
+ # Set the loaded field values
70
+ for field_name , value in data .items ():
71
+ self .__dict__ [field_name ] = value
72
+
73
+ def __getattribute__ (self , name ):
74
+ # Allow access to internal attributes and methods
75
+ if name .startswith ("_" ) or name in (
76
+ "model_fields" ,
77
+ "model_config" ,
78
+ "__class__" ,
79
+ "__dict__" ,
80
+ ):
81
+ return super ().__getattribute__ (name )
82
+
83
+ # Get model class to check if this is a model field
84
+ model_class = super ().__getattribute__ ("_model_class" )
85
+ loaded_fields = super ().__getattribute__ ("_loaded_fields" )
86
+
87
+ # If it's a model field that wasn't loaded, raise an error
88
+ if hasattr (model_class , "model_fields" ) and name in model_class .model_fields :
89
+ if name not in loaded_fields :
90
+ raise AttributeError (
91
+ f"Field '{ name } ' is missing from this query. "
92
+ f"Use .only('{ name } ') or .only({ ', ' .join (repr (field ) for field in sorted (loaded_fields .union ({name })))} ) to include it."
93
+ )
94
+
95
+ return super ().__getattribute__ (name )
96
+
97
+ def __setattr__ (self , name , value ):
98
+ # Allow setting internal attributes
99
+ if name .startswith ("_" ):
100
+ self .__dict__ [name ] = value
101
+ else :
102
+ # For regular fields, check if they were loaded
103
+ if name not in self ._loaded_fields :
104
+ raise AttributeError (
105
+ f"Cannot set field '{ name } ' - it is missing from this query."
106
+ )
107
+ self .__dict__ [name ] = value
108
+
109
+ def __repr__ (self ):
110
+ loaded_data = {k : v for k , v in self ._data .items () if k in self ._loaded_fields }
111
+ return f"Partial{ self ._model_class .__name__ } ({ loaded_data } )"
112
+
113
+
56
114
# For basic exact-match field types like an indexed string, we create a TAG
57
115
# field in the RediSearch index. TAG is designed for multi-value fields
58
116
# separated by a "separator" character. We're using the field for single values
@@ -503,7 +561,7 @@ def query(self):
503
561
"""
504
562
if self ._query :
505
563
return self ._query
506
- self ._query = self .resolve_redisearch_query (self .expression )
564
+ self ._query = self ._resolve_redisearch_query (self .expression )
507
565
if self .knn :
508
566
self ._query = (
509
567
self ._query
@@ -541,15 +599,98 @@ def to_string(s):
541
599
if res [i + offset ] is None :
542
600
continue
543
601
# When using RETURN, we get flat key-value pairs
544
- fields : Dict [str , str ] = dict (
602
+ raw_fields : Dict [str , str ] = dict (
545
603
zip (
546
604
map (to_string , res [i + offset ][::2 ]),
547
605
map (to_string , res [i + offset ][1 ::2 ]),
548
606
)
549
607
)
550
- docs .append (fields )
608
+ # Convert raw Redis strings to properly typed values
609
+ converted_fields = self ._convert_projected_fields (raw_fields )
610
+ docs .append (converted_fields )
551
611
return docs
552
612
613
+ def _convert_projected_fields (self , raw_data : Dict [str , str ]) -> Dict [str , Any ]:
614
+ """Convert raw Redis string values to properly typed values using model field info."""
615
+
616
+ # Fast path: Try creating a single model instance with all projected fields
617
+ # This is more efficient and handles field interdependencies
618
+ try :
619
+ # Use model_validate instead of model_construct to ensure type conversion
620
+ temp_model = self .model .model_validate (raw_data , strict = False )
621
+
622
+ # Use model_dump() to efficiently extract all converted values
623
+ all_converted = temp_model .model_dump ()
624
+
625
+ # Filter to only the fields we actually projected
626
+ converted_data = {
627
+ k : all_converted [k ] for k in raw_data .keys () if k in all_converted
628
+ }
629
+
630
+ return converted_data
631
+
632
+ except Exception : # nosec B110
633
+ # If validation fails (due to missing required fields), fall back to individual conversion
634
+ # This is expected for partial field sets
635
+ pass
636
+
637
+ # Fallback path: Convert each field individually using type information
638
+ converted_data = {}
639
+ for field_name , raw_value in raw_data .items ():
640
+ if field_name not in self .model .model_fields :
641
+ # Unknown field, keep as string
642
+ converted_data [field_name ] = raw_value
643
+ continue
644
+
645
+ try :
646
+ field_info = self .model .model_fields [field_name ]
647
+
648
+ # Get the field type annotation
649
+ if hasattr (field_info , "annotation" ):
650
+ field_type = field_info .annotation
651
+ else :
652
+ field_type = getattr (field_info , "type_" , str )
653
+
654
+ # Handle common type conversions directly for efficiency
655
+ if field_type == int :
656
+ converted_data [field_name ] = int (raw_value )
657
+ elif field_type == float :
658
+ converted_data [field_name ] = float (raw_value )
659
+ elif field_type == bool :
660
+ # Redis may store bool as "True"/"False" or "1"/"0"
661
+ converted_data [field_name ] = raw_value .lower () in (
662
+ "true" ,
663
+ "1" ,
664
+ "yes" ,
665
+ )
666
+ elif field_type == str :
667
+ converted_data [field_name ] = raw_value
668
+ else :
669
+ # For complex types, keep as string (could be enhanced later)
670
+ converted_data [field_name ] = raw_value
671
+
672
+ except (ValueError , TypeError ):
673
+ # If conversion fails, keep the raw string value
674
+ converted_data [field_name ] = raw_value
675
+
676
+ return converted_data
677
+
678
+ def _parse_projected_models (self , res : Any ) -> List [PartialModel ]:
679
+ """Parse results when using RETURN clause to create partial model instances."""
680
+ projected_dicts = self ._parse_projected_results (res )
681
+
682
+ # Create partial model instances that will raise errors for missing fields
683
+ partial_models = []
684
+ for data in projected_dicts :
685
+ partial_model = PartialModel (
686
+ model_class = self .model ,
687
+ data = data ,
688
+ loaded_fields = set (self .projected_fields ),
689
+ )
690
+ partial_models .append (partial_model )
691
+
692
+ return partial_models
693
+
553
694
@property
554
695
def query_params (self ):
555
696
params : List [Union [str , bytes ]] = []
@@ -669,6 +810,7 @@ def resolve_value(
669
810
op : Operators ,
670
811
value : Any ,
671
812
parents : List [Tuple [str , "RedisModel" ]],
813
+ model_class : Optional [Type ["RedisModel" ]] = None ,
672
814
) -> str :
673
815
# The 'field_name' should already include the correct prefix
674
816
result = ""
@@ -724,8 +866,18 @@ def resolve_value(
724
866
)
725
867
return ""
726
868
if isinstance (value , bool ):
869
+ # For HashModel, convert boolean to "1"/"0" to match storage format
870
+ # For JsonModel, keep as boolean since JSON supports native booleans
871
+ if model_class :
872
+ # Check if this is a HashModel by checking the class hierarchy
873
+ is_hash_model = any (
874
+ base .__name__ == "HashModel" for base in model_class .__mro__
875
+ )
876
+ bool_value = ("1" if value else "0" ) if is_hash_model else value
877
+ else :
878
+ bool_value = value
727
879
result = "@{field_name}:{{{value}}}" .format (
728
- field_name = field_name , value = value
880
+ field_name = field_name , value = bool_value
729
881
)
730
882
elif isinstance (value , int ):
731
883
# This if will hit only if the field is a primary key of type int
@@ -803,8 +955,7 @@ def resolve_redisearch_sort_fields(self):
803
955
if self .sort_fields :
804
956
return ["SORTBY" , * fields ]
805
957
806
- @classmethod
807
- def resolve_redisearch_query (cls , expression : ExpressionOrNegated ) -> str :
958
+ def _resolve_redisearch_query (self , expression : ExpressionOrNegated ) -> str :
808
959
"""
809
960
Resolve an arbitrarily deep expression into a single RediSearch query string.
810
961
@@ -848,9 +999,11 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
848
999
if isinstance (expression .left , Expression ) or isinstance (
849
1000
expression .left , NegatedExpression
850
1001
):
851
- result += f"({ cls . resolve_redisearch_query (expression .left )} )"
1002
+ result += f"({ self . _resolve_redisearch_query (expression .left )} )"
852
1003
elif isinstance (expression .left , FieldInfo ):
853
- field_type = cls .resolve_field_type (expression .left , expression .op )
1004
+ field_type = self .__class__ .resolve_field_type (
1005
+ expression .left , expression .op
1006
+ )
854
1007
field_name = expression .left .name
855
1008
field_info = expression .left
856
1009
if not field_info or not getattr (field_info , "index" , None ):
@@ -881,7 +1034,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
881
1034
result += "-"
882
1035
right = right .expression
883
1036
884
- result += f"({ cls . resolve_redisearch_query (right )} )"
1037
+ result += f"({ self . _resolve_redisearch_query (right )} )"
885
1038
else :
886
1039
if not field_name :
887
1040
raise QuerySyntaxError ("Could not resolve field name. See docs: TODO" )
@@ -890,13 +1043,14 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
890
1043
elif not field_info :
891
1044
raise QuerySyntaxError ("Could not resolve field info. See docs: TODO" )
892
1045
else :
893
- result += cls .resolve_value (
1046
+ result += self . __class__ .resolve_value (
894
1047
field_name ,
895
1048
field_type ,
896
1049
field_info ,
897
1050
expression .op ,
898
1051
right ,
899
1052
expression .parents ,
1053
+ self .model ,
900
1054
)
901
1055
902
1056
if encompassing_expression_is_negated :
@@ -951,16 +1105,19 @@ async def execute(
951
1105
return raw_result
952
1106
count = raw_result [0 ]
953
1107
954
- # If we're using field projection or explicitly requesting dict output,
955
- # return dictionaries instead of model instances
956
- if self .projected_fields or self .return_as_dict :
957
- if self .projected_fields :
958
- results = self ._parse_projected_results (raw_result )
959
- else :
960
- # Return all fields as dicts - need to convert from model instances
961
- model_results = self .model .from_redis (raw_result , self .knn )
962
- results = [model .model_dump () for model in model_results ]
1108
+ # Handle different result processing based on what was requested
1109
+ if self .projected_fields and self .return_as_dict :
1110
+ # .values('field1', 'field2') - specific fields as dicts
1111
+ results = self ._parse_projected_results (raw_result )
1112
+ elif self .projected_fields and not self .return_as_dict :
1113
+ # .only('field1', 'field2') - partial model instances
1114
+ results = self ._parse_projected_models (raw_result )
1115
+ elif self .return_as_dict and not self .projected_fields :
1116
+ # .values() - all fields as dicts
1117
+ model_results = self .model .from_redis (raw_result , self .knn )
1118
+ results = [model .model_dump () for model in model_results ]
963
1119
else :
1120
+ # Normal query - full model instances
964
1121
results = self .model .from_redis (raw_result , self .knn )
965
1122
self ._model_cache += results
966
1123
@@ -1019,10 +1176,10 @@ def sort_by(self, *fields: str):
1019
1176
def values (self , * fields : str ):
1020
1177
"""
1021
1178
Return query results as dictionaries instead of model instances.
1022
-
1179
+
1023
1180
If no fields are specified, returns all fields.
1024
1181
If fields are specified, returns only those fields.
1025
-
1182
+
1026
1183
Usage:
1027
1184
await Model.find().values() # All fields as dicts
1028
1185
await Model.find().values('name', 'email') # Only specified fields
@@ -1034,6 +1191,20 @@ def values(self, *fields: str):
1034
1191
# Return specific fields as dicts
1035
1192
return self .copy (return_as_dict = True , projected_fields = list (fields ))
1036
1193
1194
+ def only (self , * fields : str ):
1195
+ """
1196
+ Return query results as model instances with only the specified fields loaded.
1197
+
1198
+ Accessing fields that weren't loaded will raise an AttributeError.
1199
+ Uses Redis RETURN clause for efficient field projection.
1200
+
1201
+ Usage:
1202
+ await Model.find().only('name', 'email').all() # Partial model instances
1203
+ """
1204
+ if not fields :
1205
+ raise ValueError ("only() requires at least one field name" )
1206
+ return self .copy (projected_fields = list (fields ))
1207
+
1037
1208
async def update (self , use_transaction = True , ** field_values ):
1038
1209
"""
1039
1210
Update models that match this query to the given field-value pairs.
@@ -1766,6 +1937,13 @@ async def save(
1766
1937
1767
1938
# filter out values which are `None` because they are not valid in a HSET
1768
1939
document = {k : v for k , v in document .items () if v is not None }
1940
+
1941
+ # Convert boolean values to "1"/"0" for storage efficiency (Redis HSET doesn't support booleans)
1942
+ document = {
1943
+ k : ("1" if v else "0" ) if isinstance (v , bool ) else v
1944
+ for k , v in document .items ()
1945
+ }
1946
+
1769
1947
# TODO: Wrap any Redis response errors in a custom exception?
1770
1948
await db .hset (self .key (), mapping = document )
1771
1949
return self
0 commit comments