Skip to content

Commit 5b91067

Browse files
committed
fix: correctly implement field projection with return_fields
- Move RETURN clause from query string to args list (fixes RediSearch syntax) - Rename internal parameter from return_fields to projected_fields to avoid conflict with method name - Return dictionaries instead of model instances when using field projection - Add proper parsing for projected results that returns flat key-value pairs - Add tests for both HashModel and JsonModel field projection - Update validation to use model_fields instead of deprecated __fields__ This addresses all review comments from PR #633 and implements field projection correctly.
1 parent 5af2e8e commit 5b91067

File tree

5 files changed

+200
-20
lines changed

5 files changed

+200
-20
lines changed

.claude/settings.local.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Bash(gh pr checkout:*)",
5+
"Bash(git rebase:*)",
6+
"Bash(git stash:*)",
7+
"Bash(make:*)",
8+
"Bash(poetry lock:*)",
9+
"Bash(poetry run pytest:*)",
10+
"Bash(poetry run:*)",
11+
"Bash(rm:*)",
12+
"Bash(gh pr view:*)",
13+
"Bash(gh api:*)",
14+
"Bash(grep:*)",
15+
"Bash(git add:*)",
16+
"Bash(git push:*)",
17+
"WebFetch(domain:github.com)",
18+
"Bash(npm install:*)",
19+
"Bash(cspell:*)",
20+
"Bash(git commit:*)"
21+
],
22+
"deny": [],
23+
"ask": []
24+
}
25+
}

CLAUDE.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Repository Overview
6+
7+
Redis OM Python is an object mapping library that provides declarative models for Redis data, built on Pydantic for validation and serialization. It supports both Redis Hash and JSON storage with automatic indexing via RediSearch.
8+
9+
## Development Commands
10+
11+
### Essential Commands
12+
```bash
13+
# Setup and dependencies
14+
make install # Install all dependencies via Poetry
15+
make redis # Start Redis Stack (6380) and Redis OSS (6381) containers
16+
17+
# Code generation (CRITICAL - always run after modifying async code)
18+
make sync # Generate sync code from async using unasync
19+
20+
# Code quality
21+
make format # Format with isort and black
22+
make lint # Run flake8, mypy, and bandit
23+
make test # Run full test suite against Redis Stack
24+
make test_oss # Run tests against OSS Redis
25+
26+
# Individual test commands
27+
poetry run pytest tests/test_hash_model.py::test_saves_model # Run specific test
28+
poetry run pytest -n auto -vv ./tests/ ./tests_sync/ # Run all tests with parallelization
29+
```
30+
31+
### Database Migrations
32+
```bash
33+
poetry run migrate # Run schema migrations for index changes
34+
```
35+
36+
## Architecture and Code Organization
37+
38+
### Dual Async/Sync Architecture
39+
**CRITICAL**: This codebase uses a unique dual-implementation approach:
40+
- **Primary source**: `/aredis_om/` - All features implemented here first (async)
41+
- **Generated code**: `/redis_om/` - Automatically generated sync version
42+
- **Generation**: `make_sync.py` uses `unasync` to transform async → sync
43+
44+
**Development Rule**: NEVER directly edit files in `/redis_om/`. Always modify `/aredis_om/` and run `make sync`.
45+
46+
### Core Abstractions
47+
48+
#### Model Hierarchy
49+
- `RedisModel` - Abstract base class
50+
- `HashModel` - Stores data as Redis Hashes (simple key-value)
51+
- `JsonModel` - Stores data as RedisJSON documents (nested structures)
52+
- `EmbeddedJsonModel` - For nested JSON models without separate storage
53+
54+
#### Query System
55+
Expression-based queries using Django ORM-like syntax:
56+
```python
57+
results = await Customer.find(
58+
(Customer.age > 18) & (Customer.country == "US")
59+
).all()
60+
```
61+
62+
#### Field System
63+
Built on Pydantic fields with Redis-specific extensions:
64+
- `Field(index=True)` - Create secondary index
65+
- `Field(full_text_search=True)` - Enable text search
66+
- `VectorField` - For similarity search with embeddings
67+
68+
### Key Implementation Patterns
69+
70+
1. **Connection Management**: Uses `get_redis_connection()` from `connections.py`
71+
2. **Index Management**: Automatic RediSearch index creation via migration system
72+
3. **Query Resolution**: `QueryResolver` class handles expression tree → RediSearch query conversion
73+
4. **Type Validation**: All models use Pydantic v2 for validation
74+
75+
## Important Constraints
76+
77+
1. **Database 0 Only**: RediSearch indexes only work in Redis database 0
78+
2. **Module Requirements**: Advanced features require Redis Stack (RediSearch + RedisJSON)
79+
3. **Python Versions**: Supports Python 3.8-3.13, test across versions with tox
80+
4. **Async Testing**: Tests use `pytest-asyncio` with `asyncio_mode = strict`
81+
82+
## Testing Approach
83+
84+
- Separate test directories: `/tests/` (async) and `/tests_sync/` (sync)
85+
- Docker Compose provides test Redis instances
86+
- Use fixtures from `conftest.py` for Redis connections
87+
- Mark tests requiring Redis modules with appropriate decorators
88+
89+
## Common Development Tasks
90+
91+
When implementing new features:
92+
1. Implement in `/aredis_om/` first
93+
2. Run `make sync` to generate sync version
94+
3. Write tests in `/tests/` (async)
95+
4. Run `make format` and `make lint`
96+
5. Run `make test` to verify
97+
6. Update type hints and ensure mypy passes

aredis_om/model/model.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def __init__(
418418
limit: Optional[int] = None,
419419
page_size: int = DEFAULT_PAGE_SIZE,
420420
sort_fields: Optional[List[str]] = None,
421-
return_fields: Optional[List[str]] = None,
421+
projected_fields: Optional[List[str]] = None,
422422
nocontent: bool = False,
423423
):
424424
if not has_redisearch(model.db()):
@@ -443,10 +443,10 @@ def __init__(
443443
else:
444444
self.sort_fields = []
445445

446-
if return_fields:
447-
self.return_fields = self.validate_return_fields(return_fields)
446+
if projected_fields:
447+
self.projected_fields = self.validate_projected_fields(projected_fields)
448448
else:
449-
self.return_fields = []
449+
self.projected_fields = []
450450

451451
self._expression = None
452452
self._query: Optional[str] = None
@@ -505,18 +505,45 @@ def query(self):
505505
if self._query.startswith("(") or self._query == "*"
506506
else f"({self._query})"
507507
) + f"=>[{self.knn}]"
508-
if self.return_fields:
509-
self._query += f" RETURN {','.join(self.return_fields)}"
508+
# RETURN clause should be added to args, not to the query string
510509
return self._query
511510

512-
def validate_return_fields(self, return_fields: List[str]):
513-
for field in return_fields:
514-
if field not in self.model.__fields__: # type: ignore
511+
def validate_projected_fields(self, projected_fields: List[str]):
512+
for field in projected_fields:
513+
if field not in self.model.model_fields: # type: ignore
515514
raise QueryNotSupportedError(
516515
f"You tried to return the field {field}, but that field "
517516
f"does not exist on the model {self.model}"
518517
)
519-
return return_fields
518+
return projected_fields
519+
520+
def _parse_projected_results(self, res: Any) -> List[Dict[str, Any]]:
521+
"""Parse results when using RETURN clause with specific fields."""
522+
523+
def to_string(s):
524+
if isinstance(s, (str,)):
525+
return s
526+
elif isinstance(s, bytes):
527+
return s.decode(errors="ignore")
528+
else:
529+
return s
530+
531+
docs = []
532+
step = 2 # Because the result has content
533+
offset = 1 # The first item is the count of total matches.
534+
535+
for i in range(1, len(res), step):
536+
if res[i + offset] is None:
537+
continue
538+
# When using RETURN, we get flat key-value pairs
539+
fields: Dict[str, str] = dict(
540+
zip(
541+
map(to_string, res[i + offset][::2]),
542+
map(to_string, res[i + offset][1::2]),
543+
)
544+
)
545+
docs.append(fields)
546+
return docs
520547

521548
@property
522549
def query_params(self):
@@ -899,6 +926,12 @@ async def execute(
899926
if self.nocontent:
900927
args.append("NOCONTENT")
901928

929+
# Add RETURN clause to the args list, not to the query string
930+
if self.projected_fields:
931+
args.extend(
932+
["RETURN", str(len(self.projected_fields))] + self.projected_fields
933+
)
934+
902935
if return_query_args:
903936
return self.model.Meta.index_name, args
904937

@@ -912,7 +945,12 @@ async def execute(
912945
if return_raw_result:
913946
return raw_result
914947
count = raw_result[0]
915-
results = self.model.from_redis(raw_result, self.knn)
948+
949+
# If we're using field projection, return dictionaries instead of model instances
950+
if self.projected_fields:
951+
results = self._parse_projected_results(raw_result)
952+
else:
953+
results = self.model.from_redis(raw_result, self.knn)
916954
self._model_cache += results
917955

918956
if not exhaust_results:
@@ -966,11 +1004,11 @@ def sort_by(self, *fields: str):
9661004
if not fields:
9671005
return self
9681006
return self.copy(sort_fields=list(fields))
969-
1007+
9701008
def return_fields(self, *fields: str):
9711009
if not fields:
9721010
return self
973-
return self.copy(return_fields=list(fields))
1011+
return self.copy(projected_fields=list(fields))
9741012

9751013
async def update(self, use_transaction=True, **field_values):
9761014
"""
@@ -1546,9 +1584,7 @@ def find(
15461584
*expressions: Union[Any, Expression],
15471585
knn: Optional[KNNExpression] = None,
15481586
) -> FindQuery:
1549-
return FindQuery(
1550-
expressions=expressions, knn=knn, model=cls
1551-
)
1587+
return FindQuery(expressions=expressions, knn=knn, model=cls)
15521588

15531589
@classmethod
15541590
def from_redis(cls, res: Any, knn: Optional[KNNExpression] = None):

tests/test_hash_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,23 @@ class Meta:
11331133
).first()
11341134

11351135

1136+
@py_test_mark_asyncio
1137+
async def test_return_specified_fields(members, m):
1138+
member1, member2, member3 = members
1139+
actual = (
1140+
await m.Member.find(
1141+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
1142+
| (m.Member.last_name == "Smith")
1143+
)
1144+
.return_fields("first_name", "last_name")
1145+
.all()
1146+
)
1147+
assert actual == [
1148+
{"first_name": "Andrew", "last_name": "Brookins"},
1149+
{"first_name": "Andrew", "last_name": "Smith"},
1150+
]
1151+
1152+
11361153
@py_test_mark_asyncio
11371154
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
11381155
class Location(HashModel, index=True):

tests/test_json_model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -955,13 +955,18 @@ class TypeWithUuid(JsonModel, index=True):
955955

956956
await item.save()
957957

958+
958959
@py_test_mark_asyncio
959960
async def test_return_specified_fields(members, m):
960961
member1, member2, member3 = members
961-
actual = await m.Member.find(
962-
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
963-
| (m.Member.last_name == "Smith")
964-
).all()
962+
actual = (
963+
await m.Member.find(
964+
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
965+
| (m.Member.last_name == "Smith")
966+
)
967+
.return_fields("first_name", "last_name")
968+
.all()
969+
)
965970
assert actual == [
966971
{"first_name": "Andrew", "last_name": "Brookins"},
967972
{"first_name": "Andrew", "last_name": "Smith"},

0 commit comments

Comments
 (0)