diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index fb41ce4fc..1262f14b4 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -1,23 +1,23 @@ from django.db.models.aggregates import Aggregate, Count, StdDev, Variance from django.db.models.expressions import Case, Value, When from django.db.models.lookups import IsNull +from django.db.models.sql.where import WhereNode -from .query_utils import process_lhs +from django_mongodb_backend.expressions import Remove # Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower(). MONGO_AGGREGATIONS = {Count: "sum"} def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False): + agg_expression, *_ = self.get_source_expressions() if self.filter: - node = self.copy() - node.filter = None - source_expressions = node.get_source_expressions() - condition = When(self.filter, then=source_expressions[0]) - node.set_source_expressions([Case(condition), *source_expressions[1:]]) - else: - node = self - lhs_mql = process_lhs(node, compiler, connection, as_expr=True) + agg_expression = Case( + When(self.filter, then=agg_expression), + # Skip rows that don't meet the criteria. + default=Remove(), + ) + lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True) if resolve_inner_expression: return lhs_mql operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower()) @@ -30,31 +30,23 @@ def count(self, compiler, connection, resolve_inner_expression=False): value. This is used to count different elements, so the inner values are returned to be pushed into a set. """ + agg_expression, *_ = self.get_source_expressions() if not self.distinct or resolve_inner_expression: + conditions = [IsNull(agg_expression, False)] if self.filter: - node = self.copy() - node.filter = None - source_expressions = node.get_source_expressions() - condition = When( - self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1))) - ) - node.set_source_expressions([Case(condition), *source_expressions[1:]]) - inner_expression = process_lhs(node, compiler, connection, as_expr=True) - else: - lhs_mql = process_lhs(self, compiler, connection, as_expr=True) - null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]} - inner_expression = { - "$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1} - } + conditions.append(self.filter) + inner_expression = Case( + When(WhereNode(conditions), then=agg_expression if self.distinct else Value(1)), + # Skip rows that don't meet the criteria. + default=Remove(), + ) + inner_expression = inner_expression.as_mql(compiler, connection, as_expr=True) if resolve_inner_expression: return inner_expression return {"$sum": inner_expression} # If distinct=True or resolve_inner_expression=False, sum the size of the # set. - lhs_mql = process_lhs(self, compiler, connection, as_expr=True) - # None shouldn't be counted, so subtract 1 if it's present. - exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} - return {"$add": [{"$size": lhs_mql}, exits_null]} + return {"$size": agg_expression.as_mql(compiler, connection, as_expr=True)} def stddev_variance(self, compiler, connection): diff --git a/django_mongodb_backend/expressions/__init__.py b/django_mongodb_backend/expressions/__init__.py index 46fd0b018..cd92a7fe1 100644 --- a/django_mongodb_backend/expressions/__init__.py +++ b/django_mongodb_backend/expressions/__init__.py @@ -1,3 +1,4 @@ +from .expressions import Remove from .search import ( CombinedSearchExpression, CompoundExpression, @@ -21,6 +22,7 @@ __all__ = [ "CombinedSearchExpression", "CompoundExpression", + "Remove", "SearchAutocomplete", "SearchEquals", "SearchExists", diff --git a/django_mongodb_backend/expressions/expressions.py b/django_mongodb_backend/expressions/expressions.py new file mode 100644 index 000000000..4c86c2072 --- /dev/null +++ b/django_mongodb_backend/expressions/expressions.py @@ -0,0 +1,6 @@ +from django.db.models.expressions import Func + + +class Remove(Func): + def as_mql(self, compiler, connection, as_expr=False): + return "$$REMOVE" diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index ea892ec9f..ccb55d49b 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,6 +1,5 @@ from django.core.exceptions import FullResultSet from django.db.models import F -from django.db.models.aggregates import Aggregate from django.db.models.expressions import CombinedExpression, Func, Value from django.db.models.sql.query import Query @@ -20,8 +19,6 @@ def process_lhs(node, compiler, connection, as_expr=False): result.append(expr.as_mql(compiler, connection, as_expr=as_expr)) except FullResultSet: result.append(Value(True).as_mql(compiler, connection, as_expr=as_expr)) - if isinstance(node, Aggregate): - return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs):