Skip to content

Commit 0f4699f

Browse files
committed
INTPYTHON-658 Add PolymorphicEmbeddedModelArrayField
1 parent 817d03a commit 0f4699f

File tree

5 files changed

+719
-0
lines changed

5 files changed

+719
-0
lines changed

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .json import register_json_field
77
from .objectid import ObjectIdField
88
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
9+
from .polymorphic_embedded_model_array import PolymorphicEmbeddedModelArrayField
910

1011
__all__ = [
1112
"register_fields",
@@ -15,6 +16,7 @@
1516
"ObjectIdAutoField",
1617
"ObjectIdField",
1718
"PolymorphicEmbeddedModelField",
19+
"PolymorphicEmbeddedModelArrayField",
1820
]
1921

2022

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import lookups
5+
from django.db.models.expressions import Col
6+
from django.db.models.fields.related import lazy_related_operation
7+
from django.db.models.lookups import Lookup, Transform
8+
9+
from ..query_utils import process_lhs, process_rhs
10+
from . import PolymorphicEmbeddedModelField
11+
from .array import ArrayField, ArrayLenTransform
12+
13+
14+
class PolymorphicEmbeddedModelArrayField(ArrayField):
15+
value_is_model_instance = True
16+
17+
def __init__(self, embedded_models, **kwargs):
18+
if "size" in kwargs:
19+
raise ValueError("PolymorphicEmbeddedModelArrayField does not support size.")
20+
kwargs["editable"] = False
21+
super().__init__(PolymorphicEmbeddedModelField(embedded_models), **kwargs)
22+
self.embedded_models = embedded_models
23+
24+
def contribute_to_class(self, cls, name, private_only=False, **kwargs):
25+
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
26+
27+
if not cls._meta.abstract:
28+
# If the embedded_model argument is a string, resolve it to the
29+
# actual model class.
30+
def _resolve_lookup(_, *resolved_models):
31+
self.embedded_models = resolved_models
32+
33+
lazy_related_operation(_resolve_lookup, cls, *self.embedded_models)
34+
35+
def deconstruct(self):
36+
name, path, args, kwargs = super().deconstruct()
37+
if path == (
38+
"django_mongodb_backend.fields.polymorphic_embedded_model_array."
39+
"PolymorphicEmbeddedModelArrayField"
40+
):
41+
path = "django_mongodb_backend.fields.PolymorphicEmbeddedModelArrayField"
42+
kwargs["embedded_models"] = self.embedded_models
43+
del kwargs["base_field"]
44+
del kwargs["editable"]
45+
return name, path, args, kwargs
46+
47+
def get_db_prep_value(self, value, connection, prepared=False):
48+
if isinstance(value, list | tuple):
49+
# Must call get_db_prep_save() rather than get_db_prep_value()
50+
# to transform model instances to dicts.
51+
return [self.base_field.get_db_prep_save(i, connection) for i in value]
52+
if value is not None:
53+
raise TypeError(
54+
f"Expected list of {self.embedded_models!r} instances, not {type(value)!r}."
55+
)
56+
return value
57+
58+
def formfield(self, **kwargs):
59+
raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.")
60+
61+
_get_model_from_label = PolymorphicEmbeddedModelField._get_model_from_label
62+
63+
def get_transform(self, name):
64+
transform = super().get_transform(name)
65+
if transform:
66+
return transform
67+
return KeyTransformFactory(name, self)
68+
69+
def _get_lookup(self, lookup_name):
70+
lookup = super()._get_lookup(lookup_name)
71+
if lookup is None or lookup is ArrayLenTransform:
72+
return lookup
73+
74+
class EmbeddedModelArrayFieldLookups(Lookup):
75+
def as_mql(self, compiler, connection):
76+
raise ValueError(
77+
"Lookups aren't supported on PolymorphicEmbeddedModelArrayField. "
78+
"Try querying one of its embedded fields instead."
79+
)
80+
81+
return EmbeddedModelArrayFieldLookups
82+
83+
84+
class _EmbeddedModelArrayOutputField(ArrayField):
85+
"""
86+
Represent the output of an EmbeddedModelArrayField when traversed in a
87+
query path.
88+
89+
This field is not meant to be used in model definitions. It exists solely
90+
to support query output resolution. When an EmbeddedModelArrayField is
91+
accessed in a query, the result should behave like an array of the embedded
92+
model's target type.
93+
94+
While it mimics ArrayField's lookup behavior, the way those lookups are
95+
resolved follows the semantics of EmbeddedModelArrayField rather than
96+
ArrayField.
97+
"""
98+
99+
ALLOWED_LOOKUPS = {
100+
"in",
101+
"exact",
102+
"iexact",
103+
"gt",
104+
"gte",
105+
"lt",
106+
"lte",
107+
}
108+
109+
def get_lookup(self, name):
110+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
111+
112+
113+
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
114+
def process_rhs(self, compiler, connection):
115+
value = self.rhs
116+
if not self.get_db_prep_lookup_value_is_iterable:
117+
value = [value]
118+
# Value must be serialized based on the query target. If querying a
119+
# subfield inside the array (i.e., a nested KeyTransform), use the
120+
# output field of the subfield. Otherwise, use the base field of the
121+
# array itself.
122+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
123+
return None, [
124+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
125+
for v in value
126+
]
127+
128+
def as_mql(self, compiler, connection):
129+
# Querying a subfield within the array elements (via nested
130+
# KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over
131+
# the array and applying $in on the subfield.
132+
lhs_mql = process_lhs(self, compiler, connection)
133+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
134+
values = process_rhs(self, compiler, connection)
135+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
136+
inner_lhs_mql, values
137+
)
138+
return {"$anyElementTrue": lhs_mql}
139+
140+
141+
@_EmbeddedModelArrayOutputField.register_lookup
142+
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
143+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
144+
# This pipeline is adapted from that of ArrayField, because the
145+
# structure of EmbeddedModelArrayField on the RHS behaves similar to
146+
# ArrayField.
147+
return [
148+
{
149+
"$facet": {
150+
"gathered_data": [
151+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
152+
# To concatenate all the values from the RHS subquery,
153+
# use an $unwind followed by a $group.
154+
{
155+
"$unwind": "$tmp_name",
156+
},
157+
# The $group stage collects values into an array using
158+
# $addToSet. The use of {_id: null} results in a
159+
# single grouped array. However, because arrays from
160+
# multiple documents are aggregated, the result is a
161+
# list of lists.
162+
{
163+
"$group": {
164+
"_id": None,
165+
"tmp_name": {"$addToSet": "$tmp_name"},
166+
}
167+
},
168+
]
169+
}
170+
},
171+
{
172+
"$project": {
173+
field_name: {
174+
"$ifNull": [
175+
{
176+
"$getField": {
177+
"input": {"$arrayElemAt": ["$gathered_data", 0]},
178+
"field": "tmp_name",
179+
}
180+
},
181+
[],
182+
]
183+
}
184+
}
185+
},
186+
]
187+
188+
189+
@_EmbeddedModelArrayOutputField.register_lookup
190+
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
191+
pass
192+
193+
194+
@_EmbeddedModelArrayOutputField.register_lookup
195+
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
196+
get_db_prep_lookup_value_is_iterable = False
197+
198+
199+
@_EmbeddedModelArrayOutputField.register_lookup
200+
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
201+
pass
202+
203+
204+
@_EmbeddedModelArrayOutputField.register_lookup
205+
class EmbeddedModelArrayFieldGreaterThanOrEqual(
206+
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
207+
):
208+
pass
209+
210+
211+
@_EmbeddedModelArrayOutputField.register_lookup
212+
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
213+
pass
214+
215+
216+
@_EmbeddedModelArrayOutputField.register_lookup
217+
class EmbeddedModelArrayFieldLessThanOrEqual(
218+
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
219+
):
220+
pass
221+
222+
223+
class KeyTransform(Transform):
224+
def __init__(self, key_name, array_field, *args, **kwargs):
225+
super().__init__(*args, **kwargs)
226+
self.array_field = array_field
227+
self.key_name = key_name
228+
# Lookups iterate over the array of embedded models. A virtual column
229+
# of the queried field's type represents each element.
230+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
231+
column_name = f"$item.{key_name}"
232+
column_target.db_column = column_name
233+
column_target.set_attributes_from_name(column_name)
234+
self._lhs = Col(None, column_target)
235+
self._sub_transform = None
236+
237+
def __call__(self, this, *args, **kwargs):
238+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
239+
return self
240+
241+
def get_lookup(self, name):
242+
return self.output_field.get_lookup(name)
243+
244+
def get_transform(self, name):
245+
"""
246+
Validate that `name` is either a field of an embedded model or am
247+
allowed lookup on an embedded model's field.
248+
"""
249+
# Once the sub-lhs is a transform, all the filters are applied over it.
250+
# Otherwise get the transform from the nested embedded model field.
251+
if transform := self._lhs.get_transform(name):
252+
if isinstance(transform, KeyTransformFactory):
253+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
254+
self._sub_transform = transform
255+
return self
256+
output_field = self._lhs.output_field
257+
# The lookup must be allowed AND a valid lookup for the field.
258+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
259+
set(output_field.get_lookups())
260+
)
261+
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
262+
if suggested_lookups:
263+
suggested_lookups = " or ".join(suggested_lookups)
264+
suggestion = f", perhaps you meant {suggested_lookups}?"
265+
else:
266+
suggestion = ""
267+
raise FieldDoesNotExist(
268+
f"Unsupported lookup '{name}' for "
269+
f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'"
270+
f"{suggestion}"
271+
)
272+
273+
def as_mql(self, compiler, connection):
274+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
275+
lhs_mql = process_lhs(self, compiler, connection)
276+
return {
277+
"$ifNull": [
278+
{
279+
"$map": {
280+
"input": lhs_mql,
281+
"as": "item",
282+
"in": inner_lhs_mql,
283+
}
284+
},
285+
[],
286+
]
287+
}
288+
289+
@property
290+
def output_field(self):
291+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
292+
293+
294+
class KeyTransformFactory:
295+
def __init__(self, key_name, base_field):
296+
self.key_name = key_name
297+
self.base_field = base_field
298+
299+
def __call__(self, *args, **kwargs):
300+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

django_mongodb_backend/operations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def get_db_converters(self, expression):
124124
converters.append(self.convert_jsonfield_value)
125125
elif internal_type == "PolymorphicEmbeddedModelField":
126126
converters.append(self.convert_polymorphicembeddedmodelfield_value)
127+
elif internal_type == "PolymorphicEmbeddedModelArrayField":
128+
converters.extend(
129+
[
130+
self._get_arrayfield_converter(converter)
131+
for converter in self.get_db_converters(
132+
Expression(output_field=expression.output_field.base_field)
133+
)
134+
]
135+
)
127136
elif internal_type == "TimeField":
128137
# Trunc(... output_field="TimeField") values must remain datetime
129138
# until Trunc.convert_value() so they can be converted from UTC

tests/model_fields_/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
EmbeddedModelArrayField,
88
EmbeddedModelField,
99
ObjectIdField,
10+
PolymorphicEmbeddedModelArrayField,
1011
PolymorphicEmbeddedModelField,
1112
)
1213
from django_mongodb_backend.models import EmbeddedModel
@@ -251,3 +252,12 @@ class Cat(EmbeddedModel):
251252

252253
def __str__(self):
253254
return self.name
255+
256+
257+
# PolymorphicEmbeddedModelArrayField
258+
class Owner(models.Model):
259+
name = models.CharField(max_length=100)
260+
pets = PolymorphicEmbeddedModelArrayField(("Dog", "Cat"), blank=True, null=True)
261+
262+
def __str__(self):
263+
return self.name

0 commit comments

Comments
 (0)