Skip to content

Commit 6a95333

Browse files
committed
Rename return_fields() to values() and add only() as well
1 parent 6e56b60 commit 6a95333

File tree

3 files changed

+366
-29
lines changed

3 files changed

+366
-29
lines changed

aredis_om/model/model.py

Lines changed: 199 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,64 @@
5353
log = logging.getLogger(__name__)
5454
escaper = TokenEscaper()
5555

56+
57+
class PartialModel:
58+
"""A partial model instance that only contains certain fields.
59+
60+
Accessing fields that weren't loaded will raise AttributeError.
61+
This is used for .only() queries to provide partial model instances.
62+
"""
63+
64+
def __init__(self, model_class, data: dict, loaded_fields: set):
65+
self.__dict__["_model_class"] = model_class
66+
self.__dict__["_loaded_fields"] = loaded_fields
67+
self.__dict__["_data"] = data
68+
69+
# Set the loaded field values
70+
for field_name, value in data.items():
71+
self.__dict__[field_name] = value
72+
73+
def __getattribute__(self, name):
74+
# Allow access to internal attributes and methods
75+
if name.startswith("_") or name in (
76+
"model_fields",
77+
"model_config",
78+
"__class__",
79+
"__dict__",
80+
):
81+
return super().__getattribute__(name)
82+
83+
# Get model class to check if this is a model field
84+
model_class = super().__getattribute__("_model_class")
85+
loaded_fields = super().__getattribute__("_loaded_fields")
86+
87+
# If it's a model field that wasn't loaded, raise an error
88+
if hasattr(model_class, "model_fields") and name in model_class.model_fields:
89+
if name not in loaded_fields:
90+
raise AttributeError(
91+
f"Field '{name}' is missing from this query. "
92+
f"Use .only('{name}') or .only({', '.join(repr(field) for field in sorted(loaded_fields.union({name})))}) to include it."
93+
)
94+
95+
return super().__getattribute__(name)
96+
97+
def __setattr__(self, name, value):
98+
# Allow setting internal attributes
99+
if name.startswith("_"):
100+
self.__dict__[name] = value
101+
else:
102+
# For regular fields, check if they were loaded
103+
if name not in self._loaded_fields:
104+
raise AttributeError(
105+
f"Cannot set field '{name}' - it is missing from this query."
106+
)
107+
self.__dict__[name] = value
108+
109+
def __repr__(self):
110+
loaded_data = {k: v for k, v in self._data.items() if k in self._loaded_fields}
111+
return f"Partial{self._model_class.__name__}({loaded_data})"
112+
113+
56114
# For basic exact-match field types like an indexed string, we create a TAG
57115
# field in the RediSearch index. TAG is designed for multi-value fields
58116
# separated by a "separator" character. We're using the field for single values
@@ -503,7 +561,7 @@ def query(self):
503561
"""
504562
if self._query:
505563
return self._query
506-
self._query = self.resolve_redisearch_query(self.expression)
564+
self._query = self._resolve_redisearch_query(self.expression)
507565
if self.knn:
508566
self._query = (
509567
self._query
@@ -541,15 +599,98 @@ def to_string(s):
541599
if res[i + offset] is None:
542600
continue
543601
# When using RETURN, we get flat key-value pairs
544-
fields: Dict[str, str] = dict(
602+
raw_fields: Dict[str, str] = dict(
545603
zip(
546604
map(to_string, res[i + offset][::2]),
547605
map(to_string, res[i + offset][1::2]),
548606
)
549607
)
550-
docs.append(fields)
608+
# Convert raw Redis strings to properly typed values
609+
converted_fields = self._convert_projected_fields(raw_fields)
610+
docs.append(converted_fields)
551611
return docs
552612

613+
def _convert_projected_fields(self, raw_data: Dict[str, str]) -> Dict[str, Any]:
614+
"""Convert raw Redis string values to properly typed values using model field info."""
615+
616+
# Fast path: Try creating a single model instance with all projected fields
617+
# This is more efficient and handles field interdependencies
618+
try:
619+
# Use model_validate instead of model_construct to ensure type conversion
620+
temp_model = self.model.model_validate(raw_data, strict=False)
621+
622+
# Use model_dump() to efficiently extract all converted values
623+
all_converted = temp_model.model_dump()
624+
625+
# Filter to only the fields we actually projected
626+
converted_data = {
627+
k: all_converted[k] for k in raw_data.keys() if k in all_converted
628+
}
629+
630+
return converted_data
631+
632+
except Exception: # nosec B110
633+
# If validation fails (due to missing required fields), fall back to individual conversion
634+
# This is expected for partial field sets
635+
pass
636+
637+
# Fallback path: Convert each field individually using type information
638+
converted_data = {}
639+
for field_name, raw_value in raw_data.items():
640+
if field_name not in self.model.model_fields:
641+
# Unknown field, keep as string
642+
converted_data[field_name] = raw_value
643+
continue
644+
645+
try:
646+
field_info = self.model.model_fields[field_name]
647+
648+
# Get the field type annotation
649+
if hasattr(field_info, "annotation"):
650+
field_type = field_info.annotation
651+
else:
652+
field_type = getattr(field_info, "type_", str)
653+
654+
# Handle common type conversions directly for efficiency
655+
if field_type == int:
656+
converted_data[field_name] = int(raw_value)
657+
elif field_type == float:
658+
converted_data[field_name] = float(raw_value)
659+
elif field_type == bool:
660+
# Redis may store bool as "True"/"False" or "1"/"0"
661+
converted_data[field_name] = raw_value.lower() in (
662+
"true",
663+
"1",
664+
"yes",
665+
)
666+
elif field_type == str:
667+
converted_data[field_name] = raw_value
668+
else:
669+
# For complex types, keep as string (could be enhanced later)
670+
converted_data[field_name] = raw_value
671+
672+
except (ValueError, TypeError):
673+
# If conversion fails, keep the raw string value
674+
converted_data[field_name] = raw_value
675+
676+
return converted_data
677+
678+
def _parse_projected_models(self, res: Any) -> List[PartialModel]:
679+
"""Parse results when using RETURN clause to create partial model instances."""
680+
projected_dicts = self._parse_projected_results(res)
681+
682+
# Create partial model instances that will raise errors for missing fields
683+
partial_models = []
684+
for data in projected_dicts:
685+
partial_model = PartialModel(
686+
model_class=self.model,
687+
data=data,
688+
loaded_fields=set(self.projected_fields),
689+
)
690+
partial_models.append(partial_model)
691+
692+
return partial_models
693+
553694
@property
554695
def query_params(self):
555696
params: List[Union[str, bytes]] = []
@@ -669,6 +810,7 @@ def resolve_value(
669810
op: Operators,
670811
value: Any,
671812
parents: List[Tuple[str, "RedisModel"]],
813+
model_class: Optional[Type["RedisModel"]] = None,
672814
) -> str:
673815
# The 'field_name' should already include the correct prefix
674816
result = ""
@@ -724,8 +866,18 @@ def resolve_value(
724866
)
725867
return ""
726868
if isinstance(value, bool):
869+
# For HashModel, convert boolean to "1"/"0" to match storage format
870+
# For JsonModel, keep as boolean since JSON supports native booleans
871+
if model_class:
872+
# Check if this is a HashModel by checking the class hierarchy
873+
is_hash_model = any(
874+
base.__name__ == "HashModel" for base in model_class.__mro__
875+
)
876+
bool_value = ("1" if value else "0") if is_hash_model else value
877+
else:
878+
bool_value = value
727879
result = "@{field_name}:{{{value}}}".format(
728-
field_name=field_name, value=value
880+
field_name=field_name, value=bool_value
729881
)
730882
elif isinstance(value, int):
731883
# This if will hit only if the field is a primary key of type int
@@ -803,8 +955,7 @@ def resolve_redisearch_sort_fields(self):
803955
if self.sort_fields:
804956
return ["SORTBY", *fields]
805957

806-
@classmethod
807-
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
958+
def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str:
808959
"""
809960
Resolve an arbitrarily deep expression into a single RediSearch query string.
810961
@@ -848,9 +999,11 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
848999
if isinstance(expression.left, Expression) or isinstance(
8491000
expression.left, NegatedExpression
8501001
):
851-
result += f"({cls.resolve_redisearch_query(expression.left)})"
1002+
result += f"({self._resolve_redisearch_query(expression.left)})"
8521003
elif isinstance(expression.left, FieldInfo):
853-
field_type = cls.resolve_field_type(expression.left, expression.op)
1004+
field_type = self.__class__.resolve_field_type(
1005+
expression.left, expression.op
1006+
)
8541007
field_name = expression.left.name
8551008
field_info = expression.left
8561009
if not field_info or not getattr(field_info, "index", None):
@@ -881,7 +1034,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
8811034
result += "-"
8821035
right = right.expression
8831036

884-
result += f"({cls.resolve_redisearch_query(right)})"
1037+
result += f"({self._resolve_redisearch_query(right)})"
8851038
else:
8861039
if not field_name:
8871040
raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
@@ -890,13 +1043,14 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
8901043
elif not field_info:
8911044
raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
8921045
else:
893-
result += cls.resolve_value(
1046+
result += self.__class__.resolve_value(
8941047
field_name,
8951048
field_type,
8961049
field_info,
8971050
expression.op,
8981051
right,
8991052
expression.parents,
1053+
self.model,
9001054
)
9011055

9021056
if encompassing_expression_is_negated:
@@ -951,16 +1105,19 @@ async def execute(
9511105
return raw_result
9521106
count = raw_result[0]
9531107

954-
# If we're using field projection or explicitly requesting dict output,
955-
# return dictionaries instead of model instances
956-
if self.projected_fields or self.return_as_dict:
957-
if self.projected_fields:
958-
results = self._parse_projected_results(raw_result)
959-
else:
960-
# Return all fields as dicts - need to convert from model instances
961-
model_results = self.model.from_redis(raw_result, self.knn)
962-
results = [model.model_dump() for model in model_results]
1108+
# Handle different result processing based on what was requested
1109+
if self.projected_fields and self.return_as_dict:
1110+
# .values('field1', 'field2') - specific fields as dicts
1111+
results = self._parse_projected_results(raw_result)
1112+
elif self.projected_fields and not self.return_as_dict:
1113+
# .only('field1', 'field2') - partial model instances
1114+
results = self._parse_projected_models(raw_result)
1115+
elif self.return_as_dict and not self.projected_fields:
1116+
# .values() - all fields as dicts
1117+
model_results = self.model.from_redis(raw_result, self.knn)
1118+
results = [model.model_dump() for model in model_results]
9631119
else:
1120+
# Normal query - full model instances
9641121
results = self.model.from_redis(raw_result, self.knn)
9651122
self._model_cache += results
9661123

@@ -1019,10 +1176,10 @@ def sort_by(self, *fields: str):
10191176
def values(self, *fields: str):
10201177
"""
10211178
Return query results as dictionaries instead of model instances.
1022-
1179+
10231180
If no fields are specified, returns all fields.
10241181
If fields are specified, returns only those fields.
1025-
1182+
10261183
Usage:
10271184
await Model.find().values() # All fields as dicts
10281185
await Model.find().values('name', 'email') # Only specified fields
@@ -1034,6 +1191,20 @@ def values(self, *fields: str):
10341191
# Return specific fields as dicts
10351192
return self.copy(return_as_dict=True, projected_fields=list(fields))
10361193

1194+
def only(self, *fields: str):
1195+
"""
1196+
Return query results as model instances with only the specified fields loaded.
1197+
1198+
Accessing fields that weren't loaded will raise an AttributeError.
1199+
Uses Redis RETURN clause for efficient field projection.
1200+
1201+
Usage:
1202+
await Model.find().only('name', 'email').all() # Partial model instances
1203+
"""
1204+
if not fields:
1205+
raise ValueError("only() requires at least one field name")
1206+
return self.copy(projected_fields=list(fields))
1207+
10371208
async def update(self, use_transaction=True, **field_values):
10381209
"""
10391210
Update models that match this query to the given field-value pairs.
@@ -1766,6 +1937,13 @@ async def save(
17661937

17671938
# filter out values which are `None` because they are not valid in a HSET
17681939
document = {k: v for k, v in document.items() if v is not None}
1940+
1941+
# Convert boolean values to "1"/"0" for storage efficiency (Redis HSET doesn't support booleans)
1942+
document = {
1943+
k: ("1" if v else "0") if isinstance(v, bool) else v
1944+
for k, v in document.items()
1945+
}
1946+
17691947
# TODO: Wrap any Redis response errors in a custom exception?
17701948
await db.hset(self.key(), mapping=document)
17711949
return self

0 commit comments

Comments
 (0)