Skip to content

Commit 247f89b

Browse files
committed
Resolve queryset annotate types for expressions with static ClassVar
Instead of always typing annotated fields as Any, resolve the concrete Python type when the expression has a static ClassVar output_field (e.g. Count → int, Exists → bool, Length → int, Now → datetime).
1 parent 3d55124 commit 247f89b

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

mypy_django_plugin/transformers/querysets.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from django.db.models.sql.query import Query
1717
from mypy.checker import TypeChecker
1818
from mypy.errorcodes import NO_REDEF
19-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression, ListExpr, SetExpr, TupleExpr
19+
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression, ListExpr, SetExpr, TupleExpr, Var
2020
from mypy.plugin import FunctionContext, MethodContext
2121
from mypy.types import AnyType, Instance, LiteralType, ProperType, TupleType, TypedDictType, TypeOfAny, get_proper_type
2222
from mypy.types import Type as MypyType
@@ -198,12 +198,39 @@ def gather_kwargs(ctx: MethodContext) -> dict[str, MypyType] | None:
198198
return kwargs
199199

200200

201+
def _resolve_output_field_type(expr_type: MypyType) -> MypyType | None:
202+
"""Try to resolve the Python type for an expression's output_field ClassVar.
203+
204+
Returns None if the output_field can't be statically resolved.
205+
"""
206+
proper = get_proper_type(expr_type)
207+
if not isinstance(proper, Instance):
208+
return None
209+
210+
output_field_sym = proper.type.get("output_field")
211+
if output_field_sym is None or output_field_sym.node is None:
212+
return None
213+
214+
node = output_field_sym.node
215+
if not isinstance(node, Var) or node.type is None:
216+
return None
217+
218+
field_type = get_proper_type(node.type)
219+
if not isinstance(field_type, Instance):
220+
return None
221+
222+
return helpers.get_private_descriptor_type(field_type.type, "_pyi_private_get_type", is_nullable=False)
223+
224+
201225
def gather_expression_types(ctx: MethodContext) -> dict[str, MypyType]:
202226
kwargs = gather_kwargs(ctx)
203227
if not kwargs:
204228
return {}
205229

206-
# For now, we don't try to resolve the output_field of the field would be, but use Any.
230+
# Try to resolve the output_field type for each expression. For expressions
231+
# with a static ClassVar output_field (e.g. Count → IntegerField → int),
232+
# we can infer the concrete Python type. Otherwise, fall back to Any.
233+
#
207234
# NOTE: It's important that we use 'special_form' for 'Any' as otherwise we can
208235
# get stuck with mypy interpreting an overload ambiguity towards the
209236
# overloaded 'Field.__get__' method when its 'model' argument gets matched. This
@@ -230,7 +257,14 @@ def gather_expression_types(ctx: MethodContext) -> dict[str, MypyType]:
230257
# select due to the 'Any' in 'TypedDict({"foo": Any})'. But if we specify the
231258
# 'Any' as 'TypeOfAny.special_form' mypy doesn't consider the model instance to
232259
# contain 'Any' and the ambiguity goes away.
233-
return {name: AnyType(TypeOfAny.special_form) for name, _ in kwargs.items()}
260+
result: dict[str, MypyType] = {}
261+
for name, expr_type in kwargs.items():
262+
resolved = _resolve_output_field_type(expr_type)
263+
if resolved is not None and not isinstance(get_proper_type(resolved), AnyType):
264+
result[name] = resolved
265+
else:
266+
result[name] = AnyType(TypeOfAny.special_form)
267+
return result
234268

235269

236270
def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: DjangoContext) -> MypyType:

tests/typecheck/managers/querysets/test_annotate.yml

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,8 @@
561561
562562
qs = MyModel.objects.all().annotate(num_items=Count("id"))
563563
obj = qs.get()
564-
reveal_type(obj.num_items) # N: Revealed type is "Any"
565-
obj.nonexistent # E: "MyModel@AnnotatedWith[TypedDict({'num_items': Any})]" has no attribute "nonexistent" [attr-defined]
564+
reveal_type(obj.num_items) # N: Revealed type is "builtins.int"
565+
obj.nonexistent # E: "MyModel@AnnotatedWith[TypedDict({'num_items': int})]" has no attribute "nonexistent" [attr-defined]
566566
567567
# Custom queryset methods remain available after annotate
568568
qs.custom_method()
@@ -584,3 +584,49 @@
584584
class MyModel(models.Model):
585585
name = models.CharField(max_length=100)
586586
objects = MyManager()
587+
588+
- case: annotate_resolves_output_field_type
589+
main: |
590+
from typing_extensions import reveal_type
591+
from myapp.models import User
592+
from django.db.models import Count, Exists, OuterRef
593+
from django.db.models.functions import Length, Now, Right, TruncDate, TruncTime
594+
595+
# Count has output_field: ClassVar[IntegerField] → int
596+
user_count = User.objects.annotate(user_count=Count('id')).get().user_count
597+
reveal_type(user_count) # N: Revealed type is "builtins.int"
598+
599+
# Exists has output_field: ClassVar[BooleanField] → bool
600+
has_users = User.objects.annotate(has_users=Exists(User.objects.filter(pk=OuterRef('pk')))).get().has_users
601+
reveal_type(has_users) # N: Revealed type is "builtins.bool"
602+
603+
# Length has output_field: ClassVar[IntegerField] → int
604+
name_len = User.objects.annotate(name_len=Length('username')).get().name_len
605+
reveal_type(name_len) # N: Revealed type is "builtins.int"
606+
607+
# Now has output_field: ClassVar[DateTimeField] → datetime
608+
current_time = User.objects.annotate(current_time=Now()).get().current_time
609+
reveal_type(current_time) # N: Revealed type is "datetime.datetime"
610+
611+
# Right (inherits from Left) has output_field: ClassVar[CharField] → str
612+
suffix = User.objects.annotate(suffix=Right('username', 3)).get().suffix
613+
reveal_type(suffix) # N: Revealed type is "builtins.str"
614+
615+
# TruncDate has output_field: ClassVar[DateField] → date
616+
created_date = User.objects.annotate(created_date=TruncDate('created_at')).get().created_date
617+
reveal_type(created_date) # N: Revealed type is "datetime.date"
618+
619+
# TruncTime has output_field: ClassVar[TimeField] → time
620+
created_time = User.objects.annotate(created_time=TruncTime('created_at')).get().created_time
621+
reveal_type(created_time) # N: Revealed type is "datetime.time"
622+
623+
installed_apps:
624+
- myapp
625+
files:
626+
- path: myapp/__init__.py
627+
- path: myapp/models.py
628+
content: |
629+
from django.db import models
630+
class User(models.Model):
631+
username = models.CharField(max_length=100)
632+
created_at = models.DateTimeField()

0 commit comments

Comments
 (0)