Skip to content

Commit 93776f7

Browse files
committed
POC: Manage sub array queries with a different transform path and EMF.
1 parent 89ea845 commit 93776f7

File tree

6 files changed

+388
-20
lines changed

6 files changed

+388
-20
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
326326
def as_mql(self, compiler, connection):
327327
lhs_mql = process_lhs(self, compiler, connection)
328328
value = process_rhs(self, compiler, connection)
329-
return {
330-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
331-
}
329+
return {"$and": [{"$isArray": lhs_mql}, {"$size": {"$setIntersection": [value, lhs_mql]}}]}
332330

333331

334332
@ArrayField.register_lookup
@@ -338,7 +336,7 @@ class ArrayLenTransform(Transform):
338336

339337
def as_mql(self, compiler, connection):
340338
lhs_mql = process_lhs(self, compiler, connection)
341-
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}}
339+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
342340

343341

344342
@ArrayField.register_lookup

django_mongodb_backend/fields/embedded_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,11 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189-
transforms = ".".join(key_transforms)
190-
return f"{mql}.{transforms}"
189+
# transform = ".".join(key_transforms)
190+
for key in key_transforms:
191+
mql = {"$getField": {"input": mql, "field": key}}
192+
return mql
193+
# return f"{mql}.{transform}"
191194

192195
@property
193196
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 189 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db import models
15
from django.db.models import Field
6+
from django.db.models.expressions import Col
7+
from django.db.models.lookups import Lookup, Transform
28

39
from .. import forms
10+
from ..query_utils import process_lhs, process_rhs
411
from . import EmbeddedModelField
512
from .array import ArrayField
6-
from .embedded_model import EMFExact
13+
from .embedded_model import EMFExact, EMFMixin
714

815

916
class EmbeddedModelArrayField(ArrayField):
@@ -47,17 +54,188 @@ def formfield(self, **kwargs):
4754
)
4855

4956
def get_transform(self, name):
50-
# TODO: ...
51-
return self.base_field.get_transform(name)
52-
# Copied from EmbedddedModelField -- customize?
53-
# transform = super().get_transform(name)
54-
# if transform:
55-
# return transform
56-
# field = self.embedded_model._meta.get_field(name)
57-
# return KeyTransformFactory(name, field)
57+
transform = super().get_transform(name)
58+
if transform:
59+
return transform
60+
return KeyTransformFactory(name, self)
5861

5962

6063
@EmbeddedModelArrayField.register_lookup
6164
class EMFArrayExact(EMFExact):
62-
# TODO
63-
pass
65+
def as_mql(self, compiler, connection):
66+
lhs_mql = process_lhs(self, compiler, connection)
67+
value = process_rhs(self, compiler, connection)
68+
if isinstance(self.lhs, KeyTransform):
69+
lhs_mql, inner_lhs_mql = lhs_mql
70+
else:
71+
inner_lhs_mql = "$$item"
72+
if isinstance(value, models.Model):
73+
value, emf_data = self.model_to_dict(value)
74+
# Get conditions for any nested EmbeddedModelFields.
75+
conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)})
76+
return {
77+
"$anyElementTrue": {
78+
"$ifNull": [
79+
{
80+
"$map": {
81+
"input": lhs_mql,
82+
"as": "item",
83+
"in": {"$and": conditions},
84+
}
85+
},
86+
[],
87+
]
88+
}
89+
}
90+
return {
91+
"$anyElementTrue": {
92+
"$ifNull": [
93+
{
94+
"$map": {
95+
"input": lhs_mql,
96+
"as": "item",
97+
"in": {"$eq": [inner_lhs_mql, value]},
98+
}
99+
},
100+
[],
101+
]
102+
}
103+
}
104+
105+
106+
@EmbeddedModelArrayField.register_lookup
107+
class ArrayOverlap(EMFMixin, Lookup):
108+
lookup_name = "overlap"
109+
get_db_prep_lookup_value_is_iterable = True
110+
111+
def process_rhs(self, compiler, connection):
112+
values = self.rhs
113+
if self.get_db_prep_lookup_value_is_iterable:
114+
values = [values]
115+
# Compute how to serialize each value based on the query target.
116+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
117+
# field of the subfield. Otherwise, use the base field of the array itself.
118+
if isinstance(self.lhs, KeyTransform):
119+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
120+
else:
121+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
122+
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
123+
124+
def as_mql(self, compiler, connection):
125+
lhs_mql = process_lhs(self, compiler, connection)
126+
values = process_rhs(self, compiler, connection)
127+
# Querying a subfield within the array elements (via nested KeyTransform).
128+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
129+
# `$in` on the subfield.
130+
if isinstance(self.lhs, KeyTransform):
131+
lhs_mql, inner_lhs_mql = lhs_mql
132+
return {
133+
"$anyElementTrue": {
134+
"$ifNull": [
135+
{
136+
"$map": {
137+
"input": lhs_mql,
138+
"as": "item",
139+
"in": {"$in": [inner_lhs_mql, values]},
140+
}
141+
},
142+
[],
143+
]
144+
}
145+
}
146+
conditions = []
147+
inner_lhs_mql = "$$item"
148+
# Querying full embedded documents in the array.
149+
# Builds `$or` conditions and maps them over the array to match any full document.
150+
for value in values:
151+
value, emf_data = self.model_to_dict(value)
152+
# Get conditions for any nested EmbeddedModelFields.
153+
conditions.append({"$and": self.get_conditions({inner_lhs_mql: (value, emf_data)})})
154+
return {
155+
"$anyElementTrue": {
156+
"$ifNull": [
157+
{
158+
"$map": {
159+
"input": lhs_mql,
160+
"as": "item",
161+
"in": {"$or": conditions},
162+
}
163+
},
164+
[],
165+
]
166+
}
167+
}
168+
169+
170+
class KeyTransform(Transform):
171+
# it should be different class than EMF keytransform even most of the methods are equal.
172+
def __init__(self, key_name, array_field, *args, **kwargs):
173+
super().__init__(*args, **kwargs)
174+
self.array_field = array_field
175+
self.key_name = key_name
176+
# The iteration items begins from the base_field, a virtual column with
177+
# base field output type is created.
178+
column_target = array_field.base_field.embedded_model._meta.get_field(key_name).clone()
179+
column_name = f"$item.{key_name}"
180+
column_target.db_column = column_name
181+
column_target.set_attributes_from_name(column_name)
182+
self._lhs = Col(None, column_target)
183+
self._sub_transform = None
184+
185+
def __call__(self, this, *args, **kwargs):
186+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
187+
return self
188+
189+
def get_lookup(self, name):
190+
return self.output_field.get_lookup(name)
191+
192+
def _get_missing_field_or_lookup_exception(self, lhs, name):
193+
suggested_lookups = difflib.get_close_matches(name, lhs.get_lookups())
194+
if suggested_lookups:
195+
suggested_lookups = " or ".join(suggested_lookups)
196+
suggestion = f", perhaps you meant {suggested_lookups}?"
197+
else:
198+
suggestion = "."
199+
raise FieldDoesNotExist(
200+
f"Unsupported lookup '{name}' for "
201+
f"{self.array_field.base_field.__class__.__name__} '{self.array_field.base_field.name}'"
202+
f"{suggestion}"
203+
)
204+
205+
def get_transform(self, name):
206+
"""
207+
Validate that `name` is either a field of an embedded model or a
208+
lookup on an embedded model's field.
209+
"""
210+
# Once the sub lhs is a transform, all the filter are applied over it.
211+
transform = (
212+
self._lhs.get_transform(name)
213+
if isinstance(self._lhs, Transform)
214+
else self.array_field.base_field.embedded_model._meta.get_field(
215+
self.key_name
216+
).get_transform(name)
217+
)
218+
if transform:
219+
self._sub_transform = transform
220+
return self
221+
raise self._get_missing_field_or_lookup_exception(
222+
self._lhs if isinstance(self._lhs, Transform) else self.base_field, name
223+
)
224+
225+
def as_mql(self, compiler, connection):
226+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
227+
lhs_mql = process_lhs(self, compiler, connection)
228+
return lhs_mql, inner_lhs_mql
229+
230+
@property
231+
def output_field(self):
232+
return self.array_field
233+
234+
235+
class KeyTransformFactory:
236+
def __init__(self, key_name, base_field):
237+
self.key_name = key_name
238+
self.base_field = base_field
239+
240+
def __call__(self, *args, **kwargs):
241+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

tests/model_fields_/models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,37 @@ class Movie(models.Model):
165165

166166
def __str__(self):
167167
return self.title
168+
169+
170+
class RestorationRecord(EmbeddedModel):
171+
date = models.DateField()
172+
description = models.TextField()
173+
restored_by = models.CharField(max_length=255)
174+
175+
176+
class ArtifactDetail(EmbeddedModel):
177+
"""Details about a specific artifact."""
178+
179+
name = models.CharField(max_length=255)
180+
description = models.CharField(max_length=255)
181+
metadata = models.JSONField()
182+
restorations = EmbeddedModelArrayField(RestorationRecord, null=True)
183+
last_restoration = EmbeddedModelField(RestorationRecord, null=True)
184+
185+
186+
class ExhibitSection(EmbeddedModel):
187+
"""A section within an exhibit, containing multiple artifacts."""
188+
189+
section_number = models.IntegerField()
190+
artifacts = EmbeddedModelArrayField(ArtifactDetail, null=True)
191+
192+
193+
class MuseumExhibit(models.Model):
194+
"""An exhibit in the museum, composed of multiple sections."""
195+
196+
exhibit_name = models.CharField(max_length=255)
197+
sections = EmbeddedModelArrayField(ExhibitSection, null=True)
198+
main_section = EmbeddedModelField(ExhibitSection, null=True)
199+
200+
def __str__(self):
201+
return self.exhibit_name

0 commit comments

Comments
 (0)