Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions django_mongodb_backend/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
from django.db.models.expressions import Case, Value, When
from django.db.models.lookups import IsNull
from django.db.models.sql.where import WhereNode

from .query_utils import process_lhs
from django_mongodb_backend.expressions import Remove

# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
MONGO_AGGREGATIONS = {Count: "sum"}


def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
agg_expression, *_ = self.get_source_expressions()
if self.filter:
node = self.copy()
node.filter = None
source_expressions = node.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
node.set_source_expressions([Case(condition), *source_expressions[1:]])
else:
node = self
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
agg_expression = Case(
When(self.filter, then=agg_expression),
# Skip rows that don't meet the criteria.
default=Remove(),
)
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
if resolve_inner_expression:
return lhs_mql
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
Expand All @@ -30,31 +30,23 @@ def count(self, compiler, connection, resolve_inner_expression=False):
value. This is used to count different elements, so the inner values are
returned to be pushed into a set.
"""
agg_expression, *_ = self.get_source_expressions()
if not self.distinct or resolve_inner_expression:
conditions = [IsNull(agg_expression, False)]
if self.filter:
node = self.copy()
node.filter = None
source_expressions = node.get_source_expressions()
condition = When(
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
)
node.set_source_expressions([Case(condition), *source_expressions[1:]])
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
else:
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
inner_expression = {
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
}
conditions.append(self.filter)
inner_expression = Case(
When(WhereNode(conditions), then=agg_expression if self.distinct else Value(1)),
# Skip rows that don't meet the criteria.
default=Remove(),
)
inner_expression = inner_expression.as_mql(compiler, connection, as_expr=True)
if resolve_inner_expression:
return inner_expression
return {"$sum": inner_expression}
# If distinct=True or resolve_inner_expression=False, sum the size of the
# set.
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
# None shouldn't be counted, so subtract 1 if it's present.
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
return {"$add": [{"$size": lhs_mql}, exits_null]}
return {"$size": agg_expression.as_mql(compiler, connection, as_expr=True)}


def stddev_variance(self, compiler, connection):
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb_backend/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .expressions import Remove
from .search import (
CombinedSearchExpression,
CompoundExpression,
Expand All @@ -21,6 +22,7 @@
__all__ = [
"CombinedSearchExpression",
"CompoundExpression",
"Remove",
"SearchAutocomplete",
"SearchEquals",
"SearchExists",
Expand Down
6 changes: 6 additions & 0 deletions django_mongodb_backend/expressions/expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.db.models.expressions import Func


class Remove(Func):
def as_mql(self, compiler, connection, as_expr=False):
return "$$REMOVE"
3 changes: 0 additions & 3 deletions django_mongodb_backend/query_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from django.core.exceptions import FullResultSet
from django.db.models import F
from django.db.models.aggregates import Aggregate
from django.db.models.expressions import CombinedExpression, Func, Value
from django.db.models.sql.query import Query

Expand All @@ -20,8 +19,6 @@ def process_lhs(node, compiler, connection, as_expr=False):
result.append(expr.as_mql(compiler, connection, as_expr=as_expr))
except FullResultSet:
result.append(Value(True).as_mql(compiler, connection, as_expr=as_expr))
if isinstance(node, Aggregate):
return result[0]
return result
# node is a Transform with just one source expression, aliased as "lhs".
if is_direct_value(node.lhs):
Expand Down
Loading