diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index f95d4827..f36c8d58 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -53,6 +53,138 @@ log = logging.getLogger(__name__) escaper = TokenEscaper() + +class PartialModel: + """A partial model instance that only contains certain fields. + + Accessing fields that weren't loaded will raise AttributeError. + This is used for .only() queries to provide partial model instances. + """ + + def __init__(self, model_class, data: dict, loaded_fields: set): + self.__dict__["_model_class"] = model_class + self.__dict__["_loaded_fields"] = loaded_fields + self.__dict__["_data"] = data + + # Set the loaded field values, creating nested partial models where needed + for field_name, value in data.items(): + if isinstance(value, dict) and field_name in model_class.model_fields: + # Check if this should be a nested model + field_info = model_class.model_fields[field_name] + field_type = getattr(field_info, "annotation", None) + + try: + if isinstance(field_type, type) and issubclass( + field_type, RedisModel + ): + # Create a nested partial model + nested_loaded_fields = { + field.split("__", 1)[1] + for field in loaded_fields + if field.startswith(f"{field_name}__") + } + if nested_loaded_fields: + nested_partial = PartialModel( + model_class=field_type, + data=value, + loaded_fields=nested_loaded_fields, + ) + self.__dict__[field_name] = nested_partial + else: + # No deep fields for this nested model, but it's still data + self.__dict__[field_name] = value + else: + # Regular dict field + self.__dict__[field_name] = value + except TypeError: + # Not a model class, treat as regular dict + self.__dict__[field_name] = value + else: + # Regular field + self.__dict__[field_name] = value + + def __getattribute__(self, name): + # Allow access to internal attributes and methods + if name.startswith("_") or name in ( + "model_fields", + "model_config", + "__class__", + "__dict__", + ): + return super().__getattribute__(name) + + # Get model class to check if this is a model field + model_class = super().__getattribute__("_model_class") + loaded_fields = super().__getattribute__("_loaded_fields") + + # If it's a model field that wasn't loaded, raise an error + if hasattr(model_class, "model_fields") and name in model_class.model_fields: + # Check if this field or any deep fields starting with this field were loaded + field_loaded = name in loaded_fields + deep_fields_loaded = any( + field.startswith(f"{name}__") for field in loaded_fields + ) + + if not field_loaded and not deep_fields_loaded: + raise AttributeError( + f"Field '{name}' is missing from this query. " + f"Use .only('{name}') or .only({', '.join(repr(field) for field in sorted(loaded_fields.union({name})))}) to include it." + ) + + return super().__getattribute__(name) + + def __getattr__(self, name): + """Fallback for attribute access - supports flat deep field syntax like 'address__city'.""" + loaded_fields = self._loaded_fields + model_class = self._model_class + + # Check if this is a deep field that was loaded + if "__" in name and name in loaded_fields: + # Extract the value from the nested data structure + return self._extract_nested_value(self._data, name) + + # Check if this is a model field that wasn't loaded - provide helpful error message + if hasattr(model_class, "model_fields") and name in model_class.model_fields: + raise AttributeError( + f"Field '{name}' was not loaded from this query. " + f"Use .only('{name}') or .only({', '.join(repr(field) for field in sorted(loaded_fields.union({name})))}) to include it." + ) + + # If not found, raise the standard AttributeError + raise AttributeError( + f"'{model_class.__name__}' object has no attribute '{name}'" + ) + + def _extract_nested_value(self, data: dict, field_path: str): + """Extract nested value from dict using Django-like path syntax.""" + parts = field_path.split("__") + current = data + + for part in parts: + if isinstance(current, dict) and part in current: + current = current[part] + else: + return None + + return current + + def __setattr__(self, name, value): + # Allow setting internal attributes + if name.startswith("_"): + self.__dict__[name] = value + else: + # For regular fields, check if they were loaded + if name not in self._loaded_fields: + raise AttributeError( + f"Cannot set field '{name}' - it is missing from this query." + ) + self.__dict__[name] = value + + def __repr__(self): + loaded_data = {k: v for k, v in self._data.items() if k in self._loaded_fields} + return f"Partial{self._model_class.__name__}({loaded_data})" + + # For basic exact-match field types like an indexed string, we create a TAG # field in the RediSearch index. TAG is designed for multi-value fields # separated by a "separator" character. We're using the field for single values @@ -418,7 +550,9 @@ def __init__( limit: Optional[int] = None, page_size: int = DEFAULT_PAGE_SIZE, sort_fields: Optional[List[str]] = None, + projected_fields: Optional[List[str]] = None, nocontent: bool = False, + return_as_dict: bool = False, ): if not has_redisearch(model.db()): raise RedisModelError( @@ -442,6 +576,13 @@ def __init__( else: self.sort_fields = [] + if projected_fields: + self.projected_fields = self.validate_projected_fields(projected_fields) + else: + self.projected_fields = [] + + self.return_as_dict = return_as_dict + self._expression = None self._query: Optional[str] = None self._pagination: List[str] = [] @@ -455,7 +596,9 @@ def dict(self) -> Dict[str, Any]: limit=self.limit, expressions=copy(self.expressions), sort_fields=copy(self.sort_fields), + projected_fields=copy(self.projected_fields), nocontent=self.nocontent, + return_as_dict=self.return_as_dict, ) def copy(self, **kwargs): @@ -492,15 +635,418 @@ def query(self): """ if self._query: return self._query - self._query = self.resolve_redisearch_query(self.expression) + self._query = self._resolve_redisearch_query(self.expression) if self.knn: self._query = ( self._query if self._query.startswith("(") or self._query == "*" else f"({self._query})" ) + f"=>[{self.knn}]" + # RETURN clause should be added to args, not to the query string return self._query + def validate_projected_fields(self, projected_fields: List[str]): + for field in projected_fields: + if "__" in field: + # Deep field syntax - validate the path exists + self._validate_deep_field_path(field) + elif field not in self.model.model_fields: # type: ignore + raise QueryNotSupportedError( + f"You tried to return the field {field}, but that field " + f"does not exist on the model {self.model}" + ) + return projected_fields + + def _validate_deep_field_path(self, field_path: str): + """Validate that a deep field path like 'address__city' exists in the model.""" + parts = field_path.split("__") + current_model = self.model + current_field_name = parts[0] + + # Check the first part exists in the model + if current_field_name not in current_model.model_fields: + raise QueryNotSupportedError( + f"You tried to return the field {field_path}, but the root field " + f"{current_field_name} does not exist on the model {current_model}" + ) + + # Walk through the nested field path + for i, field_name in enumerate(parts): + if i == 0: + # First part - get the field info + field_info = current_model.model_fields[field_name] + field_type = getattr(field_info, "annotation", None) + + # Check if it's an embedded model + try: + if isinstance(field_type, type) and issubclass( + field_type, RedisModel + ): + current_model = field_type + elif field_type == dict: + # Dict fields - we can't validate nested paths, just accept them + return + else: + raise QueryNotSupportedError( + f"Deep field path {field_path} requires {field_name} to be an " + f"embedded model or dict, but it is {field_type}" + ) + except TypeError: + raise QueryNotSupportedError( + f"Deep field path {field_path} requires {field_name} to be an " + f"embedded model or dict, but it is {field_type}" + ) + else: + # Nested parts - check they exist in the embedded model + if ( + not hasattr(current_model, "model_fields") + or field_name not in current_model.model_fields + ): + raise QueryNotSupportedError( + f"You tried to return the field {field_path}, but the nested field " + f"{field_name} does not exist on the embedded model {current_model}" + ) + + # Update current_model for further nesting if needed + if i < len(parts) - 1: # Not the last part + field_info = current_model.model_fields[field_name] + field_type = getattr(field_info, "annotation", None) + try: + if isinstance(field_type, type) and issubclass( + field_type, RedisModel + ): + current_model = field_type + elif field_type == dict: + return # Can't validate further into dict + else: + raise QueryNotSupportedError( + f"Deep field path {field_path} requires {field_name} to be an " + f"embedded model or dict for further nesting" + ) + except TypeError: + raise QueryNotSupportedError( + f"Deep field path {field_path} requires {field_name} to be an " + f"embedded model or dict for further nesting" + ) + + def _parse_projected_results(self, res: Any) -> List[Dict[str, Any]]: + """Parse results when using RETURN clause with specific fields.""" + + def to_string(s): + if isinstance(s, (str,)): + return s + elif isinstance(s, bytes): + return s.decode(errors="ignore") + else: + return s + + docs = [] + step = 2 # Because the result has content + offset = 1 # The first item is the count of total matches. + + for i in range(1, len(res), step): + if res[i + offset] is None: + continue + # When using RETURN, we get flat key-value pairs + raw_fields: Dict[str, str] = dict( + zip( + map(to_string, res[i + offset][::2]), + map(to_string, res[i + offset][1::2]), + ) + ) + # Convert raw Redis strings to properly typed values + converted_fields = self._convert_projected_fields(raw_fields) + docs.append(converted_fields) + return docs + + def _convert_projected_fields(self, raw_data: Dict[str, str]) -> Dict[str, Any]: + """Convert raw Redis string values to properly typed values using model field info.""" + + # Fast path: Try creating a single model instance with all projected fields + # This is more efficient and handles field interdependencies + try: + # Use model_validate instead of model_construct to ensure type conversion + temp_model = self.model.model_validate(raw_data, strict=False) + + # Use model_dump() to efficiently extract all converted values + all_converted = temp_model.model_dump() + + # Filter to only the fields we actually projected + converted_data = { + k: all_converted[k] for k in raw_data.keys() if k in all_converted + } + + return converted_data + + except Exception: # nosec B110 + # If validation fails (due to missing required fields), fall back to individual conversion + # This is expected for partial field sets + pass + + # Fallback path: Convert each field individually using type information + converted_data = {} + for field_name, raw_value in raw_data.items(): + if field_name not in self.model.model_fields: + # Unknown field, keep as string + converted_data[field_name] = raw_value + continue + + try: + field_info = self.model.model_fields[field_name] + + # Get the field type annotation + if hasattr(field_info, "annotation"): + field_type = field_info.annotation + else: + field_type = getattr(field_info, "type_", str) + + # Handle common type conversions directly for efficiency + if field_type == int: + converted_data[field_name] = int(raw_value) + elif field_type == float: + converted_data[field_name] = float(raw_value) + elif field_type == bool: + # Redis may store bool as "True"/"False" or "1"/"0" + converted_data[field_name] = raw_value.lower() in ( + "true", + "1", + "yes", + ) + elif field_type == str: + converted_data[field_name] = raw_value + else: + # For complex types, keep as string (could be enhanced later) + converted_data[field_name] = raw_value + + except (ValueError, TypeError): + # If conversion fails, keep the raw string value + converted_data[field_name] = raw_value + + return converted_data + + def _parse_projected_models(self, res: Any) -> List[PartialModel]: + """Parse results when using RETURN clause to create partial model instances.""" + projected_dicts = self._parse_projected_results(res) + + # Create partial model instances that will raise errors for missing fields + partial_models = [] + for data in projected_dicts: + partial_model = PartialModel( + model_class=self.model, + data=data, + loaded_fields=set(self.projected_fields), + ) + partial_models.append(partial_model) + + return partial_models + + def _has_complex_projected_fields(self) -> bool: + """Check if any projected fields are complex types that RediSearch RETURN can't handle.""" + # Only check for JsonModel - HashModel doesn't support complex fields anyway + if not any(base.__name__ == "JsonModel" for base in self.model.__mro__): + return False + + for field_name in self.projected_fields: + # Deep field syntax always requires complex handling + if "__" in field_name: + return True + + if field_name not in self.model.model_fields: + continue + + field_info = self.model.model_fields[field_name] + field_type = getattr(field_info, "annotation", None) + + # Check for dict fields + if field_type == dict: + return True + + # Check for embedded models (subclasses of RedisModel) + try: + if isinstance(field_type, type) and issubclass(field_type, RedisModel): + return True + except TypeError: + pass + + # Check for List/Dict generic types + origin = get_origin(field_type) + if origin in (list, dict, tuple): + return True + + return False + + async def _parse_full_document_projection_as_dict( + self, res: Any + ) -> List[Dict[str, Any]]: + """Parse results using efficient JSON.GET with JSONPath for deep field projection.""" + # Check if this is a JsonModel - only JsonModels support JSON.GET + is_json_model = any(base.__name__ == "JsonModel" for base in self.model.__mro__) + + if is_json_model: + return await self._parse_json_path_projection_as_dict(res) + else: + # Fallback for HashModel (shouldn't happen since HashModel doesn't support deep fields) + return await self._parse_fallback_projection_as_dict(res) + + async def _parse_json_path_projection_as_dict( + self, res: Any + ) -> List[Dict[str, Any]]: + """Use JSON.GET with JSONPath to efficiently extract deep fields.""" + # Extract document keys from search results + doc_keys = [] + step = 2 # Because the result has content + + for i in range(1, len(res), step): + if i < len(res): + doc_key = res[i] # Document key + if isinstance(doc_key, bytes): + doc_key = doc_key.decode("utf-8") + doc_keys.append(doc_key) + + if not doc_keys: + return [] + + # Convert field names to JSONPath expressions + json_paths = [] + for field_name in self.projected_fields: + if "__" in field_name: + # Deep field: address__city -> $.address.city + json_path = "$." + field_name.replace("__", ".") + else: + # Regular field: name -> $.name + json_path = f"$.{field_name}" + json_paths.append(json_path) + + # Batch get all projected fields for all documents + projected_results = [] + db = self.model.db() + + for doc_key in doc_keys: + try: + # Get multiple JSONPath expressions in one call + result = await db.json().get(doc_key, *json_paths) + + if result is None: + continue + + # Convert JSONPath results back to field names + projected_data = {} + if isinstance(result, dict): + # Multiple paths returned as dict + for json_path, values in result.items(): + # Convert $.address.city back to address__city + field_name = json_path[2:].replace( + ".", "__" + ) # Remove "$." and convert dots to __ + # JSON.GET returns arrays, take first value + if values and len(values) > 0: + projected_data[field_name] = values[0] + else: + # Single path - shouldn't happen with multiple paths, but handle it + if len(json_paths) == 1: + field_name = json_paths[0][2:].replace(".", "__") + if isinstance(result, list) and result: + projected_data[field_name] = result[0] + + projected_results.append(projected_data) + + except Exception: # nosec B112 + # If JSON.GET fails (connection, parsing, etc.), skip this document + continue + + return projected_results + + async def _parse_fallback_projection_as_dict( + self, res: Any + ) -> List[Dict[str, Any]]: + """Fallback method using full document parsing (for HashModel or when JSON.GET fails).""" + # Get full model instances first + full_models = self.model.from_redis(res, self.knn) + + # Project only the requested fields + projected_results = [] + for model in full_models: + model_data = model.model_dump() + projected_data = {} + + for field_name in self.projected_fields: + if "__" in field_name: + # Deep field syntax - extract nested value + nested_value = self._extract_nested_value(model_data, field_name) + if nested_value is not None: + projected_data[field_name] = nested_value + elif field_name in model_data: + projected_data[field_name] = model_data[field_name] + + projected_results.append(projected_data) + + return projected_results + + def _extract_nested_value(self, data: Dict[str, Any], field_path: str) -> Any: + """Extract nested value from dict using Django-like path syntax.""" + parts = field_path.split("__") + current = data + + for part in parts: + if isinstance(current, dict) and part in current: + current = current[part] + else: + return None + + return current + + async def _parse_full_document_projection_as_models( + self, res: Any + ) -> List[PartialModel]: + """Parse full document results and project only requested fields as partial models.""" + # Get the projected data first + projected_dicts = await self._parse_full_document_projection_as_dict(res) + + # Create partial model instances with nested structure + partial_models = [] + for data in projected_dicts: + # Construct nested partial model data + nested_data = self._construct_nested_partial_data(data) + partial_model = PartialModel( + model_class=self.model, + data=nested_data, + loaded_fields=set(self.projected_fields), + ) + partial_models.append(partial_model) + + return partial_models + + def _construct_nested_partial_data( + self, flat_data: Dict[str, Any] + ) -> Dict[str, Any]: + """Construct nested data structure from flat deep field results.""" + nested_data: Dict[str, Any] = {} + + for field_name, value in flat_data.items(): + if "__" in field_name: + # Deep field - construct nested structure + self._set_nested_value(nested_data, field_name, value) + else: + # Regular field - set directly + nested_data[field_name] = value + + return nested_data + + def _set_nested_value(self, data: Dict[str, Any], field_path: str, value: Any): + """Set a nested value in data dict using Django-like path syntax.""" + parts = field_path.split("__") + current = data + + # Navigate/create the nested structure + for i, part in enumerate(parts[:-1]): + if part not in current: + # Create a nested dict for the next level + current[part] = {} + current = current[part] + + # Set the final value + current[parts[-1]] = value + @property def query_params(self): params: List[Union[str, bytes]] = [] @@ -620,6 +1166,7 @@ def resolve_value( op: Operators, value: Any, parents: List[Tuple[str, "RedisModel"]], + model_class: Optional[Type["RedisModel"]] = None, ) -> str: # The 'field_name' should already include the correct prefix result = "" @@ -675,8 +1222,18 @@ def resolve_value( ) return "" if isinstance(value, bool): + # For HashModel, convert boolean to "1"/"0" to match storage format + # For JsonModel, keep as boolean since JSON supports native booleans + if model_class: + # Check if this is a HashModel by checking the class hierarchy + is_hash_model = any( + base.__name__ == "HashModel" for base in model_class.__mro__ + ) + bool_value = ("1" if value else "0") if is_hash_model else value + else: + bool_value = value result = "@{field_name}:{{{value}}}".format( - field_name=field_name, value=value + field_name=field_name, value=bool_value ) elif isinstance(value, int): # This if will hit only if the field is a primary key of type int @@ -754,8 +1311,7 @@ def resolve_redisearch_sort_fields(self): if self.sort_fields: return ["SORTBY", *fields] - @classmethod - def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: + def _resolve_redisearch_query(self, expression: ExpressionOrNegated) -> str: """ Resolve an arbitrarily deep expression into a single RediSearch query string. @@ -799,9 +1355,11 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: if isinstance(expression.left, Expression) or isinstance( expression.left, NegatedExpression ): - result += f"({cls.resolve_redisearch_query(expression.left)})" + result += f"({self._resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, FieldInfo): - field_type = cls.resolve_field_type(expression.left, expression.op) + field_type = self.__class__.resolve_field_type( + expression.left, expression.op + ) field_name = expression.left.name field_info = expression.left if not field_info or not getattr(field_info, "index", None): @@ -832,7 +1390,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: result += "-" right = right.expression - result += f"({cls.resolve_redisearch_query(right)})" + result += f"({self._resolve_redisearch_query(right)})" else: if not field_name: raise QuerySyntaxError("Could not resolve field name. See docs: TODO") @@ -841,13 +1399,14 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: elif not field_info: raise QuerySyntaxError("Could not resolve field info. See docs: TODO") else: - result += cls.resolve_value( + result += self.__class__.resolve_value( field_name, field_type, field_info, expression.op, right, expression.parents, + self.model, ) if encompassing_expression_is_negated: @@ -882,6 +1441,18 @@ async def execute( if self.nocontent: args.append("NOCONTENT") + # Check if we have complex fields that RediSearch RETURN clause can't handle + use_full_document_fallback = False + if self.projected_fields: + use_full_document_fallback = self._has_complex_projected_fields() + + # Add RETURN clause to the args list, not to the query string + # Skip RETURN clause if we need full documents for complex field projection + if self.projected_fields and not use_full_document_fallback: + args.extend( + ["RETURN", str(len(self.projected_fields))] + self.projected_fields + ) + if return_query_args: return self.model.Meta.index_name, args @@ -895,7 +1466,29 @@ async def execute( if return_raw_result: return raw_result count = raw_result[0] - results = self.model.from_redis(raw_result, self.knn) + + # Handle different result processing based on what was requested + if self.projected_fields and use_full_document_fallback: + # Complex field projection - use full document fallback + if self.return_as_dict: + results = await self._parse_full_document_projection_as_dict(raw_result) + else: + results = await self._parse_full_document_projection_as_models( + raw_result + ) + elif self.projected_fields and self.return_as_dict: + # .values('field1', 'field2') - specific fields as dicts + results = self._parse_projected_results(raw_result) + elif self.projected_fields and not self.return_as_dict: + # .only('field1', 'field2') - partial model instances + results = self._parse_projected_models(raw_result) + elif self.return_as_dict and not self.projected_fields: + # .values() - all fields as dicts + model_results = self.model.from_redis(raw_result, self.knn) + results = [model.model_dump() for model in model_results] + else: + # Normal query - full model instances + results = self.model.from_redis(raw_result, self.knn) self._model_cache += results if not exhaust_results: @@ -950,6 +1543,38 @@ def sort_by(self, *fields: str): return self return self.copy(sort_fields=list(fields)) + def values(self, *fields: str): + """ + Return query results as dictionaries instead of model instances. + + If no fields are specified, returns all fields. + If fields are specified, returns only those fields. + + Usage: + await Model.find().values() # All fields as dicts + await Model.find().values('name', 'email') # Only specified fields + """ + if not fields: + # Return all fields as dicts + return self.copy(return_as_dict=True) + else: + # Return specific fields as dicts + return self.copy(return_as_dict=True, projected_fields=list(fields)) + + def only(self, *fields: str): + """ + Return query results as model instances with only the specified fields loaded. + + Accessing fields that weren't loaded will raise an AttributeError. + Uses Redis RETURN clause for efficient field projection. + + Usage: + await Model.find().only('name', 'email').all() # Partial model instances + """ + if not fields: + raise ValueError("only() requires at least one field name") + return self.copy(projected_fields=list(fields)) + async def update(self, use_transaction=True, **field_values): """ Update models that match this query to the given field-value pairs. @@ -1682,6 +2307,13 @@ async def save( # filter out values which are `None` because they are not valid in a HSET document = {k: v for k, v in document.items() if v is not None} + + # Convert boolean values to "1"/"0" for storage efficiency (Redis HSET doesn't support booleans) + document = { + k: ("1" if v else "0") if isinstance(v, bool) else v + for k, v in document.items() + } + # TODO: Wrap any Redis response errors in a custom exception? await db.hset(self.key(), mapping=document) return self diff --git a/docs/models.md b/docs/models.md index 79f174c8..f44a4c03 100644 --- a/docs/models.md +++ b/docs/models.md @@ -272,3 +272,246 @@ from redis_om import ( redis = get_redis_connection() Migrator().run() ``` + +## Field Projection + +Redis OM supports field projection, which allows you to retrieve only specific fields from your models rather than loading all fields. This can improve performance and reduce memory usage, especially for models with many fields. + +There are two main methods for field projection: + +### `.values()` - Dictionary Results + +The `.values()` method returns query results as dictionaries instead of model instances: + +```python +from redis_om import HashModel, Field + +class Customer(HashModel): + first_name: str = Field(index=True) + last_name: str = Field(index=True) + email: str = Field(index=True) + age: int = Field(index=True) + bio: str + +# Get all fields as dictionaries +customers = Customer.find().values() +# Returns: [{"first_name": "John", "last_name": "Doe", "email": "john@example.com", "age": 30, "bio": "..."}] + +# Get only specific fields as dictionaries +customers = Customer.find().values("first_name", "email") +# Returns: [{"first_name": "John", "email": "john@example.com"}] +``` + +### `.only()` - Partial Model Instances + +The `.only()` method returns partial model instances that contain only the specified fields. Accessing fields that weren't loaded will raise an `AttributeError`: + +```python +# Get partial model instances with only specific fields +customers = Customer.find().only("first_name", "email") + +for customer in customers: + print(customer.first_name) # ✓ Works - field was loaded + print(customer.email) # ✓ Works - field was loaded + print(customer.age) # ✗ Raises AttributeError - field not loaded +``` + +### Performance Benefits + +Both methods use Redis's `RETURN` clause for efficient field projection at the database level, which means: +- Only requested fields are transferred over the network +- Less memory usage on both Redis and client side +- Faster query execution for large models +- Automatic type conversion for returned fields + +### Type Conversion + +Redis OM automatically converts field values to their proper Python types based on your model field definitions: + +```python +class Product(HashModel): + name: str = Field(index=True) + price: float = Field(index=True) + in_stock: bool = Field(index=True) + created_at: datetime.datetime = Field(index=True) + +# Values are automatically converted to correct types +products = Product.find().values("name", "price", "in_stock") +# Returns: [{"name": "Widget", "price": 19.99, "in_stock": True}] +# Note: price is float, in_stock is bool (not strings) +``` + +### Combining with Other Query Methods + +Field projection works seamlessly with other query methods: + +```python +# Combine with filtering and sorting +expensive_products = Product.find( + Product.price > 100 +).sort_by("price").only("name", "price") + +# Combine with pagination +first_page = Product.find().values("name", "price").page(0, 10) + +# Use with async queries (for async models) +products = await AsyncProduct.find().values("name", "price").all() +``` + +### Deep Field Projection + +Redis OM supports Django-like deep field projection using double underscore (`__`) syntax to access nested fields in embedded models and dictionaries. This allows you to extract specific values from complex nested structures without loading the entire object. + +#### Embedded Model Fields + +Extract fields from embedded models using the `field__subfield` syntax: + +```python +from redis_om import JsonModel, Field + +class Address(JsonModel): + street: str + city: str + zipcode: str = Field(index=True) + country: str = "USA" + + class Meta: + embedded = True + +class Customer(JsonModel, index=True): + name: str = Field(index=True) + age: int = Field(index=True) + address: Address + metadata: dict = Field(default_factory=dict) + +# Extract nested fields from embedded models +customers = Customer.find().values("name", "address__city", "address__zipcode") +# Returns: [{"name": "John Doe", "address__city": "Anytown", "address__zipcode": "12345"}] + +# Works with .only() method too +customer = Customer.find().only("name", "address__street").first() +print(customer.name) # ✓ Works +print(getattr(customer, "address__street")) # ✓ Works - returns "123 Main St" +print(customer.age) # ✗ Raises AttributeError - not loaded +``` + +#### Dictionary Field Access + +Access nested dictionary values using the same syntax: + +```python +# Sample data with nested dictionary +customer_data = { + "name": "John Doe", + "metadata": { + "role": "admin", + "preferences": { + "theme": "dark", + "notifications": True, + "settings": { + "language": "en" + } + } + } +} + +# Extract values at any nesting level +result = Customer.find().values( + "name", + "metadata__role", + "metadata__preferences__theme", + "metadata__preferences__settings__language" +) +# Returns: [{ +# "name": "John Doe", +# "metadata__role": "admin", +# "metadata__preferences__theme": "dark", +# "metadata__preferences__settings__language": "en" +# }] +``` + +#### Mixed Deep Fields + +Combine regular fields, embedded model fields, and dictionary fields in a single query: + +```python +# Mix all types of field projection +customers = Customer.find().values( + "name", # Regular field + "age", # Regular field + "address__city", # Embedded model field + "address__country", # Embedded model field + "metadata__role", # Dictionary field + "metadata__preferences__theme" # Nested dictionary field +) +``` + +#### Validation and Error Handling + +Deep field paths are fully validated to ensure they exist in your model hierarchy: + +```python +# ✓ Valid - address is an embedded model with a city field +Customer.find().values("name", "address__city") + +# ✗ Invalid - nonexistent root field +Customer.find().values("name", "nonexistent__field") +# Raises: QueryNotSupportedError + +# ✗ Invalid - city is not a complex field +Customer.find().values("name", "address__city__invalid") +# Raises: QueryNotSupportedError + +# ✗ Invalid - address exists but zipcode_invalid doesn't +Customer.find().values("name", "address__zipcode_invalid") +# Raises: QueryNotSupportedError +``` + +#### Performance Considerations + +Deep field projection automatically uses the full document fallback strategy for optimal data access: + +- **Simple fields only**: Uses efficient Redis `RETURN` clause +- **Deep fields present**: Queries full documents and extracts requested fields +- **Automatic detection**: No manual configuration needed +- **Type preservation**: All nested values maintain their proper Python types + +```python +# This query uses RETURN clause (efficient) +Customer.find().values("name", "age") + +# This query uses fallback (still efficient, but queries full documents) +Customer.find().values("name", "address__city") +``` + +### Limitations + +Field projection has some limitations to be aware of: + +#### Complex Field Types (JsonModel only) + +For `JsonModel`, complex field types (embedded models, dictionaries, lists) cannot be projected using Redis's `RETURN` clause. Redis OM automatically falls back to querying full documents and manually extracting the requested fields, but this means: + +- **HashModel**: All simple field types work with efficient projection +- **JsonModel**: Simple fields use efficient projection, complex fields use fallback +- **Performance**: Fallback is still fast but transfers more data + +#### Supported vs Unsupported Field Types + +```python +# ✓ Supported for efficient projection (all model types) +class Product(HashModel): # or JsonModel + name: str = Field(index=True) # ✓ String fields + price: float = Field(index=True) # ✓ Numeric fields + active: bool = Field(index=True) # ✓ Boolean fields + created: datetime = Field(index=True) # ✓ DateTime fields + +# JsonModel: These use fallback strategy (still supported) +class Customer(JsonModel): + profile: UserProfile # Uses fallback (embedded model) + settings: dict # Uses fallback (dictionary) + tags: List[str] # Uses fallback (list) + + # Deep field access works for all complex types + result = Customer.find().values("name", "profile__email", "settings__theme") +``` diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 928986c4..187f3e32 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -934,19 +934,23 @@ class TestUpdate(HashModel, index=True): @py_test_mark_asyncio -async def test_literals(): +async def test_literals(key_prefix, redis): from typing import Literal class TestLiterals(HashModel, index=True): flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") + class Meta: + global_key_prefix = key_prefix + database = redis + schema = TestLiterals.redisearch_schema() - key_prefix = TestLiterals.make_key( + expected_key_prefix = TestLiterals.make_key( TestLiterals._meta.primary_key_pattern.format(pk="") ) assert schema == ( - f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |" + f"ON HASH PREFIX 1 {expected_key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |" ) await Migrator().run() item = TestLiterals(flavor="pumpkin") @@ -1133,6 +1137,192 @@ class Meta: ).first() +@py_test_mark_asyncio +async def test_values_method_with_specific_fields(members, m): + member1, member2, member3 = members + actual = await ( + m.Member.find( + (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") + | (m.Member.last_name == "Smith") + ) + .sort_by("last_name") + .values("first_name", "last_name") + .all() + ) + assert actual == [ + {"first_name": "Andrew", "last_name": "Brookins"}, + {"first_name": "Andrew", "last_name": "Smith"}, + ] + + +@py_test_mark_asyncio +async def test_values_method_all_fields(members, m): + member1, member2, member3 = members + actual = await m.Member.find(m.Member.first_name == "Andrew").values().all() + + # Check that it returns all fields as dicts + assert len(actual) == 2 # Should find Andrew Brookins and Andrew Smith + # Verify it contains all fields as dictionaries + for result in actual: + assert "first_name" in result + assert "last_name" in result + assert "email" in result + assert "age" in result + assert "pk" in result # Should include primary key + assert result["first_name"] == "Andrew" + + +@py_test_mark_asyncio +async def test_only_method_basic(members, m): + """Test basic .only() method functionality - partial model instances.""" + member1, member2, member3 = members + + # Test .only() with specific fields + actual = await ( + m.Member.find(m.Member.first_name == "Andrew") + .sort_by("last_name") + .only("first_name", "last_name") + .all() + ) + + # Should return PartialModel instances, not dicts + assert len(actual) == 2 + + result = actual[0] # Test with first result + + # Should be able to access loaded fields + assert result.first_name == "Andrew" + assert result.last_name in ["Brookins", "Smith"] + + # Should NOT be a dict + assert not isinstance(result, dict) + + # Should be a PartialModel instance + from aredis_om.model.model import PartialModel + + assert isinstance(result, PartialModel) + + # Accessing unloaded fields should raise AttributeError + with pytest.raises(AttributeError, match="Field 'email' was not loaded"): + _ = result.email + + with pytest.raises(AttributeError, match="Field 'age' was not loaded"): + _ = result.age + + +@py_test_mark_asyncio +async def test_only_method_error_messages(members, m): + """Test that .only() provides helpful error messages.""" + member1, member2, member3 = members + + # Test .only() with one field + results = await m.Member.find().only("first_name").all() + result = results[0] + + # Accessing unloaded field should suggest the correct .only() usage + with pytest.raises(AttributeError) as exc_info: + _ = result.email + + error_message = str(exc_info.value) + assert "Field 'email' was not loaded" in error_message + assert ".only('email')" in error_message or ".only(" in error_message + + +@py_test_mark_asyncio +async def test_values_type_conversion(members, m): + """Test that .values() returns properly typed values, not just strings.""" + member1, member2, member3 = members + + # Test .values() with specific fields - should return proper types + results = await m.Member.find().values("first_name", "age").all() + result = results[0] + + # Should be dictionary with proper types + assert isinstance(result, dict) + assert isinstance(result["first_name"], str) + assert isinstance( + result["age"], int + ), f"Expected int, got {type(result['age'])} with value {result['age']}" + + +@py_test_mark_asyncio +async def test_only_type_conversion(members, m): + """Test that .only() returns properly typed values, not just strings.""" + member1, member2, member3 = members + + # Test .only() with specific fields - should return proper types in PartialModel + results = await m.Member.find().only("first_name", "age").all() + result = results[0] + + # Should be PartialModel with proper types + from aredis_om.model.model import PartialModel + + assert isinstance(result, PartialModel) + assert isinstance(result.first_name, str) + assert isinstance( + result.age, int + ), f"Expected int, got {type(result.age)} with value {result.age}" + + +@py_test_mark_asyncio +async def test_boolean_fields_work_with_hash_model(key_prefix, redis): + """Test that boolean fields work correctly with HashModel.""" + + class BoolTestModel(HashModel, index=True): + name: str = Field(index=True) + is_active: bool = Field(index=True) + is_admin: bool = Field(index=True, default=False) + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + # Test saving and retrieving boolean fields + model = BoolTestModel(name="test_user", is_active=True, is_admin=False) + await model.save() + + # Test get() method + retrieved = await BoolTestModel.get(model.pk) + assert retrieved.name == "test_user" + assert retrieved.is_active is True + assert retrieved.is_admin is False + assert isinstance(retrieved.is_active, bool) + assert isinstance(retrieved.is_admin, bool) + + # Test querying by boolean fields + active_users = await BoolTestModel.find(BoolTestModel.is_active == True).all() + assert len(active_users) == 1 + assert active_users[0].name == "test_user" + assert active_users[0].is_active is True + + non_admin_users = await BoolTestModel.find(BoolTestModel.is_admin == False).all() + assert len(non_admin_users) == 1 + assert non_admin_users[0].name == "test_user" + assert non_admin_users[0].is_admin is False + + # Test .values() with boolean fields + values_result = ( + await BoolTestModel.find().values("name", "is_active", "is_admin").all() + ) + assert len(values_result) == 1 + result = values_result[0] + assert result["name"] == "test_user" + assert result["is_active"] is True + assert result["is_admin"] is False + assert isinstance(result["is_active"], bool) + assert isinstance(result["is_admin"], bool) + + # Test .only() with boolean fields + only_result = await BoolTestModel.find().only("name", "is_active").all() + assert len(only_result) == 1 + result = only_result[0] + assert result.name == "test_user" + assert result.is_active is True + assert isinstance(result.is_active, bool) + + @py_test_mark_asyncio async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis): class Location(HashModel, index=True): diff --git a/tests/test_json_model.py b/tests/test_json_model.py index c8bd4031..5474eb7a 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -956,6 +956,41 @@ class TypeWithUuid(JsonModel, index=True): await item.save() +@py_test_mark_asyncio +async def test_values_method_with_specific_fields(members, m): + member1, member2, member3 = members + actual = await ( + m.Member.find( + (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") + | (m.Member.last_name == "Smith") + ) + .sort_by("last_name") + .values("first_name", "last_name") + .all() + ) + assert actual == [ + {"first_name": "Andrew", "last_name": "Brookins"}, + {"first_name": "Andrew", "last_name": "Smith"}, + ] + + +@py_test_mark_asyncio +async def test_values_method_all_fields(members, m): + member1, member2, member3 = members + actual = await m.Member.find(m.Member.first_name == "Andrew").values().all() + + # Check that it returns all fields as dicts + assert len(actual) == 2 # Should find Andrew Brookins and Andrew Smith + # Verify it contains all fields as dictionaries + for result in actual: + assert "first_name" in result + assert "last_name" in result + assert "email" in result + assert "age" in result + assert "pk" in result # Should include primary key + assert result["first_name"] == "Andrew" + + @py_test_mark_asyncio async def test_type_with_enum(): class TestEnum(Enum): @@ -1181,19 +1216,23 @@ async def get_page(cls, offset, limit): @py_test_mark_asyncio -async def test_literals(): +async def test_literals(key_prefix, redis): from typing import Literal class TestLiterals(JsonModel, index=True): flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") + class Meta: + global_key_prefix = key_prefix + database = redis + schema = TestLiterals.redisearch_schema() - key_prefix = TestLiterals.make_key( + expected_key_prefix = TestLiterals.make_key( TestLiterals._meta.primary_key_pattern.format(pk="") ) assert schema == ( - f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | " + f"ON JSON PREFIX 1 {expected_key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | " "$.flavor AS flavor TAG SEPARATOR |" ) await Migrator().run()