Skip to content

Commit 51f0374

Browse files
timgrahamWaVEV
authored andcommitted
Add AggregateFilter, StringAgg.as_mql()
django/django@4b977a5
1 parent a2d088c commit 51f0374

File tree

1 file changed

+53
-17
lines changed

1 file changed

+53
-17
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
1+
from django.core.exceptions import EmptyResultSet, FullResultSet
2+
from django.db import NotSupportedError
3+
from django.db.models.aggregates import (
4+
Aggregate,
5+
Count,
6+
StdDev,
7+
StringAgg,
8+
Variance,
9+
)
210
from django.db.models.expressions import Case, Value, When
311
from django.db.models.lookups import IsNull
412
from django.db.models.sql.where import WhereNode
@@ -11,13 +19,23 @@
1119

1220
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
1321
agg_expression, *_ = self.get_source_expressions()
14-
if self.filter:
15-
agg_expression = Case(
16-
When(self.filter, then=agg_expression),
17-
# Skip rows that don't meet the criteria.
18-
default=Remove(),
19-
)
20-
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
22+
if self.filter is not None:
23+
try:
24+
lhs_mql = self.filter.as_mql(compiler, connection, as_expr=True)
25+
except NotSupportedError:
26+
# Generate a CASE statement for this aggregate.
27+
agg_expression = Case(
28+
When(self.filter.condition, then=agg_expression),
29+
# Skip rows that don't meet the criteria.
30+
default=Remove(),
31+
)
32+
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
33+
except FullResultSet:
34+
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
35+
except EmptyResultSet:
36+
lhs_mql = Value(None).as_mql(compiler, connection, as_expr=True)
37+
else:
38+
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
2139
if resolve_inner_expression:
2240
return lhs_mql
2341
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
@@ -34,16 +52,29 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3452
if not self.distinct or resolve_inner_expression:
3553
conditions = [IsNull(agg_expression, False)]
3654
if self.filter:
37-
conditions.append(self.filter)
38-
inner_expression = Case(
39-
When(WhereNode(conditions), then=agg_expression if self.distinct else Value(1)),
40-
# Skip rows that don't meet the criteria.
41-
default=Remove(),
42-
)
43-
inner_expression = inner_expression.as_mql(compiler, connection, as_expr=True)
55+
try:
56+
inner_expression = self.filter.as_mql(compiler, connection, as_expr=True)
57+
except NotSupportedError:
58+
conditions.append(self.filter.condition)
59+
condition = When(
60+
WhereNode(conditions),
61+
then=agg_expression if self.distinct else Value(1),
62+
)
63+
inner_expression = Case(condition)
64+
except FullResultSet:
65+
inner_expression = agg_expression if self.distinct else Value(1)
66+
except EmptyResultSet:
67+
inner_expression = Remove() if self.distinct else Value(0)
68+
else:
69+
inner_expression = Case(
70+
When(WhereNode(conditions), then=agg_expression if self.distinct else Value(1)),
71+
# Skip rows that don't meet the criteria.
72+
default=Remove(),
73+
)
74+
lhs_mql = inner_expression.as_mql(compiler, connection, as_expr=True)
4475
if resolve_inner_expression:
45-
return inner_expression
46-
return {"$sum": inner_expression}
76+
return lhs_mql
77+
return {"$sum": lhs_mql}
4778
# If distinct=True or resolve_inner_expression=False, sum the size of the
4879
# set.
4980
return {"$size": agg_expression.as_mql(compiler, connection, as_expr=True)}
@@ -57,8 +88,13 @@ def stddev_variance(self, compiler, connection):
5788
return aggregate(self, compiler, connection, operator=operator)
5889

5990

91+
def string_agg(self, compiler, connection): # noqa: ARG001
92+
raise NotSupportedError("StringAgg is not supported.")
93+
94+
6095
def register_aggregates():
6196
Aggregate.as_mql_expr = aggregate
6297
Count.as_mql_expr = count
6398
StdDev.as_mql_expr = stddev_variance
99+
StringAgg.as_mql_expr = string_agg
64100
Variance.as_mql_expr = stddev_variance

0 commit comments

Comments
 (0)