Skip to content

Commit bf71f56

Browse files
committed
Validate embedded model fields.
1 parent 6f43abc commit bf71f56

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

django_mongodb/fields/embedded_model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def __init__(self, embedded_model, *args, **kwargs):
1515
stored. Like other relational fields, it may also be passed as a
1616
string.
1717
"""
18+
if not isinstance(embedded_model, str):
19+
self._validate_embedded_field(self, embedded_model)
20+
1821
self.embedded_model = embedded_model
1922
super().__init__(*args, **kwargs)
2023

@@ -28,6 +31,14 @@ def deconstruct(self):
2831
def get_internal_type(self):
2932
return "EmbeddedModelField"
3033

34+
@staticmethod
35+
def _validate_embedded_field(_, model):
36+
for field in model._meta.local_fields:
37+
if isinstance(field, models.ForeignKey | models.OneToOneField):
38+
raise TypeError(
39+
f"Field of type {type(field)!r} is not supported within an EmbeddedModelField."
40+
)
41+
3142
def _set_model(self, model):
3243
"""
3344
Resolve embedded model class once the field knows the model it belongs
@@ -48,6 +59,7 @@ def _resolve_lookup(_, resolved_model):
4859
self.embedded_model = resolved_model
4960

5061
lazy_related_operation(_resolve_lookup, model, self.embedded_model)
62+
lazy_related_operation(self._validate_embedded_field, model, self.embedded_model)
5163

5264
model = property(lambda self: self._model, _set_model)
5365

@@ -139,7 +151,7 @@ def formfield(self, **kwargs):
139151

140152

141153
class KeyTransform(Transform):
142-
def __init__(self, key_name, ref_field=None, *args, **kwargs):
154+
def __init__(self, key_name, ref_field, *args, **kwargs):
143155
super().__init__(*args, **kwargs)
144156
self.key_name = str(key_name)
145157
self.ref_field = ref_field
@@ -195,7 +207,7 @@ def key_transform(self, compiler, connection):
195207

196208

197209
class KeyTransformFactory:
198-
def __init__(self, key_name, ref_field=None):
210+
def __init__(self, key_name, ref_field):
199211
self.key_name = key_name
200212
self.ref_field = ref_field
201213

tests/model_fields_/test_embedded_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
F,
99
IntegerField,
1010
Max,
11+
Model,
1112
OuterRef,
1213
Subquery,
1314
Sum,
@@ -21,6 +22,7 @@
2122
Author,
2223
Book,
2324
DecimalKey,
25+
DecimalParent,
2426
EmbeddedModel,
2527
EmbeddedModelFieldModel,
2628
Library,
@@ -97,6 +99,16 @@ def test_embedded_field_with_foreign_conversion(self):
9799
# decimal_parent = DecimalParent.objects.create(child=decimal)
98100
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal)
99101

102+
def test_foreign_key_in_embedded_object(self):
103+
msg = (
104+
"Field of type <class 'django.db.models.fields.related.ForeignKey'> "
105+
"is not supported within an EmbeddedModelField."
106+
)
107+
with self.assertRaisesMessage(TypeError, msg):
108+
109+
class EmbeddedModelTest(Model):
110+
decimal = EmbeddedModelField(DecimalParent, null=True, blank=True)
111+
100112

101113
class QueryingTests(TestCase):
102114
@classmethod

0 commit comments

Comments
 (0)