Skip to content

Commit 7ec8343

Browse files
committed
Add AggregateFilter, StringgAgg.as_mql() as per
django/django@4b977a5
1 parent 3f0e6c4 commit 7ec8343

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
2-
from django.db.models.expressions import Case, Value, When
1+
from django.db import NotSupportedError
2+
from django.db.models.aggregates import (
3+
Aggregate,
4+
AggregateFilter,
5+
Count,
6+
StdDev,
7+
StringAgg,
8+
Variance,
9+
)
10+
from django.db.models.expressions import Case, Col, Value, When
311
from django.db.models.lookups import IsNull
412

513
from .query_utils import process_lhs
@@ -16,7 +24,11 @@ def aggregate(
1624
resolve_inner_expression=False,
1725
**extra_context, # noqa: ARG001
1826
):
19-
if self.filter:
27+
# TODO: isinstance(self.filter, Col) works around failure of
28+
# aggregation.tests.AggregateTestCase.test_distinct_on_aggregate. Is this
29+
# correct?
30+
if self.filter is not None and not isinstance(self.filter, Col):
31+
# Generate a CASE statement for this aggregate.
2032
node = self.copy()
2133
node.filter = None
2234
source_expressions = node.get_source_expressions()
@@ -31,6 +43,10 @@ def aggregate(
3143
return {f"${operator}": lhs_mql}
3244

3345

46+
def aggregate_filter(self, compiler, connection, **extra_context):
47+
return self.condition.as_mql(compiler, connection, **extra_context)
48+
49+
3450
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
3551
"""
3652
When resolve_inner_expression=True, return the MQL that resolves as a
@@ -72,8 +88,14 @@ def stddev_variance(self, compiler, connection, **extra_context):
7288
return aggregate(self, compiler, connection, operator=operator, **extra_context)
7389

7490

91+
def string_agg(self, compiler, connection, **extra_context): # noqa: ARG001
92+
raise NotSupportedError("StringAgg is not supported.")
93+
94+
7595
def register_aggregates():
7696
Aggregate.as_mql = aggregate
97+
AggregateFilter.as_mql = aggregate_filter
7798
Count.as_mql = count
7899
StdDev.as_mql = stddev_variance
100+
StringAgg.as_mql = string_agg
79101
Variance.as_mql = stddev_variance

django_mongodb_backend/features.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures):
9999
# This backend overrides DatabaseCreation.create_test_db() so the
100100
# deprecation warnings stacklevel points to the wrong file.
101101
"backends.base.test_creation.TestDbCreationTests.test_serialize_deprecation",
102+
# StringAgg is not supported.
103+
"aggregation.tests.AggregateTestCase.test_distinct_on_stringagg",
104+
"aggregation.tests.AggregateTestCase.test_string_agg_escapes_delimiter",
105+
"aggregation.tests.AggregateTestCase.test_string_agg_filter",
106+
"aggregation.tests.AggregateTestCase.test_string_agg_filter_in_subquery",
107+
"aggregation.tests.AggregateTestCase.test_stringagg_default_value",
102108
}
103109
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
104110
_django_test_expected_failures_bitwise = {

0 commit comments

Comments
 (0)