Skip to content

Commit e548777

Browse files
committed
Fix unit test
1 parent 680585f commit e548777

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,11 @@ def as_mql(self, compiler, connection):
186186
key_transforms.insert(0, previous.key_name)
187187
previous = previous.lhs
188188
mql = previous.as_mql(compiler, connection)
189+
# transform = ".".join(key_transforms)
189190
for key in key_transforms:
190191
mql = {"$getField": {"input": mql, "field": key}}
191192
return mql
193+
# return f"{mql}.{transform}"
192194

193195
@property
194196
def output_field(self):

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,20 @@ def as_mql(self, compiler, connection):
6767
lhs_mql = process_lhs(self, compiler, connection)
6868
value = process_rhs(self, compiler, connection)
6969
if isinstance(self.lhs, Col | KeyTransform):
70+
if isinstance(self.lhs, Col):
71+
inner_lhs_mql = "$$item"
72+
else:
73+
lhs_mql, inner_lhs_mql = lhs_mql
7074
if isinstance(value, models.Model):
7175
value, emf_data = self.model_to_dict(value)
7276
# Get conditions for any nested EmbeddedModelFields.
73-
conditions = self.get_conditions({lhs_mql[1]: (value, emf_data)})
77+
conditions = self.get_conditions({inner_lhs_mql: (value, emf_data)})
7478
return {
7579
"$anyElementTrue": {
7680
"$ifNull": [
7781
{
7882
"$map": {
79-
"input": lhs_mql[0],
83+
"input": lhs_mql,
8084
"as": "item",
8185
"in": {"$and": conditions},
8286
}
@@ -90,9 +94,9 @@ def as_mql(self, compiler, connection):
9094
"$ifNull": [
9195
{
9296
"$map": {
93-
"input": lhs_mql[0],
97+
"input": lhs_mql,
9498
"as": "item",
95-
"in": {"$eq": [lhs_mql[1], value]},
99+
"in": {"$eq": [inner_lhs_mql, value]},
96100
}
97101
},
98102
[],
@@ -146,10 +150,7 @@ def get_transform(self, name):
146150
)
147151

148152
def as_mql(self, compiler, connection):
149-
if isinstance(self._lhs, Transform):
150-
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
151-
else:
152-
inner_lhs_mql = None
153+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
153154
lhs_mql = process_lhs(self, compiler, connection)
154155
return lhs_mql, inner_lhs_mql
155156

0 commit comments

Comments
 (0)