1- from django .db .models .aggregates import Aggregate , Count , StdDev , Variance
1+ from django .core .exceptions import EmptyResultSet , FullResultSet
2+ from django .db import NotSupportedError
3+ from django .db .models .aggregates import (
4+ Aggregate ,
5+ Count ,
6+ StdDev ,
7+ StringAgg ,
8+ Variance ,
9+ )
210from django .db .models .expressions import Case , Value , When
311from django .db .models .lookups import IsNull
412from django .db .models .sql .where import WhereNode
1119
1220def aggregate (self , compiler , connection , operator = None , resolve_inner_expression = False ):
1321 agg_expression , * _ = self .get_source_expressions ()
14- if self .filter :
15- agg_expression = Case (
16- When (self .filter , then = agg_expression ),
17- # Skip rows that don't meet the criteria.
18- default = Remove (),
19- )
20- lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
22+ if self .filter is not None :
23+ try :
24+ lhs_mql = self .filter .as_mql (compiler , connection , as_expr = True )
25+ except NotSupportedError :
26+ # Generate a CASE statement for this aggregate.
27+ agg_expression = Case (
28+ When (self .filter .condition , then = agg_expression ),
29+ # Skip rows that don't meet the criteria.
30+ default = Remove (),
31+ )
32+ lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
33+ except FullResultSet :
34+ lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
35+ except EmptyResultSet :
36+ lhs_mql = Value (None ).as_mql (compiler , connection , as_expr = True )
37+ else :
38+ lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
2139 if resolve_inner_expression :
2240 return lhs_mql
2341 operator = operator or MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
@@ -34,16 +52,29 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3452 if not self .distinct or resolve_inner_expression :
3553 conditions = [IsNull (agg_expression , False )]
3654 if self .filter :
37- conditions .append (self .filter )
38- inner_expression = Case (
39- When (WhereNode (conditions ), then = agg_expression if self .distinct else Value (1 )),
40- # Skip rows that don't meet the criteria.
41- default = Remove (),
42- )
43- inner_expression = inner_expression .as_mql (compiler , connection , as_expr = True )
55+ try :
56+ inner_expression = self .filter .as_mql (compiler , connection , as_expr = True )
57+ except NotSupportedError :
58+ conditions .append (self .filter .condition )
59+ condition = When (
60+ WhereNode (conditions ),
61+ then = agg_expression if self .distinct else Value (1 ),
62+ )
63+ inner_expression = Case (condition )
64+ except FullResultSet :
65+ inner_expression = agg_expression if self .distinct else Value (1 )
66+ except EmptyResultSet :
67+ inner_expression = Remove () if self .distinct else Value (0 )
68+ else :
69+ inner_expression = Case (
70+ When (WhereNode (conditions ), then = agg_expression if self .distinct else Value (1 )),
71+ # Skip rows that don't meet the criteria.
72+ default = Remove (),
73+ )
74+ lhs_mql = inner_expression .as_mql (compiler , connection , as_expr = True )
4475 if resolve_inner_expression :
45- return inner_expression
46- return {"$sum" : inner_expression }
76+ return lhs_mql
77+ return {"$sum" : lhs_mql }
4778 # If distinct=True or resolve_inner_expression=False, sum the size of the
4879 # set.
4980 return {"$size" : agg_expression .as_mql (compiler , connection , as_expr = True )}
@@ -57,8 +88,13 @@ def stddev_variance(self, compiler, connection):
5788 return aggregate (self , compiler , connection , operator = operator )
5889
5990
91+ def string_agg (self , compiler , connection ): # noqa: ARG001
92+ raise NotSupportedError ("StringAgg is not supported." )
93+
94+
6095def register_aggregates ():
6196 Aggregate .as_mql_expr = aggregate
6297 Count .as_mql_expr = count
6398 StdDev .as_mql_expr = stddev_variance
99+ StringAgg .as_mql_expr = string_agg
64100 Variance .as_mql_expr = stddev_variance
0 commit comments