Skip to content

Commit ea530af

Browse files
committed
Fix emf flow and add subquery unit test
1 parent 2c5e5e2 commit ea530af

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from django.db import models
55
from django.db.models import Field
66
from django.db.models.expressions import Col
7-
from django.db.models.lookups import Transform
7+
from django.db.models.lookups import Lookup, Transform
88

99
from .. import forms
1010
from ..query_utils import process_lhs, process_rhs
1111
from . import EmbeddedModelField
1212
from .array import ArrayField
13-
from .embedded_model import EMFExact
13+
from .embedded_model import EMFExact, EMFMixin
1414

1515

1616
class EmbeddedModelArrayField(ArrayField):
@@ -60,17 +60,8 @@ def get_transform(self, name):
6060
return KeyTransformFactory(name, self)
6161

6262

63-
class ProcessRHSMixin:
64-
def process_rhs(self, compiler, connection):
65-
if isinstance(self.lhs, KeyTransform):
66-
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
67-
else:
68-
get_db_prep_value = self.lhs.output_field.get_db_prep_value
69-
return None, [get_db_prep_value(v, connection, prepared=True) for v in self.rhs]
70-
71-
7263
@EmbeddedModelArrayField.register_lookup
73-
class EMFArrayExact(EMFExact, ProcessRHSMixin):
64+
class EMFArrayExact(EMFExact):
7465
def as_mql(self, compiler, connection):
7566
lhs_mql = process_lhs(self, compiler, connection)
7667
value = process_rhs(self, compiler, connection)
@@ -113,12 +104,29 @@ def as_mql(self, compiler, connection):
113104

114105

115106
@EmbeddedModelArrayField.register_lookup
116-
class ArrayOverlap(EMFExact, ProcessRHSMixin):
107+
class ArrayOverlap(EMFMixin, Lookup):
117108
lookup_name = "overlap"
109+
get_db_prep_lookup_value_is_iterable = True
110+
111+
def process_rhs(self, compiler, connection):
112+
values = self.rhs
113+
if self.get_db_prep_lookup_value_is_iterable:
114+
values = [values]
115+
# Compute how to serialize each value based on the query target.
116+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
117+
# field of the subfield. Otherwise, use the base field of the array itself.
118+
if isinstance(self.lhs, KeyTransform):
119+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
120+
else:
121+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
122+
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
118123

119124
def as_mql(self, compiler, connection):
120125
lhs_mql = process_lhs(self, compiler, connection)
121126
values = process_rhs(self, compiler, connection)
127+
# Querying a subfield within the array elements (via nested KeyTransform).
128+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
129+
# `$in` on the subfield.
122130
if isinstance(self.lhs, KeyTransform):
123131
lhs_mql, inner_lhs_mql = lhs_mql
124132
return {
@@ -137,11 +145,12 @@ def as_mql(self, compiler, connection):
137145
}
138146
conditions = []
139147
inner_lhs_mql = "$$item"
148+
# Querying full embedded documents in the array.
149+
# Builds `$or` conditions and maps them over the array to match any full document.
140150
for value in values:
141-
if isinstance(value, models.Model):
142-
value, emf_data = self.model_to_dict(value)
143-
# Get conditions for any nested EmbeddedModelFields.
144-
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
151+
value, emf_data = self.model_to_dict(value)
152+
# Get conditions for any nested EmbeddedModelFields.
153+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
145154
return {
146155
"$anyElementTrue": {
147156
"$ifNull": [

tests/model_fields_/test_embedded_model.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -300,36 +300,20 @@ def test_overlap_emf(self):
300300
[self.clouds],
301301
)
302302

303-
"""
304-
def test_overlap_charfield_including_expression(self):
305-
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
306-
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
307-
CharArrayModel.objects.create(field=["lower text", "text"])
308-
self.assertSequenceEqual(
309-
CharArrayModel.objects.filter(
310-
field__overlap=[
311-
Upper(Value("text")),
312-
"other",
313-
]
314-
),
315-
[obj_1, obj_2],
316-
)
317-
318303
def test_overlap_values(self):
319-
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
304+
qs = Movie.objects.filter(title__in=["Clouds", "Frozen"])
320305
self.assertCountEqual(
321-
NullableIntegerArrayModel.objects.filter(
322-
field__overlap=qs.values_list("field"),
306+
Movie.objects.filter(
307+
reviews__overlap=qs.values_list("reviews"),
323308
),
324-
self.objs[:3],
309+
[self.clouds, self.frozen],
325310
)
326311
self.assertCountEqual(
327-
NullableIntegerArrayModel.objects.filter(
328-
field__overlap=qs.values("field"),
312+
Movie.objects.filter(
313+
reviews__overlap=qs.values("reviews"),
329314
),
330-
self.objs[:3],
315+
[self.clouds, self.frozen],
331316
)
332-
"""
333317

334318

335319
class QueryingTests(TestCase):

0 commit comments

Comments
 (0)