1
- from django .db .models .aggregates import Aggregate , Count , StdDev , Variance
2
- from django .db .models .expressions import Case , Value , When
1
+ from django .db import NotSupportedError
2
+ from django .db .models .aggregates import (
3
+ Aggregate ,
4
+ AggregateFilter ,
5
+ Count ,
6
+ StdDev ,
7
+ StringAgg ,
8
+ Variance ,
9
+ )
10
+ from django .db .models .expressions import Case , Col , Value , When
3
11
from django .db .models .lookups import IsNull
4
12
5
13
from .query_utils import process_lhs
@@ -16,7 +24,11 @@ def aggregate(
16
24
resolve_inner_expression = False ,
17
25
** extra_context , # noqa: ARG001
18
26
):
19
- if self .filter :
27
+ # TODO: isinstance(self.filter, Col) works around failure of
28
+ # aggregation.tests.AggregateTestCase.test_distinct_on_aggregate. Is this
29
+ # correct?
30
+ if self .filter is not None and not isinstance (self .filter , Col ):
31
+ # Generate a CASE statement for this aggregate.
20
32
node = self .copy ()
21
33
node .filter = None
22
34
source_expressions = node .get_source_expressions ()
@@ -31,6 +43,10 @@ def aggregate(
31
43
return {f"${ operator } " : lhs_mql }
32
44
33
45
46
+ def aggregate_filter (self , compiler , connection , ** extra_context ):
47
+ return self .condition .as_mql (compiler , connection , ** extra_context )
48
+
49
+
34
50
def count (self , compiler , connection , resolve_inner_expression = False , ** extra_context ): # noqa: ARG001
35
51
"""
36
52
When resolve_inner_expression=True, return the MQL that resolves as a
@@ -72,8 +88,14 @@ def stddev_variance(self, compiler, connection, **extra_context):
72
88
return aggregate (self , compiler , connection , operator = operator , ** extra_context )
73
89
74
90
91
+ def string_agg (self , compiler , connection , ** extra_context ): # noqa: ARG001
92
+ raise NotSupportedError ("StringAgg is not supported." )
93
+
94
+
75
95
def register_aggregates ():
76
96
Aggregate .as_mql = aggregate
97
+ AggregateFilter .as_mql = aggregate_filter
77
98
Count .as_mql = count
78
99
StdDev .as_mql = stddev_variance
100
+ StringAgg .as_mql = string_agg
79
101
Variance .as_mql = stddev_variance
0 commit comments