Skip to content

Commit 3c5936c

Browse files
authored
Merge branch 'main' into mdb-8-auth-ssl-tests
2 parents 1ec9ce9 + 37af20e commit 3c5936c

File tree

4 files changed

+125
-14
lines changed

4 files changed

+125
-14
lines changed

django_mongodb_backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _get_column_from_expression(self, expr, alias):
5353
Create a column named `alias` from the given expression to hold the
5454
aggregate value.
5555
"""
56-
column_target = expr.output_field.__class__()
56+
column_target = expr.output_field.clone()
5757
column_target.db_column = alias
5858
column_target.set_attributes_from_name(alias)
5959
return Col(self.collection_name, column_target)
@@ -81,7 +81,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
8181
alias = (
8282
f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target
8383
)
84-
column_target = sub_expr.output_field.__class__()
84+
column_target = sub_expr.output_field.clone()
8585
column_target.db_column = alias
8686
column_target.set_attributes_from_name(alias)
8787
inner_column = Col(self.collection_name, column_target)

django_mongodb_backend/fields/embedded_model.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import difflib
2+
13
from django.core import checks
4+
from django.core.exceptions import FieldDoesNotExist
25
from django.db import models
36
from django.db.models.fields.related import lazy_related_operation
47
from django.db.models.lookups import Transform
@@ -123,7 +126,8 @@ def get_transform(self, name):
123126
transform = super().get_transform(name)
124127
if transform:
125128
return transform
126-
return KeyTransformFactory(name)
129+
field = self.embedded_model._meta.get_field(name)
130+
return KeyTransformFactory(name, field)
127131

128132
def validate(self, value, model_instance):
129133
super().validate(value, model_instance)
@@ -145,9 +149,36 @@ def formfield(self, **kwargs):
145149

146150

147151
class KeyTransform(Transform):
148-
def __init__(self, key_name, *args, **kwargs):
152+
def __init__(self, key_name, ref_field, *args, **kwargs):
149153
super().__init__(*args, **kwargs)
150154
self.key_name = str(key_name)
155+
self.ref_field = ref_field
156+
157+
def get_transform(self, name):
158+
"""
159+
Validate that `name` is either a field of an embedded model or a
160+
lookup on an embedded model's field.
161+
"""
162+
result = None
163+
if isinstance(self.ref_field, EmbeddedModelField):
164+
opts = self.ref_field.embedded_model._meta
165+
new_field = opts.get_field(name)
166+
result = KeyTransformFactory(name, new_field)
167+
else:
168+
if self.ref_field.get_transform(name) is None:
169+
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
170+
if suggested_lookups:
171+
suggested_lookups = " or ".join(suggested_lookups)
172+
suggestion = f", perhaps you meant {suggested_lookups}?"
173+
else:
174+
suggestion = "."
175+
raise FieldDoesNotExist(
176+
f"Unsupported lookup '{name}' for "
177+
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
178+
f"{suggestion}"
179+
)
180+
result = KeyTransformFactory(name, self.ref_field)
181+
return result
151182

152183
def preprocess_lhs(self, compiler, connection):
153184
key_transforms = [self.key_name]
@@ -165,8 +196,9 @@ def as_mql(self, compiler, connection):
165196

166197

167198
class KeyTransformFactory:
168-
def __init__(self, key_name):
199+
def __init__(self, key_name, ref_field):
169200
self.key_name = key_name
201+
self.ref_field = ref_field
170202

171203
def __call__(self, *args, **kwargs):
172-
return KeyTransform(self.key_name, *args, **kwargs)
204+
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)

tests/model_fields_/test_embedded_model.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from django.core.exceptions import ValidationError
1+
import operator
2+
3+
from django.core.exceptions import FieldDoesNotExist, ValidationError
24
from django.db import models
5+
from django.db.models import ExpressionWrapper, F, Max, Sum
36
from django.test import SimpleTestCase, TestCase
47
from django.test.utils import isolate_apps
58

@@ -13,6 +16,7 @@
1316
Data,
1417
Holder,
1518
)
19+
from .utils import truncate_ms
1620

1721

1822
class MethodTests(SimpleTestCase):
@@ -38,10 +42,6 @@ def test_validate(self):
3842

3943

4044
class ModelTests(TestCase):
41-
def truncate_ms(self, value):
42-
"""Truncate microseconds to milliseconds as supported by MongoDB."""
43-
return value.replace(microsecond=(value.microsecond // 1000) * 1000)
44-
4545
def test_save_load(self):
4646
Holder.objects.create(data=Data(integer="5"))
4747
obj = Holder.objects.get()
@@ -64,12 +64,12 @@ def test_save_load_null(self):
6464
def test_pre_save(self):
6565
"""Field.pre_save() is called on embedded model fields."""
6666
obj = Holder.objects.create(data=Data())
67-
auto_now = self.truncate_ms(obj.data.auto_now)
68-
auto_now_add = self.truncate_ms(obj.data.auto_now_add)
67+
auto_now = truncate_ms(obj.data.auto_now)
68+
auto_now_add = truncate_ms(obj.data.auto_now_add)
6969
self.assertEqual(auto_now, auto_now_add)
7070
# save() updates auto_now but not auto_now_add.
7171
obj.save()
72-
self.assertEqual(self.truncate_ms(obj.data.auto_now_add), auto_now_add)
72+
self.assertEqual(truncate_ms(obj.data.auto_now_add), auto_now_add)
7373
auto_now_two = obj.data.auto_now
7474
self.assertGreater(auto_now_two, obj.data.auto_now_add)
7575
# And again, save() updates auto_now but not auto_now_add.
@@ -99,13 +99,89 @@ def test_gt(self):
9999
def test_gte(self):
100100
self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:])
101101

102+
def test_order_by_embedded_field(self):
103+
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
104+
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))
105+
106+
def test_order_and_group_by_embedded_field(self):
107+
# Create and sort test data by `data__integer`.
108+
expected_objs = sorted(
109+
(Holder.objects.create(data=Data(integer=x)) for x in range(6)),
110+
key=lambda x: x.data.integer,
111+
)
112+
# Group by `data__integer + 5` and get the latest `data__auto_now`
113+
# datetime.
114+
qs = (
115+
Holder.objects.annotate(
116+
group=ExpressionWrapper(F("data__integer") + 5, output_field=models.IntegerField()),
117+
)
118+
.values("group")
119+
.annotate(max_auto_now=Max("data__auto_now"))
120+
.order_by("data__integer")
121+
)
122+
# Each unique `data__integer` is correctly grouped and annotated.
123+
self.assertSequenceEqual(
124+
[{**e, "max_auto_now": e["max_auto_now"]} for e in qs],
125+
[
126+
{"group": e.data.integer + 5, "max_auto_now": truncate_ms(e.data.auto_now)}
127+
for e in expected_objs
128+
],
129+
)
130+
131+
def test_order_and_group_by_embedded_field_annotation(self):
132+
# Create repeated `data__integer` values.
133+
[Holder.objects.create(data=Data(integer=x)) for x in range(6)]
134+
# Group by `data__integer` and compute the sum of occurrences.
135+
qs = (
136+
Holder.objects.values("data__integer")
137+
.annotate(sum=Sum("data__integer"))
138+
.order_by("sum")
139+
)
140+
# The sum is twice the integer values since each appears twice.
141+
self.assertQuerySetEqual(qs, [0, 2, 4, 6, 8, 10], operator.itemgetter("sum"))
142+
102143
def test_nested(self):
103144
obj = Book.objects.create(
104145
author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
105146
)
106147
self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj])
107148

108149

150+
class InvalidLookupTests(SimpleTestCase):
151+
def test_invalid_field(self):
152+
msg = "Author has no field named 'first_name'"
153+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
154+
Book.objects.filter(author__first_name="Bob")
155+
156+
def test_invalid_field_nested(self):
157+
msg = "Address has no field named 'floor'"
158+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
159+
Book.objects.filter(author__address__floor="NYC")
160+
161+
def test_invalid_lookup(self):
162+
msg = "Unsupported lookup 'foo' for CharField 'city'."
163+
with self.assertRaisesMessage(FieldDoesNotExist, msg):
164+
Book.objects.filter(author__address__city__foo="NYC")
165+
166+
def test_invalid_lookup_with_suggestions(self):
167+
msg = (
168+
"Unsupported lookup '{lookup}' for CharField 'name', "
169+
"perhaps you meant {suggested_lookups}?"
170+
)
171+
with self.assertRaisesMessage(
172+
FieldDoesNotExist, msg.format(lookup="exactly", suggested_lookups="exact or iexact")
173+
):
174+
Book.objects.filter(author__name__exactly="NYC")
175+
with self.assertRaisesMessage(
176+
FieldDoesNotExist, msg.format(lookup="gti", suggested_lookups="gt or gte")
177+
):
178+
Book.objects.filter(author__name__gti="NYC")
179+
with self.assertRaisesMessage(
180+
FieldDoesNotExist, msg.format(lookup="is_null", suggested_lookups="isnull")
181+
):
182+
Book.objects.filter(author__name__is_null="NYC")
183+
184+
109185
@isolate_apps("model_fields_")
110186
class CheckTests(SimpleTestCase):
111187
def test_no_relational_fields(self):

tests/model_fields_/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
def truncate_ms(value):
2+
"""Truncate microseconds to milliseconds as supported by MongoDB."""
3+
return value.replace(microsecond=(value.microsecond // 1000) * 1000)

0 commit comments

Comments
 (0)