|
2 | 2 | from collections import defaultdict |
3 | 3 | from contextlib import contextmanager |
4 | 4 | from typing import ( |
5 | | - TYPE_CHECKING, Dict, Iterator, Optional, Set, Tuple, Type, Union, |
| 5 | + TYPE_CHECKING, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, Union, |
6 | 6 | ) |
7 | 7 |
|
8 | 8 | from django.core.exceptions import FieldError |
|
11 | 11 | from django.db.models.fields import AutoField, CharField, Field |
12 | 12 | from django.db.models.fields.related import ForeignKey, RelatedField |
13 | 13 | from django.db.models.fields.reverse_related import ForeignObjectRel |
| 14 | +from django.db.models.lookups import Exact |
14 | 15 | from django.db.models.sql.query import Query |
15 | 16 | from django.utils.functional import cached_property |
16 | 17 | from mypy.checker import TypeChecker |
| 18 | +from mypy.plugin import MethodContext |
17 | 19 | from mypy.types import AnyType, Instance |
18 | 20 | from mypy.types import Type as MypyType |
19 | | -from mypy.types import TypeOfAny |
| 21 | +from mypy.types import TypeOfAny, UnionType |
20 | 22 |
|
21 | | -from mypy_django_plugin.lib import helpers |
| 23 | +from mypy_django_plugin.lib import fullnames, helpers |
22 | 24 |
|
23 | 25 | try: |
24 | 26 | from django.contrib.postgres.fields import ArrayField |
@@ -153,33 +155,87 @@ def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> |
153 | 155 | return related_model_cls |
154 | 156 |
|
155 | 157 |
|
| 158 | +class LookupsAreUnsupported(Exception): |
| 159 | + pass |
| 160 | + |
| 161 | + |
156 | 162 | class DjangoLookupsContext: |
157 | 163 | def __init__(self, django_context: 'DjangoContext'): |
158 | 164 | self.django_context = django_context |
159 | 165 |
|
160 | | - def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Field: |
| 166 | + def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field: |
| 167 | + currently_observed_model = model_cls |
| 168 | + field = None |
| 169 | + for field_part in field_parts: |
| 170 | + if field_part == 'pk': |
| 171 | + field = self.django_context.get_primary_key_field(currently_observed_model) |
| 172 | + continue |
| 173 | + |
| 174 | + field = currently_observed_model._meta.get_field(field_part) |
| 175 | + if isinstance(field, RelatedField): |
| 176 | + currently_observed_model = field.related_model |
| 177 | + model_name = currently_observed_model._meta.model_name |
| 178 | + if (model_name is not None |
| 179 | + and field_part == (model_name + '_id')): |
| 180 | + field = self.django_context.get_primary_key_field(currently_observed_model) |
| 181 | + |
| 182 | + if isinstance(field, ForeignObjectRel): |
| 183 | + currently_observed_model = field.related_model |
| 184 | + |
| 185 | + assert field is not None |
| 186 | + return field |
| 187 | + |
| 188 | + def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field: |
161 | 189 | query = Query(model_cls) |
162 | 190 | lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) |
163 | 191 | if lookup_parts: |
164 | | - raise FieldError('Lookups not supported yet') |
| 192 | + raise LookupsAreUnsupported() |
165 | 193 |
|
166 | | - currently_observed_model = model_cls |
167 | | - current_field = None |
168 | | - for field_part in field_parts: |
169 | | - if field_part == 'pk': |
170 | | - return self.django_context.get_primary_key_field(currently_observed_model) |
| 194 | + return self._resolve_field_from_parts(field_parts, model_cls) |
171 | 195 |
|
172 | | - current_field = currently_observed_model._meta.get_field(field_part) |
173 | | - if not isinstance(current_field, (ForeignObjectRel, RelatedField)): |
174 | | - continue |
| 196 | + def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType: |
| 197 | + query = Query(model_cls) |
| 198 | + try: |
| 199 | + lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) |
| 200 | + if is_expression: |
| 201 | + return AnyType(TypeOfAny.explicit) |
| 202 | + except FieldError as exc: |
| 203 | + ctx.api.fail(exc.args[0], ctx.context) |
| 204 | + return AnyType(TypeOfAny.from_error) |
| 205 | + |
| 206 | + field = self._resolve_field_from_parts(field_parts, model_cls) |
| 207 | + |
| 208 | + lookup_cls = None |
| 209 | + if lookup_parts: |
| 210 | + lookup = lookup_parts[-1] |
| 211 | + lookup_cls = field.get_lookup(lookup) |
| 212 | + if lookup_cls is None: |
| 213 | + # unknown lookup |
| 214 | + return AnyType(TypeOfAny.explicit) |
| 215 | + |
| 216 | + if lookup_cls is None or isinstance(lookup_cls, Exact): |
| 217 | + return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) |
| 218 | + |
| 219 | + assert lookup_cls is not None |
175 | 220 |
|
176 | | - currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field) |
177 | | - if isinstance(current_field, ForeignObjectRel): |
178 | | - current_field = self.django_context.get_primary_key_field(currently_observed_model) |
| 221 | + lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) |
| 222 | + if lookup_info is None: |
| 223 | + return AnyType(TypeOfAny.explicit) |
179 | 224 |
|
180 | | - # if it is None, solve_lookup_type() will fail earlier |
181 | | - assert current_field is not None |
182 | | - return current_field |
| 225 | + for lookup_base in helpers.iter_bases(lookup_info): |
| 226 | + if lookup_base.args and isinstance(lookup_base.args[0], Instance): |
| 227 | + lookup_type: MypyType = lookup_base.args[0] |
| 228 | + # if it's Field, consider lookup_type a __get__ of current field |
| 229 | + if (isinstance(lookup_type, Instance) |
| 230 | + and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME): |
| 231 | + field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) |
| 232 | + if field_info is None: |
| 233 | + return AnyType(TypeOfAny.explicit) |
| 234 | + lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', |
| 235 | + is_nullable=field.null) |
| 236 | + return lookup_type |
| 237 | + |
| 238 | + return AnyType(TypeOfAny.explicit) |
183 | 239 |
|
184 | 240 |
|
185 | 241 | class DjangoContext: |
@@ -228,6 +284,27 @@ def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectR |
228 | 284 | if isinstance(field, ForeignObjectRel): |
229 | 285 | yield field |
230 | 286 |
|
| 287 | + def get_field_lookup_exact_type(self, api: TypeChecker, field: Field) -> MypyType: |
| 288 | + if isinstance(field, (RelatedField, ForeignObjectRel)): |
| 289 | + related_model_cls = field.related_model |
| 290 | + primary_key_field = self.get_primary_key_field(related_model_cls) |
| 291 | + primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init') |
| 292 | + |
| 293 | + rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls) |
| 294 | + if rel_model_info is None: |
| 295 | + return AnyType(TypeOfAny.explicit) |
| 296 | + |
| 297 | + model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []), |
| 298 | + primary_key_type]) |
| 299 | + return helpers.make_optional(model_and_primary_key_type) |
| 300 | + # return helpers.make_optional(Instance(rel_model_info, [])) |
| 301 | + |
| 302 | + field_info = helpers.lookup_class_typeinfo(api, field.__class__) |
| 303 | + if field_info is None: |
| 304 | + return AnyType(TypeOfAny.explicit) |
| 305 | + return helpers.get_private_descriptor_type(field_info, '_pyi_lookup_exact_type', |
| 306 | + is_nullable=field.null) |
| 307 | + |
231 | 308 | def get_primary_key_field(self, model_cls: Type[Model]) -> Field: |
232 | 309 | for field in model_cls._meta.get_fields(): |
233 | 310 | if isinstance(field, Field): |
|
0 commit comments