Skip to content

Commit e066095

Browse files
committed
schema checks and unit tests.
1 parent dd65bb7 commit e066095

File tree

4 files changed

+107
-26
lines changed

4 files changed

+107
-26
lines changed

django_mongodb/compiler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20+
from django_mongodb.fields import EmbeddedModelField
21+
2022
from .base import Cursor
2123
from .query import MongoQuery, wrap_database_errors
2224

@@ -547,7 +549,11 @@ def get_combinator_queries(self):
547549
def get_lookup_pipeline(self):
548550
result = []
549551
for alias in tuple(self.query.alias_map):
550-
if not self.query.alias_refcount[alias] or self.collection_name == alias:
552+
if (
553+
not self.query.alias_refcount[alias]
554+
or self.collection_name == alias
555+
or isinstance(self.query.alias_map[alias].join_field, EmbeddedModelField)
556+
):
551557
continue
552558
result += self.query.alias_map[alias].as_mql(self, self.connection)
553559
return result

django_mongodb/fields/embedded_model.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.core.exceptions import FieldDoesNotExist
12
from django.db import models
23
from django.db.models.fields.related import lazy_related_operation
34
from django.db.models.lookups import Transform
@@ -115,7 +116,8 @@ def get_transform(self, name):
115116
transform = super().get_transform(name)
116117
if transform:
117118
return transform
118-
return KeyTransformFactory(name)
119+
field = self.embedded_model._meta.get_field(name)
120+
return KeyTransformFactory(name, field)
119121

120122
def validate(self, value, model_instance):
121123
super().validate(value, model_instance)
@@ -137,32 +139,67 @@ def formfield(self, **kwargs):
137139

138140

139141
class KeyTransform(Transform):
140-
def __init__(self, key_name, *args, **kwargs):
142+
def __init__(self, key_name, ref_field=None, *args, **kwargs):
141143
super().__init__(*args, **kwargs)
142144
self.key_name = str(key_name)
145+
self.ref_field = ref_field
146+
147+
def get_transform(self, name):
148+
result = None
149+
if isinstance(self.ref_field, EmbeddedModelField):
150+
opts = self.ref_field.embedded_model._meta
151+
new_field = opts.get_field(name)
152+
result = KeyTransformFactory(name, new_field)
153+
else:
154+
if self.ref_field is not None and self.ref_field.get_transform(name) is None:
155+
raise FieldDoesNotExist(
156+
f"{self.ref_field.model._meta.object_name} has no field named '{name}'"
157+
)
158+
result = KeyTransformFactory(name, None)
159+
return result
143160

144161
def preprocess_lhs(self, compiler, connection):
145-
key_transforms = [self.key_name]
146-
previous = self.lhs
162+
previous = self
163+
embedded_key_transforms = []
164+
json_key_transforms = []
147165
while isinstance(previous, KeyTransform):
148-
key_transforms.insert(0, previous.key_name)
166+
if previous.ref_field is not None:
167+
embedded_key_transforms.insert(0, previous.key_name)
168+
else:
169+
json_key_transforms.insert(0, previous.key_name)
149170
previous = previous.lhs
150171
mql = previous.as_mql(compiler, connection)
151-
return mql, key_transforms
172+
return mql, embedded_key_transforms, json_key_transforms
152173

153174

154175
def key_transform(self, compiler, connection):
155-
mql, key_transforms = self.preprocess_lhs(compiler, connection)
176+
mql, key_transforms, json_key_transforms = self.preprocess_lhs(compiler, connection)
156177
transforms = ".".join(key_transforms)
157-
return f"{mql}.{transforms}"
178+
result = f"{mql}.{transforms}"
179+
for key in json_key_transforms:
180+
get_field = {"$getField": {"input": result, "field": key}}
181+
# Handle array indexing if the key is a digit. If key is something
182+
# like '001', it's not an array index despite isdigit() returning True.
183+
if key.isdigit() and str(int(key)) == key:
184+
result = {
185+
"$cond": {
186+
"if": {"$isArray": result},
187+
"then": {"$arrayElemAt": [result, int(key)]},
188+
"else": get_field,
189+
}
190+
}
191+
else:
192+
result = get_field
193+
return result
158194

159195

160196
class KeyTransformFactory:
161-
def __init__(self, key_name):
197+
def __init__(self, key_name, ref_field=None):
162198
self.key_name = key_name
199+
self.ref_field = ref_field
163200

164201
def __call__(self, *args, **kwargs):
165-
return KeyTransform(self.key_name, *args, **kwargs)
202+
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
166203

167204

168205
def register_embedded_model_field():

tests/model_fields_/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ class DecimalParent(models.Model):
3535

3636
class EmbeddedModelFieldModel(models.Model):
3737
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True)
38-
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True)
38+
decimal_parent = EmbeddedModelField(DecimalKey, null=True, blank=True)
3939

4040

4141
class EmbeddedModel(models.Model):
42-
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True)
42+
json_value = models.JSONField()
43+
decimal = EmbeddedModelField(DecimalModel, null=True, blank=True)
4344
someint = models.IntegerField(db_column="custom_column")
4445
auto_now = models.DateTimeField(auto_now=True)
4546
auto_now_add = models.DateTimeField(auto_now_add=True)

tests/model_fields_/test_embedded_model.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from decimal import Decimal
22

3-
from django.core.exceptions import ValidationError
3+
from django.core.exceptions import FieldDoesNotExist, ValidationError
44
from django.test import SimpleTestCase, TestCase
55

66
from django_mongodb.fields import EmbeddedModelField
@@ -10,10 +10,8 @@
1010
Author,
1111
Book,
1212
DecimalKey,
13-
DecimalParent,
1413
EmbeddedModel,
1514
EmbeddedModelFieldModel,
16-
Target,
1715
)
1816

1917

@@ -82,18 +80,10 @@ def test_pre_save(self):
8280
self.assertEqual(obj.simple.auto_now_add, auto_now_add)
8381
self.assertGreater(obj.simple.auto_now, auto_now_two)
8482

85-
def test_foreign_key_in_embedded_object(self):
86-
simple = EmbeddedModel(some_relation=Target.objects.create(index=1))
87-
obj = EmbeddedModelFieldModel.objects.create(simple=simple)
88-
simple = EmbeddedModelFieldModel.objects.get().simple
89-
self.assertNotIn("some_relation", simple.__dict__)
90-
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id))
91-
self.assertIsInstance(simple.some_relation, Target)
92-
9383
def test_embedded_field_with_foreign_conversion(self):
9484
decimal = DecimalKey.objects.create(decimal=Decimal("1.5"))
95-
decimal_parent = DecimalParent.objects.create(child=decimal)
96-
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent)
85+
# decimal_parent = DecimalParent.objects.create(child=decimal)
86+
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal)
9787

9888

9989
class QueryingTests(TestCase):
@@ -134,3 +124,50 @@ def test_nested(self):
134124
author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
135125
)
136126
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])
127+
128+
def test_nested_not_exists(self):
129+
msg = "Address has no field named 'president'"
130+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
131+
Book.objects.filter(author__address__city__president="NYC")
132+
133+
def test_not_exists_in_embedded(self):
134+
msg = "Address has no field named 'floor'"
135+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
136+
Book.objects.filter(author__address__floor="NYC")
137+
138+
def test_embedded_with_json_field(self):
139+
models = []
140+
for i in range(4):
141+
m = EmbeddedModelFieldModel.objects.create(
142+
simple=EmbeddedModel(
143+
json_value={"field1": i * 5, "field2": {"0": {"value": list(range(i))}}}
144+
)
145+
)
146+
models.append(m)
147+
148+
all_models = EmbeddedModelFieldModel.objects.all()
149+
150+
self.assertCountEqual(
151+
EmbeddedModelFieldModel.objects.filter(simple__json_value__field2__0__value__0=0),
152+
models[1:],
153+
)
154+
self.assertCountEqual(
155+
EmbeddedModelFieldModel.objects.filter(simple__json_value__field2__0__value__1=1),
156+
models[2:],
157+
)
158+
self.assertCountEqual(
159+
EmbeddedModelFieldModel.objects.filter(simple__json_value__field2__0__value__1=5), []
160+
)
161+
162+
self.assertCountEqual(
163+
EmbeddedModelFieldModel.objects.filter(simple__json_value__field1__lt=100), all_models
164+
)
165+
self.assertCountEqual(
166+
EmbeddedModelFieldModel.objects.filter(simple__json_value__field1__gt=100), []
167+
)
168+
self.assertCountEqual(
169+
EmbeddedModelFieldModel.objects.filter(
170+
simple__json_value__field1__gte=5, simple__json_value__field1__lte=10
171+
),
172+
models[1:3],
173+
)

0 commit comments

Comments
 (0)