Skip to content

added return_fields function, attempting to optionally limit fields r… #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def __init__(
limit: Optional[int] = None,
page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None,
projected_fields: Optional[List[str]] = None,
nocontent: bool = False,
):
if not has_redisearch(model.db()):
Expand All @@ -442,6 +443,11 @@ def __init__(
else:
self.sort_fields = []

if projected_fields:
self.projected_fields = self.validate_projected_fields(projected_fields)
else:
self.projected_fields = []

self._expression = None
self._query: Optional[str] = None
self._pagination: List[str] = []
Expand Down Expand Up @@ -499,8 +505,46 @@ def query(self):
if self._query.startswith("(") or self._query == "*"
else f"({self._query})"
) + f"=>[{self.knn}]"
# RETURN clause should be added to args, not to the query string
return self._query

def validate_projected_fields(self, projected_fields: List[str]):
for field in projected_fields:
if field not in self.model.model_fields: # type: ignore
raise QueryNotSupportedError(
f"You tried to return the field {field}, but that field "
f"does not exist on the model {self.model}"
)
return projected_fields

def _parse_projected_results(self, res: Any) -> List[Dict[str, Any]]:
"""Parse results when using RETURN clause with specific fields."""

def to_string(s):
if isinstance(s, (str,)):
return s
elif isinstance(s, bytes):
return s.decode(errors="ignore")
else:
return s

docs = []
step = 2 # Because the result has content
offset = 1 # The first item is the count of total matches.

for i in range(1, len(res), step):
if res[i + offset] is None:
continue
# When using RETURN, we get flat key-value pairs
fields: Dict[str, str] = dict(
zip(
map(to_string, res[i + offset][::2]),
map(to_string, res[i + offset][1::2]),
)
)
docs.append(fields)
return docs

@property
def query_params(self):
params: List[Union[str, bytes]] = []
Expand Down Expand Up @@ -882,6 +926,12 @@ async def execute(
if self.nocontent:
args.append("NOCONTENT")

# Add RETURN clause to the args list, not to the query string
if self.projected_fields:
args.extend(
["RETURN", str(len(self.projected_fields))] + self.projected_fields
)

if return_query_args:
return self.model.Meta.index_name, args

Expand All @@ -895,7 +945,12 @@ async def execute(
if return_raw_result:
return raw_result
count = raw_result[0]
results = self.model.from_redis(raw_result, self.knn)

# If we're using field projection, return dictionaries instead of model instances
if self.projected_fields:
results = self._parse_projected_results(raw_result)
else:
results = self.model.from_redis(raw_result, self.knn)
self._model_cache += results

if not exhaust_results:
Expand Down Expand Up @@ -950,6 +1005,11 @@ def sort_by(self, *fields: str):
return self
return self.copy(sort_fields=list(fields))

def return_fields(self, *fields: str):
if not fields:
return self
return self.copy(projected_fields=list(fields))

async def update(self, use_transaction=True, **field_values):
"""
Update models that match this query to the given field-value pairs.
Expand Down
17 changes: 17 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,23 @@ class Meta:
).first()


@py_test_mark_asyncio
async def test_return_specified_fields(members, m):
member1, member2, member3 = members
actual = (
await m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
.return_fields("first_name", "last_name")
.all()
)
assert actual == [
{"first_name": "Andrew", "last_name": "Brookins"},
{"first_name": "Andrew", "last_name": "Smith"},
]


@py_test_mark_asyncio
async def test_can_search_on_multiple_fields_with_geo_filter(key_prefix, redis):
class Location(HashModel, index=True):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,23 @@ class TypeWithUuid(JsonModel, index=True):
await item.save()


@py_test_mark_asyncio
async def test_return_specified_fields(members, m):
member1, member2, member3 = members
actual = (
await m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
.return_fields("first_name", "last_name")
.all()
)
assert actual == [
{"first_name": "Andrew", "last_name": "Brookins"},
{"first_name": "Andrew", "last_name": "Smith"},
]


@py_test_mark_asyncio
async def test_type_with_enum():
class TestEnum(Enum):
Expand Down
Loading