Skip to content

Commit 53cdbe4

Browse files
authored
Improve ManyToManyDescriptor and fix Model.<manytomany>.through typing (#1805)
`ManyToManyDescriptor` is now extended to take 1 new type argument, which is the target model/other side of the relation. The plugin is updated to: - Set a `ManyToManyDescriptor` instance instead of a related manager, for reverse relations of a `ManyToManyField` - Produce a manager class with `ManyRelatedManager` and a model's default manager as bases for both sides of a many-to-many relation
1 parent 0a61d81 commit 53cdbe4

File tree

12 files changed

+174
-53
lines changed

12 files changed

+174
-53
lines changed

django-stubs/db/models/fields/related.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ RECURSIVE_RELATIONSHIP_CONSTANT: Literal["self"]
2525

2626
def resolve_relation(scope_model: type[Model], relation: str | type[Model]) -> str | type[Model]: ...
2727

28-
_M = TypeVar("_M", bound=Model)
2928
# __set__ value type
3029
_ST = TypeVar("_ST")
3130
# __get__ return type
@@ -232,9 +231,10 @@ class OneToOneField(ForeignKey[_ST, _GT]):
232231
@overload
233232
def __get__(self, instance: Any, owner: Any) -> Self: ...
234233

234+
_Through = TypeVar("_Through", bound=Model)
235235
_To = TypeVar("_To", bound=Model)
236236

237-
class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]):
237+
class ManyToManyField(RelatedField[Any, Any], Generic[_To, _Through]):
238238
description: str
239239
has_null_arg: bool
240240
swappable: bool
@@ -253,7 +253,7 @@ class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]):
253253
related_query_name: str | None = ...,
254254
limit_choices_to: _AllLimitChoicesTo | None = ...,
255255
symmetrical: bool | None = ...,
256-
through: type[_M] | str | None = ...,
256+
through: type[_Through] | str | None = ...,
257257
through_fields: tuple[str, str] | None = ...,
258258
db_constraint: bool = ...,
259259
db_table: str | None = ...,
@@ -282,7 +282,7 @@ class ManyToManyField(RelatedField[Any, Any], Generic[_To, _M]):
282282
) -> None: ...
283283
# class access
284284
@overload
285-
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[_M]: ...
285+
def __get__(self, instance: None, owner: Any) -> ManyToManyDescriptor[_To, _Through]: ...
286286
# Model instance access
287287
@overload
288288
def __get__(self, instance: Model, owner: Any) -> ManyRelatedManager[_To]: ...

django-stubs/db/models/fields/related_descriptors.pyi

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ from typing_extensions import Self
1515
_M = TypeVar("_M", bound=Model)
1616
_F = TypeVar("_F", bound=Field)
1717
_From = TypeVar("_From", bound=Model)
18+
_Through = TypeVar("_Through", bound=Model)
1819
_To = TypeVar("_To", bound=Model)
1920

2021
class ForeignKeyDeferredAttribute(DeferredAttribute):
@@ -84,7 +85,7 @@ class ReverseManyToOneDescriptor:
8485
@overload
8586
def __get__(self, instance: None, cls: Any = ...) -> Self: ...
8687
@overload
87-
def __get__(self, instance: Model, cls: Any = ...) -> type[RelatedManager[Any]]: ...
88+
def __get__(self, instance: Model, cls: Any = ...) -> RelatedManager[Any]: ...
8889
def __set__(self, instance: Any, value: Any) -> NoReturn: ...
8990

9091
# Fake class, Django defines 'RelatedManager' inside a function body
@@ -104,7 +105,7 @@ def create_reverse_many_to_one_manager(
104105
superclass: type[BaseManager[_M]], rel: ManyToOneRel
105106
) -> type[RelatedManager[_M]]: ...
106107

107-
class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_M]):
108+
class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_To, _Through]):
108109
"""
109110
In the example::
110111
@@ -117,13 +118,17 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor, Generic[_M]):
117118

118119
# 'field' here is 'rel.field'
119120
rel: ManyToManyRel # type: ignore[assignment]
120-
field: ManyToManyField[Any, _M] # type: ignore[assignment]
121+
field: ManyToManyField[_To, _Through] # type: ignore[assignment]
121122
reverse: bool
122123
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
123124
@property
124-
def through(self) -> type[_M]: ...
125+
def through(self) -> type[_Through]: ...
125126
@cached_property
126-
def related_manager_cls(self) -> type[ManyRelatedManager[Any]]: ... # type: ignore[override]
127+
def related_manager_cls(self) -> type[ManyRelatedManager[_To]]: ... # type: ignore[override]
128+
@overload # type: ignore[override]
129+
def __get__(self, instance: None, cls: Any = ...) -> Self: ...
130+
@overload
131+
def __get__(self, instance: Model, cls: Any = ...) -> ManyRelatedManager[_To]: ...
127132

128133
# Fake class, Django defines 'ManyRelatedManager' inside a function body
129134
class ManyRelatedManager(Manager[_M], Generic[_M]):

mypy_django_plugin/lib/fullnames.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
}
3535

3636
REVERSE_ONE_TO_ONE_DESCRIPTOR = "django.db.models.fields.related_descriptors.ReverseOneToOneDescriptor"
37+
MANY_TO_MANY_DESCRIPTOR = "django.db.models.fields.related_descriptors.ManyToManyDescriptor"
38+
MANY_RELATED_MANAGER = "django.db.models.fields.related_descriptors.ManyRelatedManager"
3739
RELATED_FIELDS_CLASSES = frozenset(
3840
(
3941
FOREIGN_OBJECT_FULLNAME,

mypy_django_plugin/lib/helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,20 @@ def get_django_metadata_bases(
6767
return get_django_metadata(model_info).setdefault(key, cast(Dict[str, int], {}))
6868

6969

70+
def get_reverse_manager_info(
71+
api: Union[TypeChecker, SemanticAnalyzer], model_info: TypeInfo, derived_from: str
72+
) -> Optional[TypeInfo]:
73+
manager_fullname = get_django_metadata(model_info).get("reverse_managers", {}).get(derived_from)
74+
if not manager_fullname:
75+
return None
76+
77+
return lookup_fully_qualified_typeinfo(api, manager_fullname)
78+
79+
80+
def set_reverse_manager_info(model_info: TypeInfo, derived_from: str, fullname: str) -> None:
81+
get_django_metadata(model_info).setdefault("reverse_managers", {})[derived_from] = fullname
82+
83+
7084
class IncompleteDefnException(Exception):
7185
pass
7286

@@ -457,3 +471,10 @@ def resolve_lazy_reference(
457471
else:
458472
api.fail("Could not match lazy reference with any model", ctx)
459473
return None
474+
475+
476+
def is_model_instance(instance: Instance) -> bool:
477+
return (
478+
instance.type.metaclass_type is not None
479+
and instance.type.metaclass_type.type.fullname == fullnames.MODEL_METACLASS_FULLNAME
480+
)

mypy_django_plugin/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from mypy_django_plugin.django.context import DjangoContext
2424
from mypy_django_plugin.exceptions import UnregisteredModelError
2525
from mypy_django_plugin.lib import fullnames, helpers
26-
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
26+
from mypy_django_plugin.transformers import fields, forms, init_create, manytomany, meta, querysets, request, settings
2727
from mypy_django_plugin.transformers.functional import resolve_str_promise_attribute
2828
from mypy_django_plugin.transformers.managers import (
2929
create_new_manager_class_from_as_manager_method,
@@ -188,6 +188,12 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
188188
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
189189
return forms.extract_proper_type_for_get_form
190190

191+
elif method_name == "__get__" and class_fullname in {
192+
fullnames.MANYTOMANY_FIELD_FULLNAME,
193+
fullnames.MANY_TO_MANY_DESCRIPTOR,
194+
}:
195+
return manytomany.refine_many_to_many_related_manager
196+
191197
manager_classes = self._get_current_manager_bases()
192198

193199
if method_name == "values":

mypy_django_plugin/transformers/manytomany.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import NamedTuple, Optional, Union
1+
from typing import NamedTuple, Optional, Tuple, Union
22

33
from mypy.checker import TypeChecker
44
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, StrExpr, TypeInfo
5-
from mypy.plugin import FunctionContext
5+
from mypy.plugin import FunctionContext, MethodContext
66
from mypy.semanal import SemanticAnalyzer
77
from mypy.types import Instance, ProperType, UninhabitedType
88
from mypy.types import Type as MypyType
@@ -151,3 +151,57 @@ def get_model_from_expression(
151151
if model_info is not None:
152152
return Instance(model_info, [])
153153
return None
154+
155+
156+
def get_related_manager_and_model(ctx: MethodContext) -> Optional[Tuple[Instance, Instance]]:
157+
if (
158+
isinstance(ctx.default_return_type, Instance)
159+
and ctx.default_return_type.type.fullname == fullnames.MANY_RELATED_MANAGER
160+
):
161+
# This is a call to '__get__' overload with a model instance of 'ManyToManyDescriptor'.
162+
# Returning a 'ManyRelatedManager'. Which we want to, just like Django, build from the
163+
# default manager of the related model.
164+
many_related_manager = ctx.default_return_type
165+
# Require first type argument of 'ManyRelatedManager' to be a model
166+
if (
167+
many_related_manager.args
168+
and isinstance(many_related_manager.args[0], Instance)
169+
and helpers.is_model_instance(many_related_manager.args[0])
170+
):
171+
return many_related_manager, many_related_manager.args[0]
172+
173+
return None
174+
175+
176+
def refine_many_to_many_related_manager(ctx: MethodContext) -> MypyType:
177+
"""
178+
Updates the 'ManyRelatedManager' returned by e.g. 'ManyToManyDescriptor' to be a subclass
179+
of 'ManyRelatedManager' and the related model's default manager.
180+
"""
181+
related_objects = get_related_manager_and_model(ctx)
182+
if related_objects is None:
183+
return ctx.default_return_type
184+
185+
many_related_manager, related_model_instance = related_objects
186+
checker = helpers.get_typechecker_api(ctx)
187+
related_model_instance = related_model_instance.copy_modified()
188+
related_manager_info = helpers.get_reverse_manager_info(
189+
checker, related_model_instance.type, derived_from="_default_manager"
190+
)
191+
if related_manager_info is None:
192+
default_manager_node = related_model_instance.type.names.get("_default_manager")
193+
if default_manager_node is None or not isinstance(default_manager_node.type, Instance):
194+
return ctx.default_return_type
195+
196+
related_manager_info = helpers.add_new_class_for_module(
197+
module=checker.modules[related_model_instance.type.module_name],
198+
name=f"{related_model_instance.type.name}_ManyRelatedManager",
199+
bases=[many_related_manager, default_manager_node.type],
200+
)
201+
related_manager_info.metadata["django"] = {"related_manager_to_model": related_model_instance.type.fullname}
202+
helpers.set_reverse_manager_info(
203+
related_model_instance.type,
204+
derived_from="_default_manager",
205+
fullname=related_manager_info.fullname,
206+
)
207+
return Instance(related_manager_info, [])

mypy_django_plugin/transformers/models.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.db.models import Manager, Model
55
from django.db.models.fields import DateField, DateTimeField, Field
6-
from django.db.models.fields.reverse_related import ForeignObjectRel, OneToOneRel
6+
from django.db.models.fields.reverse_related import ManyToManyRel, OneToOneRel
77
from mypy.checker import TypeChecker
88
from mypy.nodes import (
99
ARG_STAR2,
@@ -448,23 +448,15 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
448448

449449

450450
class AddReverseLookups(ModelClassInitializer):
451-
def get_reverse_manager_info(self, model_info: TypeInfo, derived_from: str) -> Optional[TypeInfo]:
452-
manager_fullname = helpers.get_django_metadata(model_info).get("reverse_managers", {}).get(derived_from)
453-
if not manager_fullname:
454-
return None
455-
456-
symbol = self.api.lookup_fully_qualified_or_none(manager_fullname)
457-
if symbol is None or not isinstance(symbol.node, TypeInfo):
458-
return None
459-
return symbol.node
451+
@cached_property
452+
def reverse_one_to_one_descriptor(self) -> TypeInfo:
453+
return self.lookup_typeinfo_or_incomplete_defn_error(fullnames.REVERSE_ONE_TO_ONE_DESCRIPTOR)
460454

461-
def set_reverse_manager_info(self, model_info: TypeInfo, derived_from: str, fullname: str) -> None:
462-
helpers.get_django_metadata(model_info).setdefault("reverse_managers", {})[derived_from] = fullname
455+
@cached_property
456+
def many_to_many_descriptor(self) -> TypeInfo:
457+
return self.lookup_typeinfo_or_incomplete_defn_error(fullnames.MANY_TO_MANY_DESCRIPTOR)
463458

464459
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
465-
reverse_one_to_one_descriptor = self.lookup_typeinfo_or_incomplete_defn_error(
466-
fullnames.REVERSE_ONE_TO_ONE_DESCRIPTOR
467-
)
468460
# add related managers
469461
for relation in self.django_context.get_model_relations(model_cls):
470462
attname = relation.get_accessor_name()
@@ -487,13 +479,27 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
487479
self.add_new_node_to_model_class(
488480
attname,
489481
Instance(
490-
reverse_one_to_one_descriptor,
482+
self.reverse_one_to_one_descriptor,
491483
[Instance(self.model_classdef.info, []), Instance(related_model_info, [])],
492484
),
493485
)
494486
continue
495487

496-
if isinstance(relation, ForeignObjectRel):
488+
elif isinstance(relation, ManyToManyRel):
489+
# TODO: 'relation' should be based on `TypeInfo` instead of Django runtime.
490+
to_fullname = helpers.get_class_fullname(relation.remote_field.model)
491+
to_model_info = self.lookup_typeinfo_or_incomplete_defn_error(to_fullname)
492+
assert relation.through is not None
493+
through_fullname = helpers.get_class_fullname(relation.through)
494+
through_model_info = self.lookup_typeinfo_or_incomplete_defn_error(through_fullname)
495+
self.add_new_node_to_model_class(
496+
attname,
497+
Instance(
498+
self.many_to_many_descriptor, [Instance(to_model_info, []), Instance(through_model_info, [])]
499+
),
500+
)
501+
502+
else:
497503
related_manager_info = None
498504
try:
499505
related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(
@@ -534,8 +540,8 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
534540

535541
# Check if the related model has a related manager subclassed from the default manager
536542
# TODO: Support other reverse managers than `_default_manager`
537-
default_reverse_manager_info = self.get_reverse_manager_info(
538-
model_info=related_model_info, derived_from="_default_manager"
543+
default_reverse_manager_info = helpers.get_reverse_manager_info(
544+
self.api, model_info=related_model_info, derived_from="_default_manager"
539545
)
540546
if default_reverse_manager_info:
541547
self.add_new_node_to_model_class(attname, Instance(default_reverse_manager_info, []))
@@ -564,7 +570,7 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
564570
new_related_manager_info.metadata["django"] = {"related_manager_to_model": related_model_info.fullname}
565571
# Stash the new reverse manager type fullname on the related model, so we don't duplicate
566572
# or have to create it again for other reverse relations
567-
self.set_reverse_manager_info(
573+
helpers.set_reverse_manager_info(
568574
related_model_info,
569575
derived_from="_default_manager",
570576
fullname=new_related_manager_info.fullname,

mypy_django_plugin/transformers/orm_lookups.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
1515
lookup_kwargs = ctx.arg_names[1] if len(ctx.arg_names) >= 2 else []
1616
provided_lookup_types = ctx.arg_types[1] if len(ctx.arg_types) >= 2 else []
1717

18-
assert isinstance(ctx.type, Instance)
19-
20-
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
18+
if not isinstance(ctx.type, Instance) or not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
2119
return ctx.default_return_type
2220

2321
model_cls_fullname = ctx.type.args[0].type.fullname

scripts/stubtest/allowlist.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ django.db.models.fields.related_descriptors.RelatedManager
2626
# _locally/dynamically_ runtime -- Created via
2727
# 'django.db.models.fields.related_descriptors.create_reverse_many_to_one_manager'
2828
django.contrib.admin.models.LogEntry_RelatedManager
29-
django.contrib.auth.models.Group_RelatedManager
3029
django.contrib.auth.models.Permission_RelatedManager
31-
django.contrib.auth.models.User_RelatedManager
3230

3331
# BaseArchive abstract methods that take no argument, but typed with arguments to match the Archive and TarArchive Implementations
3432
django.utils.archive.BaseArchive.list

scripts/stubtest/allowlist_todo.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ django.contrib.auth.models.GroupManager.__slotnames__
167167
django.contrib.auth.models.Permission.codename
168168
django.contrib.auth.models.Permission.content_type
169169
django.contrib.auth.models.Permission.content_type_id
170-
django.contrib.auth.models.Permission.group_set
171170
django.contrib.auth.models.Permission.id
172171
django.contrib.auth.models.Permission.name
173172
django.contrib.auth.models.Permission.user_set

0 commit comments

Comments
 (0)