Skip to content

Commit 02bdf5b

Browse files
authored
add support for typechecking of filter/get/exclude arguments (#183)
* add support for typechecking of filter/get/exclude arguments * linting
1 parent 4d4b000 commit 02bdf5b

File tree

10 files changed

+451
-48
lines changed

10 files changed

+451
-48
lines changed

django-stubs/db/models/fields/__init__.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ _GT = TypeVar("_GT")
4343
class Field(RegisterLookupMixin, Generic[_ST, _GT]):
4444
_pyi_private_set_type: Any
4545
_pyi_private_get_type: Any
46+
_pyi_lookup_exact_type: Any
4647

4748
widget: Widget
4849
help_text: str
@@ -131,6 +132,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
131132
class IntegerField(Field[_ST, _GT]):
132133
_pyi_private_set_type: Union[float, int, str, Combinable]
133134
_pyi_private_get_type: int
135+
_pyi_lookup_exact_type: int
134136

135137
class PositiveIntegerRelDbTypeMixin:
136138
def rel_db_type(self, connection: Any): ...
@@ -143,10 +145,12 @@ class BigIntegerField(IntegerField[_ST, _GT]): ...
143145
class FloatField(Field[_ST, _GT]):
144146
_pyi_private_set_type: Union[float, int, str, Combinable]
145147
_pyi_private_get_type: float
148+
_pyi_lookup_exact_type: float
146149

147150
class DecimalField(Field[_ST, _GT]):
148151
_pyi_private_set_type: Union[str, float, decimal.Decimal, Combinable]
149152
_pyi_private_get_type: decimal.Decimal
153+
_pyi_lookup_exact_type: Union[str, decimal.Decimal]
150154
# attributes
151155
max_digits: int = ...
152156
decimal_places: int = ...
@@ -176,10 +180,13 @@ class DecimalField(Field[_ST, _GT]):
176180
class AutoField(Field[_ST, _GT]):
177181
_pyi_private_set_type: Union[Combinable, int, str]
178182
_pyi_private_get_type: int
183+
_pyi_lookup_exact_type: int
179184

180185
class CharField(Field[_ST, _GT]):
181186
_pyi_private_set_type: Union[str, int, Combinable]
182187
_pyi_private_get_type: str
188+
# objects are converted to string before comparison
189+
_pyi_lookup_exact_type: Any
183190
def __init__(
184191
self,
185192
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -238,14 +245,18 @@ class URLField(CharField[_ST, _GT]): ...
238245
class TextField(Field[_ST, _GT]):
239246
_pyi_private_set_type: Union[str, Combinable]
240247
_pyi_private_get_type: str
248+
# objects are converted to string before comparison
249+
_pyi_lookup_exact_type: Any
241250

242251
class BooleanField(Field[_ST, _GT]):
243252
_pyi_private_set_type: Union[bool, Combinable]
244253
_pyi_private_get_type: bool
254+
_pyi_lookup_exact_type: bool
245255

246256
class NullBooleanField(Field[_ST, _GT]):
247257
_pyi_private_set_type: Optional[Union[bool, Combinable]]
248258
_pyi_private_get_type: Optional[bool]
259+
_pyi_lookup_exact_type: Optional[bool]
249260

250261
class IPAddressField(Field[_ST, _GT]):
251262
_pyi_private_set_type: Union[str, Combinable]
@@ -286,6 +297,7 @@ class DateTimeCheckMixin: ...
286297
class DateField(DateTimeCheckMixin, Field[_ST, _GT]):
287298
_pyi_private_set_type: Union[str, date, Combinable]
288299
_pyi_private_get_type: date
300+
_pyi_lookup_exact_type: Union[str, date]
289301
def __init__(
290302
self,
291303
verbose_name: Optional[Union[str, bytes]] = ...,
@@ -338,6 +350,7 @@ class TimeField(DateTimeCheckMixin, Field[_ST, _GT]):
338350

339351
class DateTimeField(DateField[_ST, _GT]):
340352
_pyi_private_get_type: datetime
353+
_pyi_lookup_exact_type: Union[str, datetime]
341354

342355
class UUIDField(Field[_ST, _GT]):
343356
_pyi_private_set_type: Union[str, uuid.UUID]

django-stubs/db/models/lookups.pyi

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Type, Union
2+
from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic
33

44
from django.db.backends.sqlite3.base import DatabaseWrapper
55
from django.db.models.expressions import Expression, Func
@@ -10,7 +10,9 @@ from django.utils.safestring import SafeText
1010

1111
from django.db.models.fields import TextField, related_lookups
1212

13-
class Lookup:
13+
_T = TypeVar("_T")
14+
15+
class Lookup(Generic[_T]):
1416
lookup_name: str = ...
1517
prepare_rhs: bool = ...
1618
can_use_none_as_rhs: bool = ...
@@ -47,7 +49,7 @@ class Transform(RegisterLookupMixin, Func):
4749
def lhs(self) -> Expression: ...
4850
def get_bilateral_transforms(self) -> List[Type[Transform]]: ...
4951

50-
class BuiltinLookup(Lookup):
52+
class BuiltinLookup(Lookup[_T]):
5153
def get_rhs_op(self, connection: DatabaseWrapper, rhs: str) -> str: ...
5254

5355
class FieldGetDbPrepValueMixin:
@@ -62,21 +64,21 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
6264
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): ...
6365
class IExact(BuiltinLookup): ...
6466
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup): ...
65-
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): ...
66-
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup): ...
67+
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup[_T]): ...
68+
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup[_T]): ...
6769
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): ...
6870

6971
class IntegerFieldFloatRounding:
7072
rhs: Any = ...
7173
def get_prep_lookup(self) -> Any: ...
7274

73-
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual): ...
74-
class IntegerLessThan(IntegerFieldFloatRounding, LessThan): ...
75+
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual[Union[int, float]]): ...
76+
class IntegerLessThan(IntegerFieldFloatRounding, LessThan[Union[int, float]]): ...
7577

7678
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
7779
def split_parameter_list_as_sql(self, compiler: Any, connection: Any): ...
7880

79-
class PatternLookup(BuiltinLookup):
81+
class PatternLookup(BuiltinLookup[str]):
8082
param_pattern: str = ...
8183

8284
class Contains(PatternLookup): ...
@@ -86,8 +88,8 @@ class IStartsWith(StartsWith): ...
8688
class EndsWith(PatternLookup): ...
8789
class IEndsWith(EndsWith): ...
8890
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): ...
89-
class IsNull(BuiltinLookup): ...
90-
class Regex(BuiltinLookup): ...
91+
class IsNull(BuiltinLookup[bool]): ...
92+
class Regex(BuiltinLookup[str]): ...
9193
class IRegex(Regex): ...
9294

9395
class YearLookup(Lookup):

mypy_django_plugin/django/context.py

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import defaultdict
33
from contextlib import contextmanager
44
from typing import (
5-
TYPE_CHECKING, Dict, Iterator, Optional, Set, Tuple, Type, Union,
5+
TYPE_CHECKING, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, Union,
66
)
77

88
from django.core.exceptions import FieldError
@@ -11,14 +11,16 @@
1111
from django.db.models.fields import AutoField, CharField, Field
1212
from django.db.models.fields.related import ForeignKey, RelatedField
1313
from django.db.models.fields.reverse_related import ForeignObjectRel
14+
from django.db.models.lookups import Exact
1415
from django.db.models.sql.query import Query
1516
from django.utils.functional import cached_property
1617
from mypy.checker import TypeChecker
18+
from mypy.plugin import MethodContext
1719
from mypy.types import AnyType, Instance
1820
from mypy.types import Type as MypyType
19-
from mypy.types import TypeOfAny
21+
from mypy.types import TypeOfAny, UnionType
2022

21-
from mypy_django_plugin.lib import helpers
23+
from mypy_django_plugin.lib import fullnames, helpers
2224

2325
try:
2426
from django.contrib.postgres.fields import ArrayField
@@ -153,33 +155,87 @@ def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) ->
153155
return related_model_cls
154156

155157

158+
class LookupsAreUnsupported(Exception):
159+
pass
160+
161+
156162
class DjangoLookupsContext:
157163
def __init__(self, django_context: 'DjangoContext'):
158164
self.django_context = django_context
159165

160-
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Field:
166+
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
167+
currently_observed_model = model_cls
168+
field = None
169+
for field_part in field_parts:
170+
if field_part == 'pk':
171+
field = self.django_context.get_primary_key_field(currently_observed_model)
172+
continue
173+
174+
field = currently_observed_model._meta.get_field(field_part)
175+
if isinstance(field, RelatedField):
176+
currently_observed_model = field.related_model
177+
model_name = currently_observed_model._meta.model_name
178+
if (model_name is not None
179+
and field_part == (model_name + '_id')):
180+
field = self.django_context.get_primary_key_field(currently_observed_model)
181+
182+
if isinstance(field, ForeignObjectRel):
183+
currently_observed_model = field.related_model
184+
185+
assert field is not None
186+
return field
187+
188+
def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field:
161189
query = Query(model_cls)
162190
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
163191
if lookup_parts:
164-
raise FieldError('Lookups not supported yet')
192+
raise LookupsAreUnsupported()
165193

166-
currently_observed_model = model_cls
167-
current_field = None
168-
for field_part in field_parts:
169-
if field_part == 'pk':
170-
return self.django_context.get_primary_key_field(currently_observed_model)
194+
return self._resolve_field_from_parts(field_parts, model_cls)
171195

172-
current_field = currently_observed_model._meta.get_field(field_part)
173-
if not isinstance(current_field, (ForeignObjectRel, RelatedField)):
174-
continue
196+
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
197+
query = Query(model_cls)
198+
try:
199+
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
200+
if is_expression:
201+
return AnyType(TypeOfAny.explicit)
202+
except FieldError as exc:
203+
ctx.api.fail(exc.args[0], ctx.context)
204+
return AnyType(TypeOfAny.from_error)
205+
206+
field = self._resolve_field_from_parts(field_parts, model_cls)
207+
208+
lookup_cls = None
209+
if lookup_parts:
210+
lookup = lookup_parts[-1]
211+
lookup_cls = field.get_lookup(lookup)
212+
if lookup_cls is None:
213+
# unknown lookup
214+
return AnyType(TypeOfAny.explicit)
215+
216+
if lookup_cls is None or isinstance(lookup_cls, Exact):
217+
return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
218+
219+
assert lookup_cls is not None
175220

176-
currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field)
177-
if isinstance(current_field, ForeignObjectRel):
178-
current_field = self.django_context.get_primary_key_field(currently_observed_model)
221+
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
222+
if lookup_info is None:
223+
return AnyType(TypeOfAny.explicit)
179224

180-
# if it is None, solve_lookup_type() will fail earlier
181-
assert current_field is not None
182-
return current_field
225+
for lookup_base in helpers.iter_bases(lookup_info):
226+
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
227+
lookup_type: MypyType = lookup_base.args[0]
228+
# if it's Field, consider lookup_type a __get__ of current field
229+
if (isinstance(lookup_type, Instance)
230+
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
231+
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
232+
if field_info is None:
233+
return AnyType(TypeOfAny.explicit)
234+
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
235+
is_nullable=field.null)
236+
return lookup_type
237+
238+
return AnyType(TypeOfAny.explicit)
183239

184240

185241
class DjangoContext:
@@ -228,6 +284,27 @@ def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectR
228284
if isinstance(field, ForeignObjectRel):
229285
yield field
230286

287+
def get_field_lookup_exact_type(self, api: TypeChecker, field: Field) -> MypyType:
288+
if isinstance(field, (RelatedField, ForeignObjectRel)):
289+
related_model_cls = field.related_model
290+
primary_key_field = self.get_primary_key_field(related_model_cls)
291+
primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init')
292+
293+
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
294+
if rel_model_info is None:
295+
return AnyType(TypeOfAny.explicit)
296+
297+
model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []),
298+
primary_key_type])
299+
return helpers.make_optional(model_and_primary_key_type)
300+
# return helpers.make_optional(Instance(rel_model_info, []))
301+
302+
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
303+
if field_info is None:
304+
return AnyType(TypeOfAny.explicit)
305+
return helpers.get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
306+
is_nullable=field.null)
307+
231308
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
232309
for field in model_cls._meta.get_fields():
233310
if isinstance(field, Field):

mypy_django_plugin/lib/helpers.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from collections import OrderedDict
2-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast
2+
from typing import (
3+
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union, cast,
4+
)
35

6+
from django.db.models.fields import Field
7+
from django.db.models.fields.related import RelatedField
8+
from django.db.models.fields.reverse_related import ForeignObjectRel
49
from mypy import checker
510
from mypy.checker import TypeChecker
611
from mypy.mro import calculate_mro
@@ -115,29 +120,50 @@ def parse_bool(expr: Expression) -> Optional[bool]:
115120
return None
116121

117122

118-
def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
123+
def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool:
119124
for base_fullname in bases:
120125
if info.has_base(base_fullname):
121126
return True
122127
return False
123128

124129

130+
def iter_bases(info: TypeInfo) -> Iterator[Instance]:
131+
for base in info.bases:
132+
yield base
133+
yield from iter_bases(base.type)
134+
135+
125136
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
126137
""" Return declared type of type_info's private_field_name (used for private Field attributes)"""
127138
sym = type_info.get(private_field_name)
128139
if sym is None:
129-
return AnyType(TypeOfAny.unannotated)
140+
return AnyType(TypeOfAny.explicit)
130141

131142
node = sym.node
132143
if isinstance(node, Var):
133144
descriptor_type = node.type
134145
if descriptor_type is None:
135-
return AnyType(TypeOfAny.unannotated)
146+
return AnyType(TypeOfAny.explicit)
136147

137148
if is_nullable:
138149
descriptor_type = make_optional(descriptor_type)
139150
return descriptor_type
140-
return AnyType(TypeOfAny.unannotated)
151+
return AnyType(TypeOfAny.explicit)
152+
153+
154+
def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
155+
if isinstance(field, (RelatedField, ForeignObjectRel)):
156+
lookup_type_class = field.related_model
157+
rel_model_info = lookup_class_typeinfo(api, lookup_type_class)
158+
if rel_model_info is None:
159+
return AnyType(TypeOfAny.from_error)
160+
return make_optional(Instance(rel_model_info, []))
161+
162+
field_info = lookup_class_typeinfo(api, field.__class__)
163+
if field_info is None:
164+
return AnyType(TypeOfAny.explicit)
165+
return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
166+
is_nullable=field.null)
141167

142168

143169
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:

mypy_django_plugin/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def get_method_hook(self, fullname: str
209209
manager_classes = self._get_current_manager_bases()
210210
if class_fullname in manager_classes and method_name == 'create':
211211
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
212+
if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}:
213+
return partial(init_create.typecheck_queryset_filter, django_context=self.django_context)
212214
return None
213215

214216
def get_base_class_hook(self, fullname: str

0 commit comments

Comments
 (0)