Skip to content

Commit e37c022

Browse files
authored
Fix annotation in ORDER BY of DISTINCT query (#1886)
1 parent 29a4216 commit e37c022

File tree

4 files changed

+113
-16
lines changed

4 files changed

+113
-16
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Added
1818
Fixed
1919
^^^^^
2020
- Fix update pk field raises unfriendly error (#1873)
21+
- Using `.distinct()` with an annotation and `.order_by()` produces invalid SQL for PostgreSQL (#1886)
2122

2223
Changed
2324
^^^^^^^

tests/test_order_by.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tortoise.contrib import test
1010
from tortoise.contrib.test.condition import NotEQ
1111
from tortoise.exceptions import ConfigurationError, FieldError
12+
from tortoise.expressions import Q, Case, When
1213
from tortoise.functions import Count, Sum
1314

1415

@@ -94,6 +95,45 @@ async def test_order_by_aggregation_reversed(self):
9495
)
9596
self.assertEqual([t.name for t in tournaments], ["1", "2"])
9697

98+
async def test_distinct_values_with_annotation(self):
99+
await Tournament.create(name="3")
100+
await Tournament.create(name="1")
101+
await Tournament.create(name="2")
102+
103+
tournaments = (
104+
await Tournament.annotate(
105+
name_orderable=Case(
106+
When(Q(name="1"), then="1"),
107+
When(Q(name="2"), then="2"),
108+
When(Q(name="3"), then="3"),
109+
default="-1",
110+
),
111+
)
112+
.distinct()
113+
.order_by("name_orderable", "-created")
114+
.values("name", "name_orderable", "created")
115+
)
116+
self.assertEqual([t["name"] for t in tournaments], ["1", "2", "3"])
117+
118+
async def test_distinct_all_withl_annotation(self):
119+
await Tournament.create(name="3")
120+
await Tournament.create(name="1")
121+
await Tournament.create(name="2")
122+
123+
tournaments = (
124+
await Tournament.annotate(
125+
name_orderable=Case(
126+
When(Q(name="1"), then="1"),
127+
When(Q(name="2"), then="2"),
128+
When(Q(name="3"), then="3"),
129+
default="-1",
130+
),
131+
)
132+
.distinct()
133+
.order_by("name_orderable", "-created")
134+
)
135+
self.assertEqual([t.name for t in tournaments], ["1", "2", "3"])
136+
97137

98138
class TestDefaultOrdering(test.TestCase):
99139
@test.requireCapability(dialect=NotEQ("oracle"))

tests/test_values.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tortoise.contrib import test
55
from tortoise.contrib.test.condition import In, NotEQ
66
from tortoise.exceptions import FieldError
7-
from tortoise.expressions import Function
7+
from tortoise.expressions import Q, Case, Function, When
88
from tortoise.functions import Length, Trim
99

1010

@@ -214,3 +214,41 @@ class TruncMonth(Function):
214214
sql,
215215
'SELECT DATE_FORMAT("created",?) "date" FROM "tournament"',
216216
)
217+
218+
async def test_order_by_annotation_not_in_values(self):
219+
await Tournament.create(name="2")
220+
await Tournament.create(name="3")
221+
await Tournament.create(name="1")
222+
223+
tournaments = (
224+
await Tournament.annotate(
225+
name_orderable=Case(
226+
When(Q(name="1"), then="a"),
227+
When(Q(name="2"), then="b"),
228+
When(Q(name="3"), then="c"),
229+
default="z",
230+
)
231+
)
232+
.order_by("name_orderable")
233+
.values("name")
234+
)
235+
self.assertEqual([t["name"] for t in tournaments], ["1", "2", "3"])
236+
237+
async def test_order_by_annotation_not_in_values_list(self):
238+
await Tournament.create(name="2")
239+
await Tournament.create(name="3")
240+
await Tournament.create(name="1")
241+
242+
tournaments = (
243+
await Tournament.annotate(
244+
name_orderable=Case(
245+
When(Q(name="1"), then="a"),
246+
When(Q(name="2"), then="b"),
247+
When(Q(name="3"), then="c"),
248+
default="z",
249+
)
250+
)
251+
.order_by("name_orderable")
252+
.values_list("name")
253+
)
254+
self.assertEqual(tournaments, [("1",), ("2",), ("3",)])

tortoise/queryset.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from __future__ import annotations
22

33
import types
4-
from collections.abc import AsyncIterator, Callable, Generator, Iterable
4+
from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable
55
from copy import copy
66
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast, overload
77

88
from pypika_tortoise import JoinType, Order, Table
99
from pypika_tortoise.analytics import Count
1010
from pypika_tortoise.functions import Cast
1111
from pypika_tortoise.queries import QueryBuilder
12-
from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper
12+
from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper, PseudoColumn
1313
from typing_extensions import Literal, Protocol
1414

1515
from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities
@@ -179,7 +179,8 @@ def resolve_ordering(
179179
model: type[Model],
180180
table: Table,
181181
orderings: Iterable[tuple[str, str | Order]],
182-
annotations: dict[str, Any],
182+
annotations: dict[str, Term | Expression],
183+
fields_for_select: Collection[str] | None = None,
183184
) -> None:
184185
"""
185186
Applies standard ordering to QuerySet.
@@ -189,6 +190,8 @@ def resolve_ordering(
189190
(to allow self referential joins)
190191
:param orderings: What columns/order to order by
191192
:param annotations: Annotations that may be ordered on
193+
:param fields_for_select: Contains fields that are selected in the SELECT clause if
194+
.only(), .values() or .values_list() are used.
192195
193196
:raises FieldError: If a field provided does not exist in model.
194197
"""
@@ -214,18 +217,27 @@ def resolve_ordering(
214217
{},
215218
)
216219
elif field_name in annotations:
217-
if isinstance(annotation := annotations[field_name], Term):
218-
term: Term = annotation
220+
term: Term
221+
if not fields_for_select or field_name in fields_for_select:
222+
# The annotation is SELECTed, we can just reference it in the following cases:
223+
# - Empty fields_for_select means that all columns and annotations are selected,
224+
# hence we can reference the annotation.
225+
# - The annotation is in fields_for_select, hence we can reference it.
226+
term = PseudoColumn(field_name)
219227
else:
220-
annotation_info = annotation.resolve(
221-
ResolveContext(
222-
model=self.model,
223-
table=table,
224-
annotations=annotations,
225-
custom_filters={},
226-
)
227-
)
228-
term = annotation_info.term
228+
# The annotation is not in SELECT, resolve it
229+
annotation = annotations[field_name]
230+
if isinstance(annotation, Term):
231+
term = annotation
232+
else:
233+
term = annotation.resolve(
234+
ResolveContext(
235+
model=self.model,
236+
table=table,
237+
annotations=annotations,
238+
custom_filters={},
239+
)
240+
).term
229241
self.query = self.query.orderby(term, order=ordering[1])
230242
else:
231243
field_object = model._meta.fields_map.get(field_name)
@@ -1078,7 +1090,11 @@ def _make_query(self) -> None:
10781090
if append_item not in self._select_related_idx:
10791091
self._select_related_idx.append(append_item)
10801092
self.resolve_ordering(
1081-
self.model, self.model._meta.basetable, self._orderings, self._annotations
1093+
self.model,
1094+
self.model._meta.basetable,
1095+
self._orderings,
1096+
self._annotations,
1097+
self._fields_for_select,
10821098
)
10831099
self.resolve_filters()
10841100
if self._limit is not None:
@@ -1562,6 +1578,7 @@ def _make_query(self) -> None:
15621578
table=self.model._meta.basetable,
15631579
orderings=self._orderings,
15641580
annotations=self._annotations,
1581+
fields_for_select=self._fields_for_select_list,
15651582
)
15661583
self.resolve_filters()
15671584
if self._limit:
@@ -1683,6 +1700,7 @@ def _make_query(self) -> None:
16831700
table=self.model._meta.basetable,
16841701
orderings=self._orderings,
16851702
annotations=self._annotations,
1703+
fields_for_select=self._fields_for_select.keys(),
16861704
)
16871705
self.resolve_filters()
16881706

0 commit comments

Comments
 (0)