Skip to content

Commit 78ef3dd

Browse files
authored
Fix annotation propagation for non-filter queries (#1590) (#1593)
* Fix annotation propagation for non-filter queries (#1590) * Fix lint * Fix test
1 parent 3644d67 commit 78ef3dd

File tree

5 files changed

+102
-18
lines changed

5 files changed

+102
-18
lines changed

poetry.lock

Lines changed: 10 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ aiomysql = { version = "*", optional = true }
4747
asyncmy = { version = "^0.2.8", optional = true, allow-prereleases = true }
4848
psycopg = { extras = ["pool", "binary"], version = "^3.0.12", optional = true }
4949
asyncodbc = { version = "^0.1.1", optional = true }
50+
pydantic = { version = "^2.0,!=2.7.0", optional = true }
5051

5152
[tool.poetry.dev-dependencies]
5253
# Linter tools
@@ -72,7 +73,7 @@ sanic = "*"
7273
# Sample integration - Starlette
7374
starlette = "*"
7475
# Pydantic support
75-
pydantic = "^2.0"
76+
pydantic = "^2.0,!=2.7.0"
7677
# FastAPI support
7778
fastapi = "^0.100.0"
7879
asgi_lifespan = "*"

tests/test_aggregation.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ async def test_aggregation(self):
4646
await Event.all().annotate(tournament_test_id=Sum("tournament__id")).first()
4747
)
4848
self.assertEqual(
49-
event_with_annotation.tournament_test_id, event_with_annotation.tournament_id
49+
event_with_annotation.tournament_test_id,
50+
event_with_annotation.tournament_id,
5051
)
5152

5253
with self.assertRaisesRegex(ConfigurationError, "name__id not resolvable"):
@@ -162,3 +163,76 @@ async def test_concat_functions(self):
162163
.values("long_info")
163164
)
164165
self.assertEqual(ret, [{"long_info": "Physics Book(physics)"}])
166+
167+
async def test_count_after_aggregate(self):
168+
author = await Author.create(name="1")
169+
await Book.create(name="First!", author=author, rating=4)
170+
await Book.create(name="Second!", author=author, rating=3)
171+
await Book.create(name="Third!", author=author, rating=3)
172+
173+
author2 = await Author.create(name="2")
174+
await Book.create(name="F-2", author=author2, rating=3)
175+
await Book.create(name="F-3", author=author2, rating=3)
176+
177+
author3 = await Author.create(name="3")
178+
await Book.create(name="F-4", author=author3, rating=3)
179+
await Book.create(name="F-5", author=author3, rating=2)
180+
ret = (
181+
await Author.all()
182+
.annotate(average_rating=Avg("books__rating"))
183+
.filter(average_rating__gte=3)
184+
.count()
185+
)
186+
187+
assert ret == 2
188+
189+
async def test_exist_after_aggregate(self):
190+
author = await Author.create(name="1")
191+
await Book.create(name="First!", author=author, rating=4)
192+
await Book.create(name="Second!", author=author, rating=3)
193+
await Book.create(name="Third!", author=author, rating=3)
194+
195+
ret = (
196+
await Author.all()
197+
.annotate(average_rating=Avg("books__rating"))
198+
.filter(average_rating__gte=3)
199+
.exists()
200+
)
201+
202+
assert ret is True
203+
204+
ret = (
205+
await Author.all()
206+
.annotate(average_rating=Avg("books__rating"))
207+
.filter(average_rating__gte=4)
208+
.exists()
209+
)
210+
assert ret is False
211+
212+
async def test_count_after_aggregate_m2m(self):
213+
tournament = await Tournament.create(name="1")
214+
event1 = await Event.create(name="First!", tournament=tournament)
215+
event2 = await Event.create(name="Second!", tournament=tournament)
216+
event3 = await Event.create(name="Third!", tournament=tournament)
217+
event4 = await Event.create(name="Fourth!", tournament=tournament)
218+
219+
team1 = await Team.create(name="1")
220+
team2 = await Team.create(name="2")
221+
team3 = await Team.create(name="3")
222+
223+
await event1.participants.add(team1, team2, team3)
224+
await event2.participants.add(team1, team2)
225+
await event3.participants.add(team1)
226+
await event4.participants.add(team1, team2, team3)
227+
228+
query = (
229+
Event.filter(participants__id__in=[team1.id, team2.id, team3.id])
230+
.annotate(count=Count("event_id"))
231+
.filter(count=3)
232+
.prefetch_related("participants")
233+
)
234+
result = await query
235+
assert len(result) == 2
236+
237+
res = await query.count()
238+
assert res == 2

tests/test_queryset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ async def test_force_index_available_in_more_query(self):
464464
sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql()
465465
self.assertEqual(
466466
sql_CountQuery,
467-
"SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
467+
"SELECT COUNT('*') FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
468468
)
469469

470470
sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql()
@@ -504,7 +504,7 @@ async def test_use_index_available_in_more_query(self):
504504
sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql()
505505
self.assertEqual(
506506
sql_CountQuery,
507-
"SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
507+
"SELECT COUNT('*') FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
508508
)
509509

510510
sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql()

tortoise/queryset.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
)
2222

2323
from pypika import JoinType, Order, Table
24-
from pypika.functions import Cast, Count
24+
from pypika.analytics import Count
25+
from pypika.functions import Cast
2526
from pypika.queries import QueryBuilder
2627
from pypika.terms import Case, Field, Term, ValueWrapper
2728
from typing_extensions import Literal, Protocol
@@ -131,7 +132,7 @@ def resolve_filters(
131132
:param annotations: Extra annotations to add.
132133
:param custom_filters: Pre-resolved filters to be passed through.
133134
"""
134-
has_aggregate = self._resolve_annotate()
135+
has_aggregate = self._resolve_annotate(annotations)
135136

136137
modifier = QueryModifier()
137138
for node in q_objects:
@@ -236,13 +237,14 @@ def resolve_ordering(
236237

237238
self.query = self.query.orderby(field, order=ordering[1])
238239

239-
def _resolve_annotate(self) -> bool:
240-
if not self._annotations:
240+
def _resolve_annotate(self, extra_annotations: Dict[str, Any]) -> bool:
241+
if not self._annotations and not extra_annotations:
241242
return False
242243

243244
table = self.model._meta.basetable
245+
all_annotations = {**self._annotations, **extra_annotations}
244246
annotation_info = {}
245-
for key, annotation in self._annotations.items():
247+
for key, annotation in all_annotations.items():
246248
if isinstance(annotation, Term):
247249
annotation_info[key] = {"joins": [], "field": annotation}
248250
else:
@@ -251,7 +253,8 @@ def _resolve_annotate(self) -> bool:
251253
for key, info in annotation_info.items():
252254
for join in info["joins"]:
253255
self._join_table_by_field(*join)
254-
self.query._select_other(info["field"].as_(key))
256+
if key in self._annotations:
257+
self.query._select_other(info["field"].as_(key))
255258

256259
return any(info["field"].is_aggregate for info in annotation_info.values())
257260

@@ -1282,7 +1285,10 @@ def _make_query(self) -> None:
12821285
annotations=self.annotations,
12831286
custom_filters=self.custom_filters,
12841287
)
1285-
self.query._select_other(Count("*"))
1288+
count_term = Count("*")
1289+
if self.query._groupbys:
1290+
count_term = count_term.over()
1291+
self.query._select_other(count_term)
12861292

12871293
if self.force_indexes:
12881294
self.query._force_indexes = []

0 commit comments

Comments
 (0)