Skip to content

Commit 680585f

Browse files
committed
POC: Manage sub array queries with a different transform path.
1 parent 121a15d commit 680585f

File tree

1 file changed

+51
-20
lines changed

1 file changed

+51
-20
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -70,31 +70,53 @@ def as_mql(self, compiler, connection):
7070
if isinstance(value, models.Model):
7171
value, emf_data = self.model_to_dict(value)
7272
# Get conditions for any nested EmbeddedModelFields.
73-
conditions = self.get_conditions({"$$item": (value, emf_data)})
73+
conditions = self.get_conditions({lhs_mql[1]: (value, emf_data)})
7474
return {
7575
"$anyElementTrue": {
76-
"$map": {"input": lhs_mql, "as": "item", "in": {"$and": conditions}}
76+
"$ifNull": [
77+
{
78+
"$map": {
79+
"input": lhs_mql[0],
80+
"as": "item",
81+
"in": {"$and": conditions},
82+
}
83+
},
84+
[],
85+
]
7786
}
7887
}
79-
lhs_mql = process_lhs(self.lhs, compiler, connection)
8088
return {
8189
"$anyElementTrue": {
82-
"$map": {
83-
"input": lhs_mql,
84-
"as": "item",
85-
"in": {"$eq": [f"$$item.{self.lhs.key_name}", value]},
86-
}
90+
"$ifNull": [
91+
{
92+
"$map": {
93+
"input": lhs_mql[0],
94+
"as": "item",
95+
"in": {"$eq": [lhs_mql[1], value]},
96+
}
97+
},
98+
[],
99+
]
87100
}
88101
}
89102
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
90103

91104

92105
class KeyTransform(Transform):
93106
# it should be different class than EMF keytransform even most of the methods are equal.
94-
def __init__(self, key_name, ref_field, *args, **kwargs):
107+
def __init__(self, key_name, base_field, *args, **kwargs):
95108
super().__init__(*args, **kwargs)
96-
self.key_name = str(key_name)
97-
self.ref_field = ref_field
109+
self.base_field = base_field
110+
# TODO: Need to create a column, will refactor this thing.
111+
column_target = base_field.clone()
112+
column_target.db_column = f"$item.{key_name}"
113+
column_target.set_attributes_from_name(f"$item.{key_name}")
114+
self._lhs = Col(None, column_target)
115+
self._sub_transform = None
116+
117+
def __call__(self, this, *args, **kwargs):
118+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
119+
return self
98120

99121
def get_lookup(self, name):
100122
return self.output_field.get_lookup(name)
@@ -104,33 +126,42 @@ def get_transform(self, name):
104126
Validate that `name` is either a field of an embedded model or a
105127
lookup on an embedded model's field.
106128
"""
107-
if transform := self.ref_field.get_transform(name):
108-
return transform
109-
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
129+
if isinstance(self._lhs, Transform):
130+
transform = self._lhs.get_transform(name)
131+
else:
132+
transform = self.base_field.get_transform(name)
133+
if transform:
134+
self._sub_transform = transform
135+
return self
136+
suggested_lookups = difflib.get_close_matches(name, self.base_field.get_lookups())
110137
if suggested_lookups:
111138
suggested_lookups = " or ".join(suggested_lookups)
112139
suggestion = f", perhaps you meant {suggested_lookups}?"
113140
else:
114141
suggestion = "."
115142
raise FieldDoesNotExist(
116143
f"Unsupported lookup '{name}' for "
117-
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
144+
f"{self.base_field.__class__.__name__} '{self.base_field.name}'"
118145
f"{suggestion}"
119146
)
120147

121148
def as_mql(self, compiler, connection):
149+
if isinstance(self._lhs, Transform):
150+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
151+
else:
152+
inner_lhs_mql = None
122153
lhs_mql = process_lhs(self, compiler, connection)
123-
return f"{lhs_mql}.{self.key_name}"
154+
return lhs_mql, inner_lhs_mql
124155

125156
@property
126157
def output_field(self):
127-
return EmbeddedModelArrayField(self.ref_field)
158+
return EmbeddedModelArrayField(self.base_field)
128159

129160

130161
class KeyTransformFactory:
131-
def __init__(self, key_name, ref_field):
162+
def __init__(self, key_name, base_field):
132163
self.key_name = key_name
133-
self.ref_field = ref_field
164+
self.base_field = base_field
134165

135166
def __call__(self, *args, **kwargs):
136-
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
167+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

0 commit comments

Comments
 (0)