Skip to content

Commit aa18613

Browse files
committed
INTPYTHON-658 Add PolymorphicEmbeddedModelArrayField
1 parent 817d03a commit aa18613

File tree

6 files changed

+671
-0
lines changed

6 files changed

+671
-0
lines changed

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .json import register_json_field
77
from .objectid import ObjectIdField
88
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
9+
from .polymorphic_embedded_model_array import PolymorphicEmbeddedModelArrayField
910

1011
__all__ = [
1112
"register_fields",
@@ -15,6 +16,7 @@
1516
"ObjectIdAutoField",
1617
"ObjectIdField",
1718
"PolymorphicEmbeddedModelField",
19+
"PolymorphicEmbeddedModelArrayField",
1820
]
1921

2022

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import contextlib
2+
import difflib
3+
4+
from django.core.exceptions import FieldDoesNotExist
5+
from django.db.models import lookups
6+
from django.db.models.expressions import Col
7+
from django.db.models.fields.related import lazy_related_operation
8+
from django.db.models.lookups import Lookup, Transform
9+
10+
from ..query_utils import process_lhs, process_rhs
11+
from . import PolymorphicEmbeddedModelField
12+
from .array import ArrayField, ArrayLenTransform
13+
14+
15+
class PolymorphicEmbeddedModelArrayField(ArrayField):
16+
value_is_model_instance = True
17+
18+
def __init__(self, embedded_models, **kwargs):
19+
if "size" in kwargs:
20+
raise ValueError("PolymorphicEmbeddedModelArrayField does not support size.")
21+
kwargs["editable"] = False
22+
super().__init__(PolymorphicEmbeddedModelField(embedded_models), **kwargs)
23+
self.embedded_models = embedded_models
24+
25+
def contribute_to_class(self, cls, name, private_only=False, **kwargs):
26+
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
27+
28+
if not cls._meta.abstract:
29+
# If the embedded_model argument is a string, resolve it to the
30+
# actual model class.
31+
def _resolve_lookup(_, *resolved_models):
32+
self.embedded_models = resolved_models
33+
34+
lazy_related_operation(_resolve_lookup, cls, *self.embedded_models)
35+
36+
def deconstruct(self):
37+
name, path, args, kwargs = super().deconstruct()
38+
if path == (
39+
"django_mongodb_backend.fields.polymorphic_embedded_model_array."
40+
"PolymorphicEmbeddedModelArrayField"
41+
):
42+
path = "django_mongodb_backend.fields.PolymorphicEmbeddedModelArrayField"
43+
kwargs["embedded_models"] = self.embedded_models
44+
del kwargs["base_field"]
45+
del kwargs["editable"]
46+
return name, path, args, kwargs
47+
48+
def get_db_prep_value(self, value, connection, prepared=False):
49+
if isinstance(value, list | tuple):
50+
# Must call get_db_prep_save() rather than get_db_prep_value()
51+
# to transform model instances to dicts.
52+
return [self.base_field.get_db_prep_save(i, connection) for i in value]
53+
if value is not None:
54+
raise TypeError(
55+
f"Expected list of {self.embedded_models!r} instances, not {type(value)!r}."
56+
)
57+
return value
58+
59+
def formfield(self, **kwargs):
60+
raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.")
61+
62+
_get_model_from_label = PolymorphicEmbeddedModelField._get_model_from_label
63+
64+
def get_transform(self, name):
65+
transform = super().get_transform(name)
66+
if transform:
67+
return transform
68+
return KeyTransformFactory(name, self)
69+
70+
def _get_lookup(self, lookup_name):
71+
lookup = super()._get_lookup(lookup_name)
72+
if lookup is None or lookup is ArrayLenTransform:
73+
return lookup
74+
75+
class EmbeddedModelArrayFieldLookups(Lookup):
76+
def as_mql(self, compiler, connection):
77+
raise ValueError(
78+
"Lookups aren't supported on PolymorphicEmbeddedModelArrayField. "
79+
"Try querying one of its embedded fields instead."
80+
)
81+
82+
return EmbeddedModelArrayFieldLookups
83+
84+
85+
class _EmbeddedModelArrayOutputField(ArrayField):
86+
"""
87+
Represent the output of an EmbeddedModelArrayField when traversed in a
88+
query path.
89+
90+
This field is not meant to be used in model definitions. It exists solely
91+
to support query output resolution. When an EmbeddedModelArrayField is
92+
accessed in a query, the result should behave like an array of the embedded
93+
model's target type.
94+
95+
While it mimics ArrayField's lookup behavior, the way those lookups are
96+
resolved follows the semantics of EmbeddedModelArrayField rather than
97+
ArrayField.
98+
"""
99+
100+
ALLOWED_LOOKUPS = {
101+
"in",
102+
"exact",
103+
"iexact",
104+
"gt",
105+
"gte",
106+
"lt",
107+
"lte",
108+
}
109+
110+
def get_lookup(self, name):
111+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
112+
113+
114+
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
115+
def process_rhs(self, compiler, connection):
116+
value = self.rhs
117+
if not self.get_db_prep_lookup_value_is_iterable:
118+
value = [value]
119+
# Value must be serialized based on the query target. If querying a
120+
# subfield inside the array (i.e., a nested KeyTransform), use the
121+
# output field of the subfield. Otherwise, use the base field of the
122+
# array itself.
123+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
124+
return None, [
125+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
126+
for v in value
127+
]
128+
129+
def as_mql(self, compiler, connection):
130+
# Querying a subfield within the array elements (via nested
131+
# KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over
132+
# the array and applying $in on the subfield.
133+
lhs_mql = process_lhs(self, compiler, connection)
134+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
135+
values = process_rhs(self, compiler, connection)
136+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
137+
inner_lhs_mql, values
138+
)
139+
return {"$anyElementTrue": lhs_mql}
140+
141+
142+
@_EmbeddedModelArrayOutputField.register_lookup
143+
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
144+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
145+
# This pipeline is adapted from that of ArrayField, because the
146+
# structure of EmbeddedModelArrayField on the RHS behaves similar to
147+
# ArrayField.
148+
return [
149+
{
150+
"$facet": {
151+
"gathered_data": [
152+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
153+
# To concatenate all the values from the RHS subquery,
154+
# use an $unwind followed by a $group.
155+
{
156+
"$unwind": "$tmp_name",
157+
},
158+
# The $group stage collects values into an array using
159+
# $addToSet. The use of {_id: null} results in a
160+
# single grouped array. However, because arrays from
161+
# multiple documents are aggregated, the result is a
162+
# list of lists.
163+
{
164+
"$group": {
165+
"_id": None,
166+
"tmp_name": {"$addToSet": "$tmp_name"},
167+
}
168+
},
169+
]
170+
}
171+
},
172+
{
173+
"$project": {
174+
field_name: {
175+
"$ifNull": [
176+
{
177+
"$getField": {
178+
"input": {"$arrayElemAt": ["$gathered_data", 0]},
179+
"field": "tmp_name",
180+
}
181+
},
182+
[],
183+
]
184+
}
185+
}
186+
},
187+
]
188+
189+
190+
@_EmbeddedModelArrayOutputField.register_lookup
191+
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
192+
pass
193+
194+
195+
@_EmbeddedModelArrayOutputField.register_lookup
196+
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
197+
get_db_prep_lookup_value_is_iterable = False
198+
199+
200+
@_EmbeddedModelArrayOutputField.register_lookup
201+
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
202+
pass
203+
204+
205+
@_EmbeddedModelArrayOutputField.register_lookup
206+
class EmbeddedModelArrayFieldGreaterThanOrEqual(
207+
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
208+
):
209+
pass
210+
211+
212+
@_EmbeddedModelArrayOutputField.register_lookup
213+
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
214+
pass
215+
216+
217+
@_EmbeddedModelArrayOutputField.register_lookup
218+
class EmbeddedModelArrayFieldLessThanOrEqual(
219+
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
220+
):
221+
pass
222+
223+
224+
class KeyTransform(Transform):
225+
def __init__(self, key_name, array_field, *args, **kwargs):
226+
super().__init__(*args, **kwargs)
227+
self.array_field = array_field
228+
self.key_name = key_name
229+
# Lookups iterate over the array of embedded models. A virtual column
230+
# of the queried field's type represents each element.
231+
for model in array_field.base_field.embedded_models:
232+
with contextlib.suppress(FieldDoesNotExist):
233+
field = model._meta.get_field(key_name)
234+
break
235+
else:
236+
raise FieldDoesNotExist(
237+
f"The models of field '{array_field.name}' have no field named '{key_name}'."
238+
)
239+
column_target = field.clone()
240+
column_name = f"$item.{key_name}"
241+
column_target.db_column = column_name
242+
column_target.set_attributes_from_name(column_name)
243+
self._lhs = Col(None, column_target)
244+
self._sub_transform = None
245+
246+
def __call__(self, this, *args, **kwargs):
247+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
248+
return self
249+
250+
def get_lookup(self, name):
251+
return self.output_field.get_lookup(name)
252+
253+
def get_transform(self, name):
254+
"""
255+
Validate that `name` is either a field of an embedded model or am
256+
allowed lookup on an embedded model's field.
257+
"""
258+
# Once the sub-lhs is a transform, all the filters are applied over it.
259+
# Otherwise get the transform from the nested embedded model field.
260+
if transform := self._lhs.get_transform(name):
261+
if isinstance(transform, KeyTransformFactory):
262+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
263+
self._sub_transform = transform
264+
return self
265+
output_field = self._lhs.output_field
266+
# The lookup must be allowed AND a valid lookup for the field.
267+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
268+
set(output_field.get_lookups())
269+
)
270+
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
271+
if suggested_lookups:
272+
suggested_lookups = " or ".join(suggested_lookups)
273+
suggestion = f", perhaps you meant {suggested_lookups}?"
274+
else:
275+
suggestion = ""
276+
raise FieldDoesNotExist(
277+
f"Unsupported lookup '{name}' for "
278+
f"PolymorphicEmbeddedModelArrayField of '{output_field.__class__.__name__}'"
279+
f"{suggestion}"
280+
)
281+
282+
def as_mql(self, compiler, connection):
283+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
284+
lhs_mql = process_lhs(self, compiler, connection)
285+
return {
286+
"$ifNull": [
287+
{
288+
"$map": {
289+
"input": lhs_mql,
290+
"as": "item",
291+
"in": inner_lhs_mql,
292+
}
293+
},
294+
[],
295+
]
296+
}
297+
298+
@property
299+
def output_field(self):
300+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
301+
302+
303+
class KeyTransformFactory:
304+
def __init__(self, key_name, base_field):
305+
self.key_name = key_name
306+
self.base_field = base_field
307+
308+
def __call__(self, *args, **kwargs):
309+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

django_mongodb_backend/operations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def get_db_converters(self, expression):
124124
converters.append(self.convert_jsonfield_value)
125125
elif internal_type == "PolymorphicEmbeddedModelField":
126126
converters.append(self.convert_polymorphicembeddedmodelfield_value)
127+
elif internal_type == "PolymorphicEmbeddedModelArrayField":
128+
converters.extend(
129+
[
130+
self._get_arrayfield_converter(converter)
131+
for converter in self.get_db_converters(
132+
Expression(output_field=expression.output_field.base_field)
133+
)
134+
]
135+
)
127136
elif internal_type == "TimeField":
128137
# Trunc(... output_field="TimeField") values must remain datetime
129138
# until Trunc.convert_value() so they can be converted from UTC

docs/source/ref/models/fields.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,39 @@ These indexes use 0-based indexing.
343343
.. admonition:: Forms are not supported
344344

345345
``PolymorphicEmbeddedModelField``\s don't appear in model forms.
346+
347+
``PolymorphicEmbeddedModelArrayField``
348+
--------------------------------------
349+
350+
.. class:: PolymorphicEmbeddedModelArrayField(embedded_models, **kwargs)
351+
352+
.. versionadded:: 5.2.0b2
353+
354+
Similar to :class:`PolymorphicEmbeddedModelField`, but stores a **list** of
355+
models of type ``embedded_models`` rather than a single instance.
356+
357+
.. attribute:: embedded_models
358+
359+
This is a required argument that works just like
360+
:attr:`PolymorphicEmbeddedModelField.embedded_models`.
361+
362+
.. attribute:: max_size
363+
364+
This is an optional argument.
365+
366+
If passed, the list will have a maximum size as specified, validated
367+
by forms and model validation, but not enforced by the database.
368+
369+
See :ref:`the embedded model topic guide
370+
<polymorphic-embedded-model-array-field-example>` for more details and
371+
examples.
372+
373+
.. admonition:: Migrations support is limited
374+
375+
:djadmin:`makemigrations` does not yet detect changes to embedded models,
376+
nor does it create indexes or constraints for embedded models referenced
377+
by ``PolymorphicEmbeddedModelArrayField``.
378+
379+
.. admonition:: Forms are not supported
380+
381+
``PolymorphicEmbeddedModelArrayField``\s don't appear in model forms.

0 commit comments

Comments
 (0)