Skip to content

Commit 73f0a79

Browse files
authored
Fix plugin exception "KeyError: 'model_bases'" and related errors (#1563)
1 parent fb890f8 commit 73f0a79

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

mypy_django_plugin/lib/helpers.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union
2+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Literal, Optional, Set, Union, cast
33

44
from django.db.models.fields import Field
55
from django.db.models.fields.related import RelatedField
@@ -35,6 +35,7 @@
3535
from mypy.semanal import SemanticAnalyzer
3636
from mypy.types import AnyType, Instance, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
3737
from mypy.types import Type as MypyType
38+
from typing_extensions import TypedDict
3839

3940
from mypy_django_plugin.lib import fullnames
4041
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
@@ -43,8 +44,23 @@
4344
from mypy_django_plugin.django.context import DjangoContext
4445

4546

46-
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
47-
return model_info.metadata.setdefault("django", {})
47+
class DjangoTypeMetadata(TypedDict, total=False):
48+
from_queryset_manager: str
49+
reverse_managers: Dict[str, str]
50+
baseform_bases: Dict[str, int]
51+
manager_bases: Dict[str, int]
52+
model_bases: Dict[str, int]
53+
queryset_bases: Dict[str, int]
54+
55+
56+
def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata:
57+
return cast(DjangoTypeMetadata, model_info.metadata.setdefault("django", {}))
58+
59+
60+
def get_django_metadata_bases(
61+
model_info: TypeInfo, key: Literal["baseform_bases", "manager_bases", "model_bases", "queryset_bases"]
62+
) -> Dict[str, int]:
63+
return get_django_metadata(model_info).setdefault(key, cast(Dict[str, int], {}))
4864

4965

5066
class IncompleteDefnException(Exception):
@@ -376,4 +392,5 @@ def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType, no_se
376392
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
377393
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
378394
if sym is not None and isinstance(sym.node, TypeInfo):
379-
get_django_metadata(sym.node)["manager_bases"][fullname] = 1
395+
bases = get_django_metadata_bases(sym.node, "manager_bases")
396+
bases[fullname] = 1

mypy_django_plugin/main.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -
4343
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MODEL_CLASS_FULLNAME)
4444

4545
if sym is not None and isinstance(sym.node, TypeInfo):
46-
helpers.get_django_metadata(sym.node)["model_bases"][ctx.cls.fullname] = 1
46+
bases = helpers.get_django_metadata_bases(sym.node, "model_bases")
47+
bases[ctx.cls.fullname] = 1
4748
else:
4849
if not ctx.api.final_iteration:
4950
ctx.api.defer()
@@ -55,7 +56,8 @@ def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -
5556
def transform_form_class(ctx: ClassDefContext) -> None:
5657
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME)
5758
if sym is not None and isinstance(sym.node, TypeInfo):
58-
helpers.get_django_metadata(sym.node)["baseform_bases"][ctx.cls.fullname] = 1
59+
bases = helpers.get_django_metadata_bases(sym.node, "baseform_bases")
60+
bases[ctx.cls.fullname] = 1
5961

6062
forms.make_meta_nested_class_inherit_from_any(ctx)
6163

@@ -77,41 +79,38 @@ def __init__(self, options: Options) -> None:
7779
def _get_current_queryset_bases(self) -> Dict[str, int]:
7880
model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME)
7981
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
80-
return helpers.get_django_metadata(model_sym.node).setdefault( # type: ignore[no-any-return]
81-
"queryset_bases", {fullnames.QUERYSET_CLASS_FULLNAME: 1}
82-
)
82+
bases = helpers.get_django_metadata_bases(model_sym.node, "queryset_bases")
83+
bases[fullnames.QUERYSET_CLASS_FULLNAME] = 1
84+
return bases
8385
else:
8486
return {}
8587

8688
def _get_current_manager_bases(self) -> Dict[str, int]:
8789
model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
8890
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
89-
return helpers.get_django_metadata(model_sym.node).setdefault( # type: ignore[no-any-return]
90-
"manager_bases", {fullnames.MANAGER_CLASS_FULLNAME: 1}
91-
)
91+
bases = helpers.get_django_metadata_bases(model_sym.node, "manager_bases")
92+
bases[fullnames.MANAGER_CLASS_FULLNAME] = 1
93+
return bases
9294
else:
9395
return {}
9496

9597
def _get_current_model_bases(self) -> Dict[str, int]:
9698
model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
9799
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
98-
return helpers.get_django_metadata(model_sym.node).setdefault( # type: ignore[no-any-return]
99-
"model_bases", {fullnames.MODEL_CLASS_FULLNAME: 1}
100-
)
100+
bases = helpers.get_django_metadata_bases(model_sym.node, "model_bases")
101+
bases[fullnames.MODEL_CLASS_FULLNAME] = 1
102+
return bases
101103
else:
102104
return {}
103105

104106
def _get_current_form_bases(self) -> Dict[str, int]:
105107
model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
106108
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
107-
return helpers.get_django_metadata(model_sym.node).setdefault( # type: ignore[no-any-return]
108-
"baseform_bases",
109-
{
110-
fullnames.BASEFORM_CLASS_FULLNAME: 1,
111-
fullnames.FORM_CLASS_FULLNAME: 1,
112-
fullnames.MODELFORM_CLASS_FULLNAME: 1,
113-
},
114-
)
109+
bases = helpers.get_django_metadata_bases(model_sym.node, "baseform_bases")
110+
bases[fullnames.BASEFORM_CLASS_FULLNAME] = 1
111+
bases[fullnames.FORM_CLASS_FULLNAME] = 1
112+
bases[fullnames.MODELFORM_CLASS_FULLNAME] = 1
113+
return bases
115114
else:
116115
return {}
117116

0 commit comments

Comments
 (0)