Skip to content

Commit 8bc61a1

Browse files
committed
remove use of alias for queries
1 parent 721d91e commit 8bc61a1

File tree

5 files changed

+69
-37
lines changed

5 files changed

+69
-37
lines changed

.vscode/settings.json

Lines changed: 0 additions & 4 deletions
This file was deleted.

aredis_om/model/model.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,18 @@ def query_params(self) -> Dict[str, Union[str, bytes]]:
292292

293293
@property
294294
def score_field_name(self) -> str:
295-
return self.score_field.field.alias
295+
return self.score_field.field.name
296296

297297
@property
298298
def vector_field_name(self) -> str:
299-
return self.vector_field.field.alias
299+
return self.vector_field.field.name
300300

301301

302302
ExpressionOrNegated = Union[Expression, NegatedExpression]
303303

304304

305305
class ExpressionProxy:
306-
def __init__(
307-
self, field: PydanticFieldInfo, parents: List[Tuple[str, "RedisModel"]]
308-
):
306+
def __init__(self, field: "FieldInfo", parents: List[Tuple[str, "RedisModel"]]):
309307
self.field = field
310308
self.parents = parents.copy() # Ensure a copy is stored
311309

@@ -389,7 +387,7 @@ def __getattr__(self, item):
389387
if isinstance(attr, self.__class__):
390388
# Clone the parents to ensure isolation
391389
new_parents = self.parents.copy()
392-
new_parent = (self.field.alias, outer_type)
390+
new_parent = (self.field.name, outer_type)
393391
if new_parent not in new_parents:
394392
new_parents.append(new_parent)
395393
attr.parents = new_parents
@@ -524,17 +522,18 @@ def validate_sort_fields(self, sort_fields: List[str]):
524522
)
525523
field_proxy: ExpressionProxy = getattr(self.model, field_name)
526524

527-
if not getattr(field_proxy.field, "sortable", False):
525+
if (
526+
not field_proxy.field.sortable is True
527+
and not field_proxy.field.index is True
528+
):
528529
raise QueryNotSupportedError(
529530
f"You tried sort by {field_name}, but {self.model} does "
530-
f"not define that field as sortable. Docs: {ERRORS_URL}#E2"
531+
f"not define that field as sortable or indexed. Docs: {ERRORS_URL}#E2"
531532
)
532533
return sort_fields
533534

534535
@staticmethod
535-
def resolve_field_type(
536-
field: PydanticFieldInfo, op: Operators
537-
) -> RediSearchFieldTypes:
536+
def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes:
538537
field_info: Union[FieldInfo, PydanticFieldInfo] = field
539538

540539
if getattr(field_info, "primary_key", None) is True:
@@ -543,7 +542,7 @@ def resolve_field_type(
543542
fts = getattr(field_info, "full_text_search", None)
544543
if fts is not True: # Could be PydanticUndefined
545544
raise QuerySyntaxError(
546-
f"You tried to do a full-text search on the field '{field.alias}', "
545+
f"You tried to do a full-text search on the field '{field.name}', "
547546
f"but the field is not indexed for full-text search. Use the "
548547
f"full_text_search=True option. Docs: {ERRORS_URL}#E3"
549548
)
@@ -793,7 +792,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
793792
result += f"({cls.resolve_redisearch_query(expression.left)})"
794793
elif isinstance(expression.left, FieldInfo):
795794
field_type = cls.resolve_field_type(expression.left, expression.op)
796-
field_name = expression.left.alias
795+
field_name = expression.left.name
797796
field_info = expression.left
798797
if not field_info or not getattr(field_info, "index", None):
799798
raise QueryNotSupportedError(
@@ -1059,6 +1058,8 @@ def __dataclass_transform__(
10591058

10601059

10611060
class FieldInfo(PydanticFieldInfo):
1061+
name: str
1062+
10621063
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
10631064
primary_key = kwargs.pop("primary_key", False)
10641065
sortable = kwargs.pop("sortable", Undefined)
@@ -1297,20 +1298,22 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
12971298
# Create proxies for each model field so that we can use the field
12981299
# in queries, like Model.get(Model.field_name == 1)
12991300
# Only set if the model is has index=True
1300-
if kwargs.get("index", None) == True:
1301-
new_class.model_config["index"] = True
1302-
for field_name, field in new_class.model_fields.items():
1301+
is_indexed = kwargs.get("index", None) is True
1302+
new_class.model_config["index"] = is_indexed
1303+
1304+
for field_name, field in new_class.model_fields.items():
1305+
if field.__class__ is PydanticFieldInfo:
1306+
field = FieldInfo(**field._attributes_set)
1307+
setattr(new_class, field_name, field)
1308+
1309+
if is_indexed:
13031310
setattr(new_class, field_name, ExpressionProxy(field, []))
13041311

1305-
# We need to set alias equal the field name here to allow downstream processes to have access to it.
1306-
# Processes like the query builder use it.
1307-
if not field.alias:
1308-
field.alias = field_name
1312+
# we need to set the field name for use in queries
1313+
field.name = field_name
13091314

1310-
if getattr(field, "primary_key", None) is True:
1311-
new_class._meta.primary_key = PrimaryKey(
1312-
name=field_name, field=field
1313-
)
1315+
if field.primary_key is True:
1316+
new_class._meta.primary_key = PrimaryKey(name=field_name, field=field)
13141317

13151318
if not getattr(new_class._meta, "global_key_prefix", None):
13161319
new_class._meta.global_key_prefix = getattr(

tests/test_hash_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,3 +1035,22 @@ class Child(Model):
10351035
assert child_validate.age == 18
10361036
assert child_validate.bio is None
10371037
assert child_validate.other_name == "Maria"
1038+
1039+
1040+
@py_test_mark_asyncio
1041+
async def test_model_with_alias_can_be_searched(key_prefix, redis):
1042+
class Model(HashModel, index=True):
1043+
first_name: str = Field(alias="firstName", index=True)
1044+
last_name: str = Field(alias="lastName")
1045+
1046+
class Meta:
1047+
global_key_prefix = key_prefix
1048+
database = redis
1049+
1050+
await Migrator().run()
1051+
1052+
model = Model(first_name="Steve", last_name="Lorello")
1053+
await model.save()
1054+
1055+
rematerialized = await Model.find(Model.first_name == "Steve").first()
1056+
assert rematerialized.pk == model.pk

tests/test_json_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,3 +1288,22 @@ class Model(JsonModel):
12881288
with pytest.raises(RedisModelError):
12891289
model = Model()
12901290
await model.save()
1291+
1292+
1293+
@py_test_mark_asyncio
1294+
async def test_model_with_alias_can_be_searched(key_prefix, redis):
1295+
class Model(JsonModel, index=True):
1296+
first_name: str = Field(alias="firstName", index=True)
1297+
last_name: str = Field(alias="lastName")
1298+
1299+
class Meta:
1300+
global_key_prefix = key_prefix
1301+
database = redis
1302+
1303+
await Migrator().run()
1304+
1305+
model = Model(first_name="Steve", last_name="Lorello")
1306+
await model.save()
1307+
1308+
rematerialized = await Model.find(Model.first_name == "Steve").first()
1309+
assert rematerialized.pk == model.pk

tests/test_knn_expression.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# type: ignore
22
import abc
3-
import time
43
import struct
4+
import time
55

66
import pytest_asyncio
77

88
from aredis_om import Field, JsonModel, KNNExpression, Migrator, VectorFieldOptions
99

1010
from .conftest import py_test_mark_asyncio
1111

12+
1213
DIMENSIONS = 768
1314

1415

@@ -28,9 +29,7 @@ class Meta:
2829

2930
class Member(BaseJsonModel, index=True):
3031
name: str
31-
embeddings: list[list[float]] | bytes = Field(
32-
[], vector_options=vector_field_options
33-
)
32+
embeddings: list[list[float]] = Field([], vector_options=vector_field_options)
3433
embeddings_score: float | None = None
3534

3635
await Migrator().run()
@@ -49,11 +48,7 @@ async def test_vector_field(m: type[JsonModel]):
4948
member = m(name="seth", embeddings=[vectors])
5049

5150
# Save the member to Redis
52-
mt = await member.save()
53-
54-
assert m.get(mt.pk)
55-
56-
time.sleep(1)
51+
await member.save()
5752

5853
knn = KNNExpression(
5954
k=1,

0 commit comments

Comments
 (0)