Skip to content

Commit e2fda19

Browse files
committed
Refactor subquery wrapping pipeline
1 parent 4ea54df commit e2fda19

File tree

4 files changed

+85
-77
lines changed

4 files changed

+85
-77
lines changed

django_mongodb_backend/expressions.py

Lines changed: 9 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def order_by(self, compiler, connection):
9595
return self.expression.as_mql(compiler, connection)
9696

9797

98-
def query(self, compiler, connection, lookup_name=None):
98+
def query(self, compiler, connection, get_wrapping_pipeline=None):
9999
subquery_compiler = self.get_compiler(connection=connection)
100100
subquery_compiler.pre_sql_setup(with_col_aliases=False)
101101
field_name, expr = subquery_compiler.columns[0]
@@ -119,76 +119,10 @@ def query(self, compiler, connection, lookup_name=None):
119119
for col, i in subquery_compiler.column_indices.items()
120120
},
121121
}
122-
wrapping_result_pipeline = None
123-
# The result must be a list of values. The output is compressed with an
124-
# aggregation pipeline.
125-
if lookup_name in ("in", "range"):
126-
wrapping_result_pipeline = [
127-
{
128-
"$facet": {
129-
"group": [
130-
{
131-
"$group": {
132-
"_id": None,
133-
"tmp_name": {
134-
"$addToSet": expr.as_mql(subquery_compiler, connection)
135-
},
136-
}
137-
}
138-
]
139-
}
140-
},
141-
{
142-
"$project": {
143-
field_name: {
144-
"$ifNull": [
145-
{
146-
"$getField": {
147-
"input": {"$arrayElemAt": ["$group", 0]},
148-
"field": "tmp_name",
149-
}
150-
},
151-
[],
152-
]
153-
}
154-
}
155-
},
156-
]
157-
if lookup_name == "overlap":
158-
wrapping_result_pipeline = [
159-
{
160-
"$facet": {
161-
"group": [
162-
{"$project": {"tmp_name": expr.as_mql(subquery_compiler, connection)}},
163-
{
164-
"$unwind": "$tmp_name",
165-
},
166-
{
167-
"$group": {
168-
"_id": None,
169-
"tmp_name": {"$addToSet": "$tmp_name"},
170-
}
171-
},
172-
]
173-
}
174-
},
175-
{
176-
"$project": {
177-
field_name: {
178-
"$ifNull": [
179-
{
180-
"$getField": {
181-
"input": {"$arrayElemAt": ["$group", 0]},
182-
"field": "tmp_name",
183-
}
184-
},
185-
[],
186-
]
187-
}
188-
}
189-
},
190-
]
191-
if wrapping_result_pipeline:
122+
if get_wrapping_pipeline:
123+
wrapping_result_pipeline = get_wrapping_pipeline(
124+
subquery_compiler, connection, field_name, expr
125+
)
192126
# If the subquery is a combinator, wrap the result at the end of the
193127
# combinator pipeline...
194128
if subquery.query.combinator:
@@ -221,13 +155,13 @@ def star(self, compiler, connection): # noqa: ARG001
221155
return {"$literal": True}
222156

223157

224-
def subquery(self, compiler, connection, lookup_name=None):
225-
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)
158+
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
159+
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
226160

227161

228-
def exists(self, compiler, connection, lookup_name=None):
162+
def exists(self, compiler, connection, get_wrapping_pipeline=None):
229163
try:
230-
lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name)
164+
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
231165
except EmptyResultSet:
232166
return Value(False).as_mql(compiler, connection)
233167
return connection.mongo_operators["isnull"](lhs_mql, False)

django_mongodb_backend/fields/array.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,43 @@ class ArrayExact(ArrayRHSMixin, Exact):
266266
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
267267
lookup_name = "overlap"
268268

269+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr):
270+
# The result must be a list of values.
271+
# The output is compressed with an aggregation pipeline.
272+
return [
273+
{
274+
"$facet": {
275+
"group": [
276+
{"$project": {"tmp_name": expr.as_mql(compiler, connection)}},
277+
{
278+
"$unwind": "$tmp_name",
279+
},
280+
{
281+
"$group": {
282+
"_id": None,
283+
"tmp_name": {"$addToSet": "$tmp_name"},
284+
}
285+
},
286+
]
287+
}
288+
},
289+
{
290+
"$project": {
291+
field_name: {
292+
"$ifNull": [
293+
{
294+
"$getField": {
295+
"input": {"$arrayElemAt": ["$group", 0]},
296+
"field": "tmp_name",
297+
}
298+
},
299+
[],
300+
]
301+
}
302+
}
303+
},
304+
]
305+
269306
def as_mql(self, compiler, connection):
270307
lhs_mql = process_lhs(self, compiler, connection)
271308
value = process_rhs(self, compiler, connection)

django_mongodb_backend/lookups.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,40 @@ def in_(self, compiler, connection):
4545
return builtin_lookup(self, compiler, connection)
4646

4747

48+
def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001
49+
# The result must be a list of values.
50+
# The output is compressed with an aggregation pipeline.
51+
return [
52+
{
53+
"$facet": {
54+
"group": [
55+
{
56+
"$group": {
57+
"_id": None,
58+
"tmp_name": {"$addToSet": expr.as_mql(compiler, connection)},
59+
}
60+
}
61+
]
62+
}
63+
},
64+
{
65+
"$project": {
66+
field_name: {
67+
"$ifNull": [
68+
{
69+
"$getField": {
70+
"input": {"$arrayElemAt": ["$group", 0]},
71+
"field": "tmp_name",
72+
}
73+
},
74+
[],
75+
]
76+
}
77+
}
78+
},
79+
]
80+
81+
4882
def is_null(self, compiler, connection):
4983
if not isinstance(self.rhs, bool):
5084
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
@@ -97,6 +131,7 @@ def register_lookups():
97131
field_resolve_expression_parameter
98132
)
99133
In.as_mql = RelatedIn.as_mql = in_
134+
In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline
100135
IsNull.as_mql = is_null
101136
PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value
102137
UUIDTextMixin.as_mql = uuid_text_mixin

django_mongodb_backend/query_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ def process_lhs(node, compiler, connection):
2828
def process_rhs(node, compiler, connection):
2929
rhs = node.rhs
3030
if hasattr(rhs, "as_mql"):
31-
if getattr(rhs, "subquery", False):
32-
value = rhs.as_mql(compiler, connection, lookup_name=node.lookup_name)
31+
if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"):
32+
value = rhs.as_mql(
33+
compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline
34+
)
3335
else:
3436
value = rhs.as_mql(compiler, connection)
3537
else:

0 commit comments

Comments
 (0)