Skip to content

Commit 260cd09

Browse files
committed
INTPYTHON-624 Add PolymorphicEmbeddedModelField
1 parent b8efc93 commit 260cd09

File tree

6 files changed

+455
-1
lines changed

6 files changed

+455
-1
lines changed

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ def execute_sql(self, result_type):
743743
elif hasattr(value, "prepare_database_save"):
744744
if field.remote_field:
745745
value = value.prepare_database_save(field)
746-
elif not hasattr(field, "embedded_model"):
746+
elif not (hasattr(field, "embedded_model") or hasattr(field, "embedded_models")):
747747
raise TypeError(
748748
f"Tried to update field {field} with a model "
749749
f"instance, {value!r}. Use a value compatible with "

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .embedded_model_array import EmbeddedModelArrayField
66
from .json import register_json_field
77
from .objectid import ObjectIdField
8+
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
89

910
__all__ = [
1011
"register_fields",
@@ -13,6 +14,7 @@
1314
"EmbeddedModelField",
1415
"ObjectIdAutoField",
1516
"ObjectIdField",
17+
"PolymorphicEmbeddedModelField",
1618
]
1719

1820

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from django.core import checks
2+
from django.db import models
3+
4+
# from django.db.models.fields.related import lazy_related_operation
5+
# from django.db.models.lookups import Transform
6+
7+
8+
class PolymorphicEmbeddedModelField(models.Field):
9+
"""Field that stores a model instance."""
10+
11+
def __init__(self, embedded_models, *args, **kwargs):
12+
"""
13+
`embedded_models` is a list of possible model classes to be stored.
14+
Like other relational fields, each model may also be passed as a
15+
string.
16+
"""
17+
self.embedded_models = embedded_models
18+
super().__init__(*args, **kwargs)
19+
20+
def db_type(self, connection):
21+
return "embeddedDocuments"
22+
23+
def check(self, **kwargs):
24+
from ..models import EmbeddedModel
25+
26+
errors = super().check(**kwargs)
27+
for model in self.embedded_models:
28+
if not issubclass(model, EmbeddedModel):
29+
return [
30+
checks.Error(
31+
"Embedded models must be a subclass of "
32+
"django_mongodb_backend.models.EmbeddedModel.",
33+
obj=self,
34+
hint="{model} doesn't subclass EmbeddedModel.",
35+
id="django_mongodb_backend.embedded_model.E002",
36+
)
37+
]
38+
for field in model._meta.fields:
39+
if field.remote_field:
40+
errors.append(
41+
checks.Error(
42+
"Embedded models cannot have relational fields "
43+
f"({model().__class__.__name__}.{field.name} "
44+
f"is a {field.__class__.__name__}).",
45+
obj=self,
46+
id="django_mongodb_backend.embedded_model.E001",
47+
)
48+
)
49+
return errors
50+
51+
def deconstruct(self):
52+
name, path, args, kwargs = super().deconstruct()
53+
if path.startswith("django_mongodb_backend.fields.polymorphic_embedded_model"):
54+
path = path.replace(
55+
"django_mongodb_backend.fields.polymorphic_embedded_model",
56+
"django_mongodb_backend.fields",
57+
)
58+
kwargs["embedded_models"] = self.embedded_models
59+
return name, path, args, kwargs
60+
61+
def get_internal_type(self):
62+
return "PolymorphicEmbeddedModelField"
63+
64+
# def _set_model(self, model):
65+
# """
66+
# Resolve embedded model class once the field knows the model it belongs
67+
# to. If __init__()'s embedded_model argument is a string, resolve it to
68+
# the actual model class, similar to relation fields.
69+
# """
70+
# self._model = model
71+
# if model is not None and isinstance(self.embedded_model, str):
72+
73+
# def _resolve_lookup(_, resolved_model):
74+
# self.embedded_model = resolved_model
75+
76+
# lazy_related_operation(_resolve_lookup, model, self.embedded_model)
77+
78+
# model = property(lambda self: self._model, _set_model)
79+
80+
def from_db_value(self, value, expression, connection):
81+
return self.to_python(value)
82+
83+
def to_python(self, value):
84+
"""
85+
Pass embedded model fields' values through each field's to_python() and
86+
reinstantiate the embedded instance.
87+
"""
88+
if value is None:
89+
return None
90+
if not isinstance(value, dict):
91+
return value
92+
model_class = self._get_model_from_label(value.pop("_label"))
93+
instance = model_class(
94+
**{
95+
field.attname: field.to_python(value[field.attname])
96+
for field in model_class._meta.fields
97+
if field.attname in value
98+
}
99+
)
100+
instance._state.adding = False
101+
return instance
102+
103+
def get_db_prep_save(self, embedded_instance, connection):
104+
"""
105+
Apply pre_save() and get_db_prep_save() of embedded instance fields and
106+
create the {field: value} dict to be saved.
107+
"""
108+
if embedded_instance is None:
109+
return None
110+
if not isinstance(embedded_instance, self.embedded_models):
111+
raise TypeError(
112+
f"Expected instance of type {self.embedded_models!r}, not "
113+
f"{type(embedded_instance)!r}."
114+
)
115+
field_values = {}
116+
add = embedded_instance._state.adding
117+
for field in embedded_instance._meta.fields:
118+
value = field.get_db_prep_save(
119+
field.pre_save(embedded_instance, add), connection=connection
120+
)
121+
# Exclude unset primary keys (e.g. {'id': None}).
122+
if field.primary_key and value is None:
123+
continue
124+
field_values[field.attname] = value
125+
field_values["_label"] = embedded_instance._meta.label
126+
# This instance will exist in the database soon.
127+
embedded_instance._state.adding = False
128+
return field_values
129+
130+
# def get_transform(self, name):
131+
# transform = super().get_transform(name)
132+
# if transform:
133+
# return transform
134+
# field = self.embedded_model._meta.get_field(name)
135+
# return KeyTransformFactory(name, field)
136+
137+
# def validate(self, value, model_instance):
138+
# super().validate(value, model_instance)
139+
# if self.embedded_model is None:
140+
# return
141+
# for field in self.embedded_model._meta.fields:
142+
# attname = field.attname
143+
# field.validate(getattr(value, attname), model_instance)
144+
145+
def formfield(self, **kwargs):
146+
raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.")
147+
148+
def _get_model_from_label(self, label):
149+
return {model._meta.label: model for model in self.embedded_models}[label]
150+
151+
152+
# class KeyTransform(Transform):
153+
# def __init__(self, key_name, ref_field, *args, **kwargs):
154+
# super().__init__(*args, **kwargs)
155+
# self.key_name = str(key_name)
156+
# self.ref_field = ref_field
157+
158+
# def get_lookup(self, name):
159+
# return self.ref_field.get_lookup(name)
160+
161+
# def get_transform(self, name):
162+
# """
163+
# Validate that `name` is either a field of an embedded model or a
164+
# lookup on an embedded model's field.
165+
# """
166+
# if transform := self.ref_field.get_transform(name):
167+
# return transform
168+
# suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
169+
# if suggested_lookups:
170+
# suggested_lookups = " or ".join(suggested_lookups)
171+
# suggestion = f", perhaps you meant {suggested_lookups}?"
172+
# else:
173+
# suggestion = "."
174+
# raise FieldDoesNotExist(
175+
# f"Unsupported lookup '{name}' for "
176+
# f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
177+
# f"{suggestion}"
178+
# )
179+
180+
# def as_mql(self, compiler, connection):
181+
# previous = self
182+
# key_transforms = []
183+
# while isinstance(previous, KeyTransform):
184+
# key_transforms.insert(0, previous.key_name)
185+
# previous = previous.lhs
186+
# mql = previous.as_mql(compiler, connection)
187+
# for key in key_transforms:
188+
# mql = {"$getField": {"input": mql, "field": key}}
189+
# return mql
190+
191+
# @property
192+
# def output_field(self):
193+
# return self.ref_field
194+
195+
196+
# class KeyTransformFactory:
197+
# def __init__(self, key_name, ref_field):
198+
# self.key_name = key_name
199+
# self.ref_field = ref_field
200+
201+
# def __call__(self, *args, **kwargs):
202+
# return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)

django_mongodb_backend/operations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def get_db_converters(self, expression):
122122
)
123123
elif internal_type == "JSONField":
124124
converters.append(self.convert_jsonfield_value)
125+
elif internal_type == "PolymorphicEmbeddedModelField":
126+
converters.append(self.convert_polymorphicembeddedmodelfield_value)
125127
elif internal_type == "TimeField":
126128
# Trunc(... output_field="TimeField") values must remain datetime
127129
# until Trunc.convert_value() so they can be converted from UTC
@@ -182,6 +184,19 @@ def convert_jsonfield_value(self, value, expression, connection):
182184
"""
183185
return json.dumps(value)
184186

187+
def convert_polymorphicembeddedmodelfield_value(self, value, expression, connection):
188+
if value is not None:
189+
model_class = expression.output_field._get_model_from_label(value["_label"])
190+
# Apply database converters to each field of the embedded model.
191+
for field in model_class._meta.fields:
192+
field_expr = Expression(output_field=field)
193+
converters = connection.ops.get_db_converters(
194+
field_expr
195+
) + field_expr.get_db_converters(connection)
196+
for converter in converters:
197+
value[field.attname] = converter(value[field.attname], field_expr, connection)
198+
return value
199+
185200
def convert_timefield_value(self, value, expression, connection):
186201
if value is not None:
187202
value = value.time()

tests/model_fields_/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
EmbeddedModelArrayField,
88
EmbeddedModelField,
99
ObjectIdField,
10+
PolymorphicEmbeddedModelField,
1011
)
1112
from django_mongodb_backend.models import EmbeddedModel
1213

@@ -222,3 +223,30 @@ class Tour(models.Model):
222223

223224
def __str__(self):
224225
return f"Tour by {self.guide}"
226+
227+
228+
# PolymorphicEmbeddedModelField
229+
class Dog(EmbeddedModel):
230+
name = models.CharField(max_length=100)
231+
barks = models.BooleanField(default=True)
232+
data = models.JSONField(default=dict)
233+
234+
def __str__(self):
235+
return self.name
236+
237+
238+
class Cat(EmbeddedModel):
239+
name = models.CharField(max_length=100)
240+
purs = models.BooleanField(default=True)
241+
weight = models.DecimalField(max_digits=4, decimal_places=2, blank=True, null=True)
242+
243+
def __str__(self):
244+
return self.name
245+
246+
247+
class Person(models.Model):
248+
name = models.CharField(max_length=100)
249+
pet = PolymorphicEmbeddedModelField((Dog, Cat), blank=True, null=True)
250+
251+
def __str__(self):
252+
return self.name

0 commit comments

Comments
 (0)