7
7
from copy import copy
8
8
from enum import Enum
9
9
from functools import reduce
10
- from typing_extensions import Unpack
11
10
from typing import (
12
11
Any ,
13
12
Callable ,
30
29
from pydantic import BaseModel , ConfigDict , TypeAdapter , field_validator
31
30
from pydantic ._internal ._model_construction import ModelMetaclass
32
31
from pydantic ._internal ._repr import Representation
33
- from pydantic .fields import FieldInfo as PydanticFieldInfo , _FieldInfoInputs
32
+ from pydantic .fields import FieldInfo as PydanticFieldInfo
33
+ from pydantic .fields import _FieldInfoInputs
34
34
from pydantic_core import PydanticUndefined as Undefined
35
35
from pydantic_core import PydanticUndefinedType as UndefinedType
36
36
from redis .commands .json .path import Path
37
37
from redis .exceptions import ResponseError
38
- from typing_extensions import Protocol , get_args , get_origin
38
+ from typing_extensions import Protocol , Unpack , get_args , get_origin
39
39
from ulid import ULID
40
40
41
41
from .. import redis
@@ -280,6 +280,7 @@ def tree(self):
280
280
class KNNExpression :
281
281
k : int
282
282
vector_field_name : str
283
+ score_field_name : str
283
284
reference_vector : bytes
284
285
285
286
def __str__ (self ):
@@ -291,7 +292,7 @@ def query_params(self) -> Dict[str, Union[str, bytes]]:
291
292
292
293
@property
293
294
def score_field (self ) -> str :
294
- return f"__ { self .vector_field_name } _score"
295
+ return self . score_field_name or f"_ { self .vector_field_name } _score"
295
296
296
297
297
298
ExpressionOrNegated = Union [Expression , NegatedExpression ]
@@ -1176,10 +1177,10 @@ def Field(
1176
1177
index : Union [bool , UndefinedType ] = Undefined ,
1177
1178
full_text_search : Union [bool , UndefinedType ] = Undefined ,
1178
1179
vector_options : Optional [VectorFieldOptions ] = None ,
1179
- ** kwargs : Unpack [_FieldInfoInputs ],
1180
+ ** kwargs : Unpack [_FieldInfoInputs ],
1180
1181
) -> Any :
1181
1182
field_info = FieldInfo (
1182
- ** kwargs ,
1183
+ ** kwargs ,
1183
1184
primary_key = primary_key ,
1184
1185
sortable = sortable ,
1185
1186
case_sensitive = case_sensitive ,
@@ -1196,6 +1197,10 @@ class PrimaryKey:
1196
1197
field : PydanticFieldInfo
1197
1198
1198
1199
1200
+ class RedisOmConfig (ConfigDict ):
1201
+ index : bool | None
1202
+
1203
+
1199
1204
class BaseMeta (Protocol ):
1200
1205
global_key_prefix : str
1201
1206
model_key_prefix : str
@@ -1230,9 +1235,30 @@ class DefaultMeta:
1230
1235
class ModelMeta (ModelMetaclass ):
1231
1236
_meta : BaseMeta
1232
1237
1238
+ model_config : RedisOmConfig
1239
+ model_fields : Dict [str , FieldInfo ] # type: ignore[assignment]
1240
+
1233
1241
def __new__ (cls , name , bases , attrs , ** kwargs ): # noqa C901
1234
1242
meta = attrs .pop ("Meta" , None )
1235
- new_class : RedisModel = super ().__new__ (cls , name , bases , attrs , ** kwargs )
1243
+
1244
+ # Duplicate logic from Pydantic to filter config kwargs because if they are
1245
+ # passed directly including the registry Pydantic will pass them over to the
1246
+ # superclass causing an error
1247
+ allowed_config_kwargs : Set [str ] = {
1248
+ key
1249
+ for key in dir (ConfigDict )
1250
+ if not (
1251
+ key .startswith ("__" ) and key .endswith ("__" )
1252
+ ) # skip dunder methods and attributes
1253
+ }
1254
+
1255
+ config_kwargs = {
1256
+ key : kwargs [key ] for key in kwargs .keys () & allowed_config_kwargs
1257
+ }
1258
+
1259
+ new_class : RedisModel = super ().__new__ (
1260
+ cls , name , bases , attrs , ** config_kwargs
1261
+ )
1236
1262
1237
1263
# The fact that there is a Meta field and _meta field is important: a
1238
1264
# user may have given us a Meta object with their configuration, while
@@ -1241,13 +1267,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
1241
1267
meta = meta or getattr (new_class , "Meta" , None )
1242
1268
base_meta = getattr (new_class , "_meta" , None )
1243
1269
1244
- if len (bases ) >= 1 :
1245
- for base_index in range (len (bases )):
1246
- model_fields = getattr (bases [base_index ], "model_fields" , [])
1247
- for f_name in model_fields :
1248
- field = model_fields [f_name ]
1249
- new_class .model_fields [f_name ] = field
1250
-
1251
1270
if meta and meta != DefaultMeta and meta != base_meta :
1252
1271
new_class .Meta = meta
1253
1272
new_class ._meta = meta
@@ -1266,49 +1285,35 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
1266
1285
)
1267
1286
new_class .Meta = new_class ._meta
1268
1287
1288
+ if new_class .model_config .get ("index" , None ) is True :
1289
+ raise RedisModelError (
1290
+ f"{ new_class .__name__ } cannot be indexed, only one model can be indexed in an inheritance tree"
1291
+ )
1292
+
1269
1293
# Create proxies for each model field so that we can use the field
1270
1294
# in queries, like Model.get(Model.field_name == 1)
1271
- for field_name , field in new_class .model_fields .items ():
1272
- if not isinstance (field , FieldInfo ):
1273
- for base_candidate in bases :
1274
- if hasattr (base_candidate , field_name ):
1275
- inner_field = getattr (base_candidate , field_name )
1276
- if hasattr (inner_field , "field" ) and isinstance (
1277
- getattr (inner_field , "field" ), FieldInfo
1278
- ):
1279
- field .metadata .append (getattr (inner_field , "field" ))
1280
- field = getattr (inner_field , "field" )
1281
-
1282
- if not field .alias :
1283
- field .alias = field_name
1284
- setattr (new_class , field_name , ExpressionProxy (field , []))
1285
- annotation = new_class .get_annotations ().get (field_name )
1286
- if annotation :
1287
- new_class .__annotations__ [field_name ] = Union [
1288
- annotation , ExpressionProxy
1289
- ]
1290
- else :
1291
- new_class .__annotations__ [field_name ] = ExpressionProxy
1292
- # Check if this is our FieldInfo version with extended ORM metadata.
1293
- field_info = None
1294
- if hasattr (field , "field_info" ) and isinstance (field .field_info , FieldInfo ):
1295
- field_info = field .field_info
1296
- elif field_name in attrs and isinstance (
1297
- attrs .__getitem__ (field_name ), FieldInfo
1298
- ):
1299
- field_info = attrs .__getitem__ (field_name )
1300
- field .field_info = field_info
1301
-
1302
- if field_info is not None :
1303
- if field_info .primary_key :
1295
+ # Only set if the model is has index=True
1296
+ if kwargs .get ("index" , None ) == True :
1297
+ new_class .model_config ["index" ] = True
1298
+ for field_name , field in new_class .model_fields .items ():
1299
+ setattr (new_class , field_name , ExpressionProxy (field , []))
1300
+
1301
+ # We need to set alias equal the field name here to allow downstream processes to have access to it.
1302
+ # Processes like the query builder use it.
1303
+ if not field .alias :
1304
+ field .alias = field_name
1305
+
1306
+ if getattr (field , "primary_key" , None ) is True :
1304
1307
new_class ._meta .primary_key = PrimaryKey (
1305
1308
name = field_name , field = field
1306
1309
)
1307
- if field_info . vector_options :
1310
+ if getattr ( field , " vector_options" , None ) is not None :
1308
1311
score_attr = f"_{ field_name } _score"
1309
1312
setattr (new_class , score_attr , None )
1310
1313
new_class .__annotations__ [score_attr ] = Union [float , None ]
1311
1314
1315
+ new_class .model_config ["from_attributes" ] = True
1316
+
1312
1317
if not getattr (new_class ._meta , "global_key_prefix" , None ):
1313
1318
new_class ._meta .global_key_prefix = getattr (
1314
1319
base_meta , "global_key_prefix" , ""
@@ -1339,9 +1344,13 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
1339
1344
f"{ new_class ._meta .model_key_prefix } :index"
1340
1345
)
1341
1346
1342
- # Not an abstract model class or embedded model, so we should let the
1347
+ # Model is indexed and not an abstract model class or embedded model, so we should let the
1343
1348
# Migrator create indexes for it.
1344
- if abc .ABC not in bases and not getattr (new_class ._meta , "embedded" , False ):
1349
+ if (
1350
+ abc .ABC not in bases
1351
+ and not getattr (new_class ._meta , "embedded" , False )
1352
+ and new_class .model_config .get ("index" ) is True
1353
+ ):
1345
1354
key = f"{ new_class .__module__ } .{ new_class .__qualname__ } "
1346
1355
model_registry [key ] = new_class
1347
1356
@@ -1366,28 +1375,28 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
1366
1375
Meta = DefaultMeta
1367
1376
1368
1377
model_config = ConfigDict (
1369
- from_attributes = True , arbitrary_types_allowed = True , extra = "allow" , validate_default = True
1378
+ from_attributes = True ,
1379
+ arbitrary_types_allowed = True ,
1380
+ extra = "allow" ,
1381
+ validate_default = True ,
1370
1382
)
1371
1383
1372
1384
def __init__ (__pydantic_self__ , ** data : Any ) -> None :
1373
1385
__pydantic_self__ .validate_primary_key ()
1374
- missing_fields = __pydantic_self__ .model_fields .keys () - data .keys () - {"pk" }
1375
-
1376
- kwargs = data .copy ()
1377
-
1378
- # This is a hack, we need to manually make sure we are setting up defaults correctly when we encounter them
1379
- # because inheritance apparently won't cover that in pydantic 2.0.
1380
- for field in missing_fields :
1381
- default_value = __pydantic_self__ .model_fields .get (field ).default # type: ignore
1382
- kwargs [field ] = default_value
1383
- super ().__init__ (** kwargs )
1386
+ super ().__init__ (** data )
1384
1387
1385
1388
def __lt__ (self , other ):
1386
1389
"""Default sort: compare primary key of models."""
1387
1390
return self .key () < other .key ()
1388
1391
1389
1392
def key (self ):
1390
1393
"""Return the Redis key for this model."""
1394
+ if self .model_config .get ("index" , False ) is not True :
1395
+ raise RedisModelError (
1396
+ "You cannot create a key on a model that is not indexed. "
1397
+ f"Update your class with index=True: class { self .__class__ .__name__ } (RedisModel, index=True):"
1398
+ )
1399
+
1391
1400
if hasattr (self ._meta .primary_key .field , "name" ):
1392
1401
pk = getattr (self , self ._meta .primary_key .field .name )
1393
1402
else :
@@ -1932,7 +1941,7 @@ def schema_for_type(
1932
1941
json_path : str ,
1933
1942
name : str ,
1934
1943
name_prefix : str ,
1935
- typ : Any ,
1944
+ typ : type [ RedisModel ] | Any ,
1936
1945
field_info : PydanticFieldInfo ,
1937
1946
parent_type : Optional [Any ] = None ,
1938
1947
) -> str :
@@ -2010,7 +2019,6 @@ def schema_for_type(
2010
2019
parent_type = field_type ,
2011
2020
)
2012
2021
elif field_is_model :
2013
- typ : type [RedisModel ] = typ
2014
2022
name_prefix = f"{ name_prefix } _{ name } " if name_prefix else name
2015
2023
sub_fields = []
2016
2024
for embedded_name , field in typ .model_fields .items ():
0 commit comments