Skip to content

Commit 5cdaef4

Browse files
committed
Adding support for overlap
1 parent c2a1bcb commit 5cdaef4

File tree

3 files changed

+115
-11
lines changed

3 files changed

+115
-11
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
326326
def as_mql(self, compiler, connection):
327327
lhs_mql = process_lhs(self, compiler, connection)
328328
value = process_rhs(self, compiler, connection)
329-
return {
330-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
331-
}
329+
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}
332330

333331

334332
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,20 @@ def get_transform(self, name):
5757
transform = super().get_transform(name)
5858
if transform:
5959
return transform
60-
return KeyTransformFactory(name, self.base_field)
60+
return KeyTransformFactory(name, self)
61+
62+
63+
class ProcessRHSMixin:
64+
def process_rhs(self, compiler, connection):
65+
if isinstance(self.lhs, KeyTransform):
66+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
67+
else:
68+
get_db_prep_value = self.lhs.output_field.get_db_prep_value
69+
return None, [get_db_prep_value(v, connection, prepared=True) for v in self.rhs]
6170

6271

6372
@EmbeddedModelArrayField.register_lookup
64-
class EMFArrayExact(EMFExact):
73+
class EMFArrayExact(EMFExact, ProcessRHSMixin):
6574
def as_mql(self, compiler, connection):
6675
lhs_mql = process_lhs(self, compiler, connection)
6776
value = process_rhs(self, compiler, connection)
@@ -103,15 +112,61 @@ def as_mql(self, compiler, connection):
103112
}
104113

105114

115+
@EmbeddedModelArrayField.register_lookup
116+
class ArrayOverlap(EMFExact, ProcessRHSMixin):
117+
lookup_name = "overlap"
118+
119+
def as_mql(self, compiler, connection):
120+
lhs_mql = process_lhs(self, compiler, connection)
121+
values = process_rhs(self, compiler, connection)
122+
if isinstance(self.lhs, KeyTransform):
123+
lhs_mql, inner_lhs_mql = lhs_mql
124+
return {
125+
"$anyElementTrue": {
126+
"$ifNull": [
127+
{
128+
"$map": {
129+
"input": lhs_mql,
130+
"as": "item",
131+
"in": {"$in": [inner_lhs_mql, values]},
132+
}
133+
},
134+
[],
135+
]
136+
}
137+
}
138+
conditions = []
139+
inner_lhs_mql = "$$item"
140+
for value in values:
141+
if isinstance(value, models.Model):
142+
value, emf_data = self.model_to_dict(value)
143+
# Get conditions for any nested EmbeddedModelFields.
144+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
145+
return {
146+
"$anyElementTrue": {
147+
"$ifNull": [
148+
{
149+
"$map": {
150+
"input": lhs_mql,
151+
"as": "item",
152+
"in": {"$or": conditions},
153+
}
154+
},
155+
[],
156+
]
157+
}
158+
}
159+
160+
106161
class KeyTransform(Transform):
107162
# it should be different class than EMF keytransform even most of the methods are equal.
108-
def __init__(self, key_name, base_field, *args, **kwargs):
163+
def __init__(self, key_name, array_field, *args, **kwargs):
109164
super().__init__(*args, **kwargs)
110-
self.base_field = base_field
165+
self.array_field = array_field
111166
self.key_name = key_name
112167
# The iteration items begins from the base_field, a virtual column with
113168
# base field output type is created.
114-
column_target = base_field.clone()
169+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
115170
column_name = f"$item.{key_name}"
116171
column_target.db_column = column_name
117172
column_target.set_attributes_from_name(column_name)
@@ -134,7 +189,7 @@ def _get_missing_field_or_lookup_exception(self, lhs, name):
134189
suggestion = "."
135190
raise FieldDoesNotExist(
136191
f"Unsupported lookup '{name}' for "
137-
f"{self.base_field.__class__.__name__} '{self.base_field.name}'"
192+
f"{self.array_field.base_field.__class__.__name__} '{self.array_field.base_field.name}'"
138193
f"{suggestion}"
139194
)
140195

@@ -147,7 +202,9 @@ def get_transform(self, name):
147202
transform = (
148203
self._lhs.get_transform(name)
149204
if isinstance(self._lhs, Transform)
150-
else self.base_field.embedded_model._meta.get_field(self.key_name).get_transform(name)
205+
else self.array_field.base_field.embedded_model._meta.get_field(
206+
self.key_name
207+
).get_transform(name)
151208
)
152209
if transform:
153210
self._sub_transform = transform
@@ -163,7 +220,7 @@ def as_mql(self, compiler, connection):
163220

164221
@property
165222
def output_field(self):
166-
return EmbeddedModelArrayField(self.base_field)
223+
return self.array_field
167224

168225

169226
class KeyTransformFactory:

tests/model_fields_/test_embedded_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,55 @@ def test_len(self):
282282
MuseumExhibit.objects.filter(sections__1__artifacts__len=1), [self.wonders]
283283
)
284284

285+
def test_overlap_simplefield(self):
286+
self.assertSequenceEqual(
287+
MuseumExhibit.objects.filter(sections__section_number__overlap=[10]), []
288+
)
289+
self.assertSequenceEqual(
290+
MuseumExhibit.objects.filter(sections__section_number__overlap=[1]),
291+
[self.egypt, self.wonders, self.new_descoveries],
292+
)
293+
self.assertSequenceEqual(
294+
MuseumExhibit.objects.filter(sections__section_number__overlap=[2]), [self.wonders]
295+
)
296+
297+
def test_overlap_emf(self):
298+
self.assertSequenceEqual(
299+
Movie.objects.filter(reviews__overlap=[Review(title="The best", rating=10)]),
300+
[self.clouds],
301+
)
302+
303+
"""
304+
def test_overlap_charfield_including_expression(self):
305+
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
306+
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
307+
CharArrayModel.objects.create(field=["lower text", "text"])
308+
self.assertSequenceEqual(
309+
CharArrayModel.objects.filter(
310+
field__overlap=[
311+
Upper(Value("text")),
312+
"other",
313+
]
314+
),
315+
[obj_1, obj_2],
316+
)
317+
318+
def test_overlap_values(self):
319+
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
320+
self.assertCountEqual(
321+
NullableIntegerArrayModel.objects.filter(
322+
field__overlap=qs.values_list("field"),
323+
),
324+
self.objs[:3],
325+
)
326+
self.assertCountEqual(
327+
NullableIntegerArrayModel.objects.filter(
328+
field__overlap=qs.values("field"),
329+
),
330+
self.objs[:3],
331+
)
332+
"""
333+
285334

286335
class QueryingTests(TestCase):
287336
@classmethod

0 commit comments

Comments
 (0)