Skip to content

Commit 3676634

Browse files
committed
update models to specify when a model is final
1 parent 9ffa5ed commit 3676634

File tree

7 files changed

+226
-112
lines changed

7 files changed

+226
-112
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"python.testing.unittestEnabled": false,
3+
"python.testing.pytestEnabled": true,
4+
}

aredis_om/model/model.py

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from copy import copy
88
from enum import Enum
99
from functools import reduce
10-
from typing_extensions import Unpack
1110
from typing import (
1211
Any,
1312
Callable,
@@ -30,12 +29,13 @@
3029
from pydantic import BaseModel, ConfigDict, TypeAdapter, field_validator
3130
from pydantic._internal._model_construction import ModelMetaclass
3231
from pydantic._internal._repr import Representation
33-
from pydantic.fields import FieldInfo as PydanticFieldInfo, _FieldInfoInputs
32+
from pydantic.fields import FieldInfo as PydanticFieldInfo
33+
from pydantic.fields import _FieldInfoInputs
3434
from pydantic_core import PydanticUndefined as Undefined
3535
from pydantic_core import PydanticUndefinedType as UndefinedType
3636
from redis.commands.json.path import Path
3737
from redis.exceptions import ResponseError
38-
from typing_extensions import Protocol, get_args, get_origin
38+
from typing_extensions import Protocol, Unpack, get_args, get_origin
3939
from ulid import ULID
4040

4141
from .. import redis
@@ -280,6 +280,7 @@ def tree(self):
280280
class KNNExpression:
281281
k: int
282282
vector_field_name: str
283+
score_field_name: str
283284
reference_vector: bytes
284285

285286
def __str__(self):
@@ -291,7 +292,7 @@ def query_params(self) -> Dict[str, Union[str, bytes]]:
291292

292293
@property
293294
def score_field(self) -> str:
294-
return f"__{self.vector_field_name}_score"
295+
return self.score_field_name or f"_{self.vector_field_name}_score"
295296

296297

297298
ExpressionOrNegated = Union[Expression, NegatedExpression]
@@ -1176,10 +1177,10 @@ def Field(
11761177
index: Union[bool, UndefinedType] = Undefined,
11771178
full_text_search: Union[bool, UndefinedType] = Undefined,
11781179
vector_options: Optional[VectorFieldOptions] = None,
1179-
**kwargs: Unpack[_FieldInfoInputs],
1180+
**kwargs: Unpack[_FieldInfoInputs],
11801181
) -> Any:
11811182
field_info = FieldInfo(
1182-
**kwargs,
1183+
**kwargs,
11831184
primary_key=primary_key,
11841185
sortable=sortable,
11851186
case_sensitive=case_sensitive,
@@ -1196,6 +1197,10 @@ class PrimaryKey:
11961197
field: PydanticFieldInfo
11971198

11981199

1200+
class RedisOmConfig(ConfigDict):
1201+
index: bool | None
1202+
1203+
11991204
class BaseMeta(Protocol):
12001205
global_key_prefix: str
12011206
model_key_prefix: str
@@ -1230,9 +1235,30 @@ class DefaultMeta:
12301235
class ModelMeta(ModelMetaclass):
12311236
_meta: BaseMeta
12321237

1238+
model_config: RedisOmConfig
1239+
model_fields: Dict[str, FieldInfo] # type: ignore[assignment]
1240+
12331241
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
12341242
meta = attrs.pop("Meta", None)
1235-
new_class: RedisModel = super().__new__(cls, name, bases, attrs, **kwargs)
1243+
1244+
# Duplicate logic from Pydantic to filter config kwargs because if they are
1245+
# passed directly including the registry Pydantic will pass them over to the
1246+
# superclass causing an error
1247+
allowed_config_kwargs: Set[str] = {
1248+
key
1249+
for key in dir(ConfigDict)
1250+
if not (
1251+
key.startswith("__") and key.endswith("__")
1252+
) # skip dunder methods and attributes
1253+
}
1254+
1255+
config_kwargs = {
1256+
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
1257+
}
1258+
1259+
new_class: RedisModel = super().__new__(
1260+
cls, name, bases, attrs, **config_kwargs
1261+
)
12361262

12371263
# The fact that there is a Meta field and _meta field is important: a
12381264
# user may have given us a Meta object with their configuration, while
@@ -1241,13 +1267,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
12411267
meta = meta or getattr(new_class, "Meta", None)
12421268
base_meta = getattr(new_class, "_meta", None)
12431269

1244-
if len(bases) >= 1:
1245-
for base_index in range(len(bases)):
1246-
model_fields = getattr(bases[base_index], "model_fields", [])
1247-
for f_name in model_fields:
1248-
field = model_fields[f_name]
1249-
new_class.model_fields[f_name] = field
1250-
12511270
if meta and meta != DefaultMeta and meta != base_meta:
12521271
new_class.Meta = meta
12531272
new_class._meta = meta
@@ -1266,49 +1285,35 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
12661285
)
12671286
new_class.Meta = new_class._meta
12681287

1288+
if new_class.model_config.get("index", None) is True:
1289+
raise RedisModelError(
1290+
f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree"
1291+
)
1292+
12691293
# Create proxies for each model field so that we can use the field
12701294
# in queries, like Model.get(Model.field_name == 1)
1271-
for field_name, field in new_class.model_fields.items():
1272-
if not isinstance(field, FieldInfo):
1273-
for base_candidate in bases:
1274-
if hasattr(base_candidate, field_name):
1275-
inner_field = getattr(base_candidate, field_name)
1276-
if hasattr(inner_field, "field") and isinstance(
1277-
getattr(inner_field, "field"), FieldInfo
1278-
):
1279-
field.metadata.append(getattr(inner_field, "field"))
1280-
field = getattr(inner_field, "field")
1281-
1282-
if not field.alias:
1283-
field.alias = field_name
1284-
setattr(new_class, field_name, ExpressionProxy(field, []))
1285-
annotation = new_class.get_annotations().get(field_name)
1286-
if annotation:
1287-
new_class.__annotations__[field_name] = Union[
1288-
annotation, ExpressionProxy
1289-
]
1290-
else:
1291-
new_class.__annotations__[field_name] = ExpressionProxy
1292-
# Check if this is our FieldInfo version with extended ORM metadata.
1293-
field_info = None
1294-
if hasattr(field, "field_info") and isinstance(field.field_info, FieldInfo):
1295-
field_info = field.field_info
1296-
elif field_name in attrs and isinstance(
1297-
attrs.__getitem__(field_name), FieldInfo
1298-
):
1299-
field_info = attrs.__getitem__(field_name)
1300-
field.field_info = field_info
1301-
1302-
if field_info is not None:
1303-
if field_info.primary_key:
1295+
# Only set if the model is has index=True
1296+
if kwargs.get("index", None) == True:
1297+
new_class.model_config["index"] = True
1298+
for field_name, field in new_class.model_fields.items():
1299+
setattr(new_class, field_name, ExpressionProxy(field, []))
1300+
1301+
# We need to set alias equal the field name here to allow downstream processes to have access to it.
1302+
# Processes like the query builder use it.
1303+
if not field.alias:
1304+
field.alias = field_name
1305+
1306+
if getattr(field, "primary_key", None) is True:
13041307
new_class._meta.primary_key = PrimaryKey(
13051308
name=field_name, field=field
13061309
)
1307-
if field_info.vector_options:
1310+
if getattr(field, "vector_options", None) is not None:
13081311
score_attr = f"_{field_name}_score"
13091312
setattr(new_class, score_attr, None)
13101313
new_class.__annotations__[score_attr] = Union[float, None]
13111314

1315+
new_class.model_config["from_attributes"] = True
1316+
13121317
if not getattr(new_class._meta, "global_key_prefix", None):
13131318
new_class._meta.global_key_prefix = getattr(
13141319
base_meta, "global_key_prefix", ""
@@ -1339,9 +1344,13 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
13391344
f"{new_class._meta.model_key_prefix}:index"
13401345
)
13411346

1342-
# Not an abstract model class or embedded model, so we should let the
1347+
# Model is indexed and not an abstract model class or embedded model, so we should let the
13431348
# Migrator create indexes for it.
1344-
if abc.ABC not in bases and not getattr(new_class._meta, "embedded", False):
1349+
if (
1350+
abc.ABC not in bases
1351+
and not getattr(new_class._meta, "embedded", False)
1352+
and new_class.model_config.get("index") is True
1353+
):
13451354
key = f"{new_class.__module__}.{new_class.__qualname__}"
13461355
model_registry[key] = new_class
13471356

@@ -1366,28 +1375,28 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
13661375
Meta = DefaultMeta
13671376

13681377
model_config = ConfigDict(
1369-
from_attributes=True, arbitrary_types_allowed=True, extra="allow", validate_default=True
1378+
from_attributes=True,
1379+
arbitrary_types_allowed=True,
1380+
extra="allow",
1381+
validate_default=True,
13701382
)
13711383

13721384
def __init__(__pydantic_self__, **data: Any) -> None:
13731385
__pydantic_self__.validate_primary_key()
1374-
missing_fields = __pydantic_self__.model_fields.keys() - data.keys() - {"pk"}
1375-
1376-
kwargs = data.copy()
1377-
1378-
# This is a hack, we need to manually make sure we are setting up defaults correctly when we encounter them
1379-
# because inheritance apparently won't cover that in pydantic 2.0.
1380-
for field in missing_fields:
1381-
default_value = __pydantic_self__.model_fields.get(field).default # type: ignore
1382-
kwargs[field] = default_value
1383-
super().__init__(**kwargs)
1386+
super().__init__(**data)
13841387

13851388
def __lt__(self, other):
13861389
"""Default sort: compare primary key of models."""
13871390
return self.key() < other.key()
13881391

13891392
def key(self):
13901393
"""Return the Redis key for this model."""
1394+
if self.model_config.get("index", False) is not True:
1395+
raise RedisModelError(
1396+
"You cannot create a key on a model that is not indexed. "
1397+
f"Update your class with index=True: class {self.__class__.__name__}(RedisModel, index=True):"
1398+
)
1399+
13911400
if hasattr(self._meta.primary_key.field, "name"):
13921401
pk = getattr(self, self._meta.primary_key.field.name)
13931402
else:
@@ -1932,7 +1941,7 @@ def schema_for_type(
19321941
json_path: str,
19331942
name: str,
19341943
name_prefix: str,
1935-
typ: Any,
1944+
typ: type[RedisModel] | Any,
19361945
field_info: PydanticFieldInfo,
19371946
parent_type: Optional[Any] = None,
19381947
) -> str:
@@ -2010,7 +2019,6 @@ def schema_for_type(
20102019
parent_type=field_type,
20112020
)
20122021
elif field_is_model:
2013-
typ: type[RedisModel] = typ
20142022
name_prefix = f"{name_prefix}_{name}" if name_prefix else name
20152023
sub_fields = []
20162024
for embedded_name, field in typ.model_fields.items():

tests/test_find_query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Note(EmbeddedJsonModel):
5151
description: str = Field(index=True)
5252
created_on: datetime.datetime
5353

54-
class Address(EmbeddedJsonModel):
54+
class Address(EmbeddedJsonModel, index=True):
5555
address_line_1: str
5656
address_line_2: Optional[str] = None
5757
city: str = Field(index=True)
@@ -60,15 +60,15 @@ class Address(EmbeddedJsonModel):
6060
postal_code: str = Field(index=True)
6161
note: Optional[Note] = None
6262

63-
class Item(EmbeddedJsonModel):
63+
class Item(EmbeddedJsonModel, index=True):
6464
price: decimal.Decimal
6565
name: str = Field(index=True)
6666

67-
class Order(EmbeddedJsonModel):
67+
class Order(EmbeddedJsonModel, index=True):
6868
items: List[Item]
6969
created_on: datetime.datetime
7070

71-
class Member(BaseJsonModel):
71+
class Member(BaseJsonModel, index=True):
7272
first_name: str = Field(index=True, case_sensitive=True)
7373
last_name: str = Field(index=True)
7474
email: Optional[EmailStr] = Field(index=True, default=None)

0 commit comments

Comments
 (0)