Skip to content

Commit f77e03c

Browse files
committed
add EmbeddedModelField
1 parent eae01a4 commit f77e03c

23 files changed

+1492
-42
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ repos:
4444
hooks:
4545
- id: rstcheck
4646
additional_dependencies: [sphinx]
47-
args: ["--ignore-directives=fieldlookup,setting", "--ignore-roles=lookup,setting"]
47+
args: ["--ignore-directives=django-admin,fieldlookup,setting", "--ignore-roles=djadmin,lookup,setting"]
4848

4949
# We use the Python version instead of the original version which seems to require Docker
5050
# https://github.com/koalaman/shellcheck-precommit

THIRD-PARTY-NOTICES

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ be distributed under licenses different than this software.
33

44
The attached notices are provided for information only.
55

6-
django-mongodb-backend began by borrowing code from Django non-rel's
6+
django-mongodb-backend and EmbeddedModelField began by borrowing code from
77
django-mongodb-engine (https://github.com/django-nonrel/mongodb-engine),
88
abandoned since 2015 and Django 1.6.
99

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def execute_sql(self, result_type):
741741
elif hasattr(value, "prepare_database_save"):
742742
if field.remote_field:
743743
value = value.prepare_database_save(field)
744-
else:
744+
elif not hasattr(field, "embedded_model"):
745745
raise TypeError(
746746
f"Tried to update field {field} with a model "
747747
f"instance, {value!r}. Use a value compatible with "

django_mongodb_backend/fields/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from .array import ArrayField
22
from .auto import ObjectIdAutoField
33
from .duration import register_duration_field
4+
from .embedded_model import EmbeddedModelField
45
from .json import register_json_field
56
from .objectid import ObjectIdField
67

7-
__all__ = ["register_fields", "ArrayField", "ObjectIdAutoField", "ObjectIdField"]
8+
__all__ = [
9+
"register_fields",
10+
"ArrayField",
11+
"EmbeddedModelField",
12+
"ObjectIdAutoField",
13+
"ObjectIdField",
14+
]
815

916

1017
def register_fields():
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from django.core import checks
2+
from django.db import models
3+
from django.db.models.fields.related import lazy_related_operation
4+
from django.db.models.lookups import Transform
5+
6+
from .. import forms
7+
8+
9+
class EmbeddedModelField(models.Field):
10+
"""Field that stores a model instance."""
11+
12+
def __init__(self, embedded_model, *args, **kwargs):
13+
"""
14+
`embedded_model` is the model class of the instance to be stored.
15+
Like other relational fields, it may also be passed as a string.
16+
"""
17+
self.embedded_model = embedded_model
18+
super().__init__(*args, **kwargs)
19+
20+
def check(self, **kwargs):
21+
errors = super().check(**kwargs)
22+
for field in self.embedded_model._meta.fields:
23+
if field.remote_field:
24+
errors.append(
25+
checks.Error(
26+
"Embedded models cannot have relational fields "
27+
f"({self.embedded_model().__class__.__name__}.{field.name} "
28+
f"is a {field.__class__.__name__}).",
29+
obj=self,
30+
id="django_mongodb.embedded_model.E001",
31+
)
32+
)
33+
return errors
34+
35+
def deconstruct(self):
36+
name, path, args, kwargs = super().deconstruct()
37+
if path.startswith("django_mongodb_backend.fields.embedded_model"):
38+
path = path.replace(
39+
"django_mongodb_backend.fields.embedded_model", "django_mongodb_backend.fields"
40+
)
41+
kwargs["embedded_model"] = self.embedded_model
42+
return name, path, args, kwargs
43+
44+
def get_internal_type(self):
45+
return "EmbeddedModelField"
46+
47+
def _set_model(self, model):
48+
"""
49+
Resolve embedded model class once the field knows the model it belongs
50+
to. If __init__()'s embedded_model argument is a string, resolve it to
51+
the actual model class, similar to relation fields.
52+
"""
53+
self._model = model
54+
if model is not None and isinstance(self.embedded_model, str):
55+
56+
def _resolve_lookup(_, resolved_model):
57+
self.embedded_model = resolved_model
58+
59+
lazy_related_operation(_resolve_lookup, model, self.embedded_model)
60+
61+
model = property(lambda self: self._model, _set_model)
62+
63+
def from_db_value(self, value, expression, connection):
64+
return self.to_python(value)
65+
66+
def to_python(self, value):
67+
"""
68+
Pass embedded model fields' values through each field's to_python() and
69+
reinstantiate the embedded instance.
70+
"""
71+
if value is None:
72+
return None
73+
if not isinstance(value, dict):
74+
return value
75+
instance = self.embedded_model(
76+
**{
77+
field.attname: field.to_python(value[field.attname])
78+
for field in self.embedded_model._meta.fields
79+
if field.attname in value
80+
}
81+
)
82+
instance._state.adding = False
83+
return instance
84+
85+
def get_db_prep_save(self, embedded_instance, connection):
86+
"""
87+
Apply pre_save() and get_db_prep_save() of embedded instance fields and
88+
create the {field: value} dict to be saved.
89+
"""
90+
if embedded_instance is None:
91+
return None
92+
if not isinstance(embedded_instance, self.embedded_model):
93+
raise TypeError(
94+
f"Expected instance of type {self.embedded_model!r}, not "
95+
f"{type(embedded_instance)!r}."
96+
)
97+
field_values = {}
98+
add = embedded_instance._state.adding
99+
for field in embedded_instance._meta.fields:
100+
value = field.get_db_prep_save(
101+
field.pre_save(embedded_instance, add), connection=connection
102+
)
103+
# Exclude unset primary keys (e.g. {'id': None}).
104+
if field.primary_key and value is None:
105+
continue
106+
field_values[field.attname] = value
107+
# This instance will exist in the database soon.
108+
embedded_instance._state.adding = False
109+
return field_values
110+
111+
def get_transform(self, name):
112+
transform = super().get_transform(name)
113+
if transform:
114+
return transform
115+
return KeyTransformFactory(name)
116+
117+
def validate(self, value, model_instance):
118+
super().validate(value, model_instance)
119+
if self.embedded_model is None:
120+
return
121+
for field in self.embedded_model._meta.fields:
122+
attname = field.attname
123+
field.validate(getattr(value, attname), model_instance)
124+
125+
def formfield(self, **kwargs):
126+
return super().formfield(
127+
**{
128+
"form_class": forms.EmbeddedModelField,
129+
"model": self.embedded_model,
130+
"prefix": self.name,
131+
**kwargs,
132+
}
133+
)
134+
135+
136+
class KeyTransform(Transform):
137+
def __init__(self, key_name, *args, **kwargs):
138+
super().__init__(*args, **kwargs)
139+
self.key_name = str(key_name)
140+
141+
def preprocess_lhs(self, compiler, connection):
142+
key_transforms = [self.key_name]
143+
previous = self.lhs
144+
while isinstance(previous, KeyTransform):
145+
key_transforms.insert(0, previous.key_name)
146+
previous = previous.lhs
147+
mql = previous.as_mql(compiler, connection)
148+
return mql, key_transforms
149+
150+
def as_mql(self, compiler, connection):
151+
mql, key_transforms = self.preprocess_lhs(compiler, connection)
152+
transforms = ".".join(key_transforms)
153+
return f"{mql}.{transforms}"
154+
155+
156+
class KeyTransformFactory:
157+
def __init__(self, key_name):
158+
self.key_name = key_name
159+
160+
def __call__(self, *args, **kwargs):
161+
return KeyTransform(self.key_name, *args, **kwargs)

django_mongodb_backend/forms/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget
1+
from .fields import (
2+
EmbeddedModelField,
3+
ObjectIdField,
4+
SimpleArrayField,
5+
SplitArrayField,
6+
SplitArrayWidget,
7+
)
28

39
__all__ = [
10+
"EmbeddedModelField",
411
"SimpleArrayField",
512
"SplitArrayField",
613
"SplitArrayWidget",

django_mongodb_backend/forms/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .array import SimpleArrayField, SplitArrayField, SplitArrayWidget
2+
from .embedded_model import EmbeddedModelField
23
from .objectid import ObjectIdField
34

45
__all__ = [
6+
"EmbeddedModelField",
57
"SimpleArrayField",
68
"SplitArrayField",
79
"SplitArrayWidget",
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from django import forms
2+
from django.forms.models import modelform_factory
3+
from django.utils.safestring import mark_safe
4+
from django.utils.translation import gettext_lazy as _
5+
6+
7+
class EmbeddedModelWidget(forms.MultiWidget):
8+
def __init__(self, field_names, *args, **kwargs):
9+
self.field_names = field_names
10+
super().__init__(*args, **kwargs)
11+
# The default widget names are "_0", "_1", etc. Use the field names
12+
# instead since that's how they'll be rendered by the model form.
13+
self.widgets_names = ["-" + name for name in field_names]
14+
15+
def decompress(self, value):
16+
if value is None:
17+
return []
18+
# Get the data from `value` (a model) for each field.
19+
return [getattr(value, name) for name in self.field_names]
20+
21+
22+
class EmbeddedModelBoundField(forms.BoundField):
23+
def __str__(self):
24+
"""Render the model form as the representation for this field."""
25+
form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs)
26+
return mark_safe(f"{form.as_div()}") # noqa: S308
27+
28+
29+
class EmbeddedModelField(forms.MultiValueField):
30+
default_error_messages = {
31+
"invalid": _("Enter a list of values."),
32+
"incomplete": _("Enter all required values."),
33+
}
34+
35+
def __init__(self, model, prefix, *args, **kwargs):
36+
form_kwargs = {}
37+
# To avoid collisions with other fields on the form, each subfield must
38+
# be prefixed with the name of the field.
39+
form_kwargs["prefix"] = prefix
40+
self.form_kwargs = form_kwargs
41+
self.model_form_cls = modelform_factory(model, fields="__all__")
42+
self.model_form = self.model_form_cls(**form_kwargs)
43+
self.field_names = list(self.model_form.fields.keys())
44+
fields = self.model_form.fields.values()
45+
widgets = [field.widget for field in fields]
46+
widget = EmbeddedModelWidget(self.field_names, widgets)
47+
super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs)
48+
49+
def compress(self, data_dict):
50+
if not data_dict:
51+
return None
52+
values = dict(zip(self.field_names, data_dict, strict=False))
53+
return self.model_form._meta.model(**values)
54+
55+
def get_bound_field(self, form, field_name):
56+
return EmbeddedModelBoundField(form, self, field_name)
57+
58+
def bound_data(self, data, initial):
59+
if self.disabled:
60+
return initial
61+
# Transform the bound data into a model instance.
62+
return self.compress(data)

0 commit comments

Comments
 (0)