Skip to content

Commit 9b13d3c

Browse files
committed
INTPYTHON-658 Add PolymorphicEmbeddedModelArrayField
1 parent 817d03a commit 9b13d3c

File tree

4 files changed

+707
-0
lines changed

4 files changed

+707
-0
lines changed

django_mongodb_backend/fields/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .json import register_json_field
77
from .objectid import ObjectIdField
88
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
9+
from .polymorphic_embedded_model_array import PolymorphicEmbeddedModelArrayField
910

1011
__all__ = [
1112
"register_fields",
@@ -15,6 +16,7 @@
1516
"ObjectIdAutoField",
1617
"ObjectIdField",
1718
"PolymorphicEmbeddedModelField",
19+
"PolymorphicEmbeddedModelArrayField",
1820
]
1921

2022

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import lookups
5+
from django.db.models.expressions import Col
6+
from django.db.models.fields.related import lazy_related_operation
7+
from django.db.models.lookups import Lookup, Transform
8+
9+
from ..query_utils import process_lhs, process_rhs
10+
from . import PolymorphicEmbeddedModelField
11+
from .array import ArrayField, ArrayLenTransform
12+
13+
14+
class PolymorphicEmbeddedModelArrayField(ArrayField):
15+
value_is_model_instance = True
16+
17+
def __init__(self, embedded_models, **kwargs):
18+
if "size" in kwargs:
19+
raise ValueError("PolymorphicEmbeddedModelArrayField does not support size.")
20+
kwargs["editable"] = False
21+
super().__init__(PolymorphicEmbeddedModelField(embedded_models), **kwargs)
22+
self.embedded_models = embedded_models
23+
24+
def contribute_to_class(self, cls, name, private_only=False, **kwargs):
25+
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
26+
27+
if not cls._meta.abstract:
28+
# If the embedded_model argument is a string, resolve it to the
29+
# actual model class.
30+
def _resolve_lookup(_, *resolved_models):
31+
self.embedded_models = resolved_models
32+
33+
lazy_related_operation(_resolve_lookup, cls, *self.embedded_models)
34+
35+
def deconstruct(self):
36+
name, path, args, kwargs = super().deconstruct()
37+
if path == (
38+
"django_mongodb_backend.fields.polymorphic_embedded_model_array."
39+
"PolymorphicEmbeddedModelArrayField"
40+
):
41+
path = "django_mongodb_backend.fields.PolymorphicEmbeddedModelArrayField"
42+
kwargs["embedded_models"] = self.embedded_models
43+
del kwargs["base_field"]
44+
del kwargs["editable"]
45+
return name, path, args, kwargs
46+
47+
def get_db_prep_value(self, value, connection, prepared=False):
48+
if isinstance(value, list | tuple):
49+
# Must call get_db_prep_save() rather than get_db_prep_value()
50+
# to transform model instances to dicts.
51+
return [self.base_field.get_db_prep_save(i, connection) for i in value]
52+
if value is not None:
53+
raise TypeError(
54+
f"Expected list of {self.embedded_models!r} instances, not {type(value)!r}."
55+
)
56+
return value
57+
58+
def formfield(self, **kwargs):
59+
raise NotImplementedError("PolymorphicEmbeddedModelField does not support forms.")
60+
61+
def get_transform(self, name):
62+
transform = super().get_transform(name)
63+
if transform:
64+
return transform
65+
return KeyTransformFactory(name, self)
66+
67+
def _get_lookup(self, lookup_name):
68+
lookup = super()._get_lookup(lookup_name)
69+
if lookup is None or lookup is ArrayLenTransform:
70+
return lookup
71+
72+
class EmbeddedModelArrayFieldLookups(Lookup):
73+
def as_mql(self, compiler, connection):
74+
raise ValueError(
75+
"Lookups aren't supported on PolymorphicEmbeddedModelArrayField. "
76+
"Try querying one of its embedded fields instead."
77+
)
78+
79+
return EmbeddedModelArrayFieldLookups
80+
81+
82+
class _EmbeddedModelArrayOutputField(ArrayField):
83+
"""
84+
Represent the output of an EmbeddedModelArrayField when traversed in a
85+
query path.
86+
87+
This field is not meant to be used in model definitions. It exists solely
88+
to support query output resolution. When an EmbeddedModelArrayField is
89+
accessed in a query, the result should behave like an array of the embedded
90+
model's target type.
91+
92+
While it mimics ArrayField's lookup behavior, the way those lookups are
93+
resolved follows the semantics of EmbeddedModelArrayField rather than
94+
ArrayField.
95+
"""
96+
97+
ALLOWED_LOOKUPS = {
98+
"in",
99+
"exact",
100+
"iexact",
101+
"gt",
102+
"gte",
103+
"lt",
104+
"lte",
105+
}
106+
107+
def get_lookup(self, name):
108+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
109+
110+
111+
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
112+
def process_rhs(self, compiler, connection):
113+
value = self.rhs
114+
if not self.get_db_prep_lookup_value_is_iterable:
115+
value = [value]
116+
# Value must be serialized based on the query target. If querying a
117+
# subfield inside the array (i.e., a nested KeyTransform), use the
118+
# output field of the subfield. Otherwise, use the base field of the
119+
# array itself.
120+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
121+
return None, [
122+
v if hasattr(v, "as_mql") else get_db_prep_value(v, connection, prepared=True)
123+
for v in value
124+
]
125+
126+
def as_mql(self, compiler, connection):
127+
# Querying a subfield within the array elements (via nested
128+
# KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over
129+
# the array and applying $in on the subfield.
130+
lhs_mql = process_lhs(self, compiler, connection)
131+
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
132+
values = process_rhs(self, compiler, connection)
133+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
134+
inner_lhs_mql, values
135+
)
136+
return {"$anyElementTrue": lhs_mql}
137+
138+
139+
@_EmbeddedModelArrayOutputField.register_lookup
140+
class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In):
141+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
142+
# This pipeline is adapted from that of ArrayField, because the
143+
# structure of EmbeddedModelArrayField on the RHS behaves similar to
144+
# ArrayField.
145+
return [
146+
{
147+
"$facet": {
148+
"gathered_data": [
149+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
150+
# To concatenate all the values from the RHS subquery,
151+
# use an $unwind followed by a $group.
152+
{
153+
"$unwind": "$tmp_name",
154+
},
155+
# The $group stage collects values into an array using
156+
# $addToSet. The use of {_id: null} results in a
157+
# single grouped array. However, because arrays from
158+
# multiple documents are aggregated, the result is a
159+
# list of lists.
160+
{
161+
"$group": {
162+
"_id": None,
163+
"tmp_name": {"$addToSet": "$tmp_name"},
164+
}
165+
},
166+
]
167+
}
168+
},
169+
{
170+
"$project": {
171+
field_name: {
172+
"$ifNull": [
173+
{
174+
"$getField": {
175+
"input": {"$arrayElemAt": ["$gathered_data", 0]},
176+
"field": "tmp_name",
177+
}
178+
},
179+
[],
180+
]
181+
}
182+
}
183+
},
184+
]
185+
186+
187+
@_EmbeddedModelArrayOutputField.register_lookup
188+
class EmbeddedModelArrayFieldExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.Exact):
189+
pass
190+
191+
192+
@_EmbeddedModelArrayOutputField.register_lookup
193+
class EmbeddedModelArrayFieldIExact(EmbeddedModelArrayFieldBuiltinLookup, lookups.IExact):
194+
get_db_prep_lookup_value_is_iterable = False
195+
196+
197+
@_EmbeddedModelArrayOutputField.register_lookup
198+
class EmbeddedModelArrayFieldGreaterThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThan):
199+
pass
200+
201+
202+
@_EmbeddedModelArrayOutputField.register_lookup
203+
class EmbeddedModelArrayFieldGreaterThanOrEqual(
204+
EmbeddedModelArrayFieldBuiltinLookup, lookups.GreaterThanOrEqual
205+
):
206+
pass
207+
208+
209+
@_EmbeddedModelArrayOutputField.register_lookup
210+
class EmbeddedModelArrayFieldLessThan(EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThan):
211+
pass
212+
213+
214+
@_EmbeddedModelArrayOutputField.register_lookup
215+
class EmbeddedModelArrayFieldLessThanOrEqual(
216+
EmbeddedModelArrayFieldBuiltinLookup, lookups.LessThanOrEqual
217+
):
218+
pass
219+
220+
221+
class KeyTransform(Transform):
222+
def __init__(self, key_name, array_field, *args, **kwargs):
223+
super().__init__(*args, **kwargs)
224+
self.array_field = array_field
225+
self.key_name = key_name
226+
# Lookups iterate over the array of embedded models. A virtual column
227+
# of the queried field's type represents each element.
228+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
229+
column_name = f"$item.{key_name}"
230+
column_target.db_column = column_name
231+
column_target.set_attributes_from_name(column_name)
232+
self._lhs = Col(None, column_target)
233+
self._sub_transform = None
234+
235+
def __call__(self, this, *args, **kwargs):
236+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
237+
return self
238+
239+
def get_lookup(self, name):
240+
return self.output_field.get_lookup(name)
241+
242+
def get_transform(self, name):
243+
"""
244+
Validate that `name` is either a field of an embedded model or am
245+
allowed lookup on an embedded model's field.
246+
"""
247+
# Once the sub-lhs is a transform, all the filters are applied over it.
248+
# Otherwise get the transform from the nested embedded model field.
249+
if transform := self._lhs.get_transform(name):
250+
if isinstance(transform, KeyTransformFactory):
251+
raise ValueError("Cannot perform multiple levels of array traversal in a query.")
252+
self._sub_transform = transform
253+
return self
254+
output_field = self._lhs.output_field
255+
# The lookup must be allowed AND a valid lookup for the field.
256+
allowed_lookups = self.output_field.ALLOWED_LOOKUPS.intersection(
257+
set(output_field.get_lookups())
258+
)
259+
suggested_lookups = difflib.get_close_matches(name, allowed_lookups)
260+
if suggested_lookups:
261+
suggested_lookups = " or ".join(suggested_lookups)
262+
suggestion = f", perhaps you meant {suggested_lookups}?"
263+
else:
264+
suggestion = ""
265+
raise FieldDoesNotExist(
266+
f"Unsupported lookup '{name}' for "
267+
f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'"
268+
f"{suggestion}"
269+
)
270+
271+
def as_mql(self, compiler, connection):
272+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
273+
lhs_mql = process_lhs(self, compiler, connection)
274+
return {
275+
"$ifNull": [
276+
{
277+
"$map": {
278+
"input": lhs_mql,
279+
"as": "item",
280+
"in": inner_lhs_mql,
281+
}
282+
},
283+
[],
284+
]
285+
}
286+
287+
@property
288+
def output_field(self):
289+
return _EmbeddedModelArrayOutputField(self._lhs.output_field)
290+
291+
292+
class KeyTransformFactory:
293+
def __init__(self, key_name, base_field):
294+
self.key_name = key_name
295+
self.base_field = base_field
296+
297+
def __call__(self, *args, **kwargs):
298+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
EmbeddedModelArrayField,
88
EmbeddedModelField,
99
ObjectIdField,
10+
PolymorphicEmbeddedModelArrayField,
1011
PolymorphicEmbeddedModelField,
1112
)
1213
from django_mongodb_backend.models import EmbeddedModel
@@ -251,3 +252,12 @@ class Cat(EmbeddedModel):
251252

252253
def __str__(self):
253254
return self.name
255+
256+
257+
# PolymorphicEmbeddedModelArrayField
258+
class Owner(models.Model):
259+
name = models.CharField(max_length=100)
260+
pets = PolymorphicEmbeddedModelArrayField(("Dog", "Cat"), blank=True, null=True)
261+
262+
def __str__(self):
263+
return self.name

0 commit comments

Comments
 (0)