Skip to content

Commit dd8d93f

Browse files
Support GenericPrefetch (#2851)
1 parent 7fe6ea5 commit dd8d93f

File tree

4 files changed

+141
-9
lines changed

4 files changed

+141
-9
lines changed
Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
1-
from typing import Any
1+
from typing import Any, Generic, TypeVar
22

3-
from django.db.models import Prefetch
3+
from django.db.models import Model, Prefetch
44
from django.db.models.query import QuerySet
55

6-
class GenericPrefetch(Prefetch):
7-
def __init__(self, lookup: str, querysets: list[QuerySet], to_attr: str | None = None) -> None: ...
6+
# The type of the lookup passed to Prefetch(...)
7+
# This will be specialized to a `LiteralString` in the plugin for further processing and validation
8+
_LookupT = TypeVar("_LookupT", bound=str, covariant=True)
9+
# The type of the querysets passed to GenericPrefetch(...)
10+
_PrefetchedQuerySetsT = TypeVar(
11+
"_PrefetchedQuerySetsT", bound=list[QuerySet[Model]], covariant=True, default=list[QuerySet[Model]]
12+
)
13+
# The attribute name to store the prefetched list[_PrefetchedQuerySetT]
14+
# This will be specialized to a `LiteralString` in the plugin for further processing and validation
15+
_ToAttrT = TypeVar("_ToAttrT", bound=str, covariant=True, default=str)
16+
17+
class GenericPrefetch(Prefetch, Generic[_LookupT, _PrefetchedQuerySetsT, _ToAttrT]):
18+
def __init__(self, lookup: _LookupT, querysets: _PrefetchedQuerySetsT, to_attr: _ToAttrT | None = None) -> None: ...
819
def __getstate__(self) -> dict[str, Any]: ...
9-
def get_current_querysets(self, level: int) -> list[QuerySet] | None: ...
20+
def get_current_querysets(self, level: int) -> _PrefetchedQuerySetsT | None: ...

mypy_django_plugin/lib/fullnames.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
1818
RELATED_MANAGER_CLASS = "django.db.models.fields.related_descriptors.RelatedManager"
1919
PREFETCH_CLASS_FULLNAME = "django.db.models.query.Prefetch"
20+
GENERIC_PREFETCH_CLASS_FULLNAME = "django.contrib.contenttypes.prefetch.GenericPrefetch"
2021

2122
CHOICES_CLASS_FULLNAME = "django.db.models.enums.Choices"
2223
CHOICES_TYPE_METACLASS_FULLNAME = "django.db.models.enums.ChoicesType"

mypy_django_plugin/transformers/querysets.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,12 @@ def check_valid_attr_value(
512512

513513

514514
def check_valid_prefetch_related_lookup(
515-
ctx: MethodContext, lookup: str, django_model: DjangoModel, django_context: DjangoContext
515+
ctx: MethodContext,
516+
lookup: str,
517+
django_model: DjangoModel,
518+
django_context: DjangoContext,
519+
*,
520+
is_generic_prefetch: bool = False,
516521
) -> bool:
517522
"""Check if a lookup string resolve to something that can be prefetched"""
518523
current_model_cls = django_model.cls
@@ -528,7 +533,17 @@ def check_valid_prefetch_related_lookup(
528533
ctx.context,
529534
)
530535
return False
531-
if isinstance(rel_obj_descriptor, ForwardManyToOneDescriptor):
536+
if contenttypes_installed and is_generic_prefetch:
537+
from django.contrib.contenttypes.fields import GenericForeignKey
538+
539+
if not isinstance(rel_obj_descriptor, GenericForeignKey):
540+
ctx.api.fail(
541+
f'"{through_attr}" on "{current_model_cls.__name__}" is not a GenericForeignKey, '
542+
f"GenericPrefetch can only be used with GenericForeignKey fields",
543+
ctx.context,
544+
)
545+
return True
546+
elif isinstance(rel_obj_descriptor, ForwardManyToOneDescriptor):
532547
current_model_cls = rel_obj_descriptor.field.remote_field.model
533548
elif isinstance(rel_obj_descriptor, ReverseOneToOneDescriptor):
534549
current_model_cls = rel_obj_descriptor.related.related_model # type:ignore[assignment] # Can't be 'self' for non abstract models
@@ -563,7 +578,10 @@ def check_valid_prefetch_related_lookup(
563578

564579

565580
def check_conflicting_lookups(
566-
ctx: MethodContext, observed_attr: str, qs_types: dict[str, Instance | None], queryset_type: Instance | None
581+
ctx: MethodContext,
582+
observed_attr: str,
583+
qs_types: dict[str, Instance | None],
584+
queryset_type: Instance | None,
567585
) -> bool:
568586
is_conflicting_lookup = bool(observed_attr in qs_types and qs_types[observed_attr] != queryset_type)
569587
if is_conflicting_lookup:
@@ -641,7 +659,13 @@ def extract_prefetch_related_annotations(ctx: MethodContext, django_context: Dja
641659
)
642660
qs_types[to_attr] = queryset_type
643661
if not to_attr and lookup:
644-
check_valid_prefetch_related_lookup(ctx, lookup, qs_model, django_context)
662+
check_valid_prefetch_related_lookup(
663+
ctx,
664+
lookup,
665+
qs_model,
666+
django_context,
667+
is_generic_prefetch=typ.type.has_base(fullnames.GENERIC_PREFETCH_CLASS_FULLNAME),
668+
)
645669
check_conflicting_lookups(ctx, lookup, qs_types, queryset_type)
646670
qs_types[lookup] = queryset_type
647671

tests/typecheck/managers/querysets/test_prefetch_related.yml

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,99 @@
384384
subject_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
385385
subject_id = models.PositiveIntegerField()
386386
subject = GenericForeignKey("subject_content_type", "subject_id")
387+
388+
- case: django_contrib_contenttypes_generic_prefetch
389+
installed_apps:
390+
- django.contrib.contenttypes
391+
- myapp
392+
main: |
393+
from django.contrib.contenttypes.prefetch import GenericPrefetch
394+
from myapp.models import Bookmark, Animal, TaggedItem
395+
from typing_extensions import reveal_type
396+
397+
# Basic GenericPrefetch usage
398+
prefetch = GenericPrefetch(
399+
"content_object", [Bookmark.objects.all(), Animal.objects.only("name")]
400+
)
401+
reveal_type(prefetch) # N: Revealed type is "django.contrib.contenttypes.prefetch.GenericPrefetch[Literal['content_object'], builtins.list[django.db.models.query.QuerySet[django.db.models.base.Model, django.db.models.base.Model]], builtins.str]"
402+
403+
# Using GenericPrefetch with prefetch_related
404+
qs = TaggedItem.objects.prefetch_related(prefetch).all()
405+
reveal_type(qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.TaggedItem, myapp.models.TaggedItem]"
406+
407+
# GenericPrefetch with to_attr
408+
prefetch_with_attr = GenericPrefetch(
409+
"content_object",
410+
[Bookmark.objects.all(), Animal.objects.only("name")],
411+
to_attr="prefetched_object"
412+
)
413+
qs_with_attr = TaggedItem.objects.prefetch_related(prefetch_with_attr).all()
414+
reveal_type(qs_with_attr) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.TaggedItem@AnnotatedWith[TypedDict({'prefetched_object': builtins.list[Any]})], myapp.models.TaggedItem@AnnotatedWith[TypedDict({'prefetched_object': builtins.list[Any]})]]"
415+
reveal_type(qs_with_attr.get().prefetched_object) # N: Revealed type is "builtins.list[Any]"
416+
417+
# GenericPrefetch on invalid field (not a GenericForeignKey)
418+
regular_fk_prefetch = GenericPrefetch(
419+
"content_type",
420+
[Bookmark.objects.all(), Animal.objects.only("name")],
421+
)
422+
TaggedItem.objects.prefetch_related(regular_fk_prefetch).all() # E: "content_type" on "TaggedItem" is not a GenericForeignKey, GenericPrefetch can only be used with GenericForeignKey fields [misc]
423+
regular_field_prefetch = GenericPrefetch(
424+
"tag",
425+
[Bookmark.objects.all(), Animal.objects.only("name")],
426+
)
427+
TaggedItem.objects.prefetch_related(regular_field_prefetch).all() # E: "tag" on "TaggedItem" is not a GenericForeignKey, GenericPrefetch can only be used with GenericForeignKey fields [misc]
428+
429+
files:
430+
- path: myapp/__init__.py
431+
- path: myapp/models.py
432+
content: |
433+
from django.db import models
434+
from django.contrib.contenttypes.models import ContentType
435+
from django.contrib.contenttypes.fields import GenericForeignKey
436+
437+
class Bookmark(models.Model):
438+
url = models.URLField()
439+
440+
class Animal(models.Model):
441+
name = models.CharField(max_length=100)
442+
443+
class TaggedItem(models.Model):
444+
tag = models.CharField(max_length=100)
445+
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
446+
object_id = models.PositiveIntegerField()
447+
content_object = GenericForeignKey('content_type', 'object_id')
448+
449+
450+
- case: uninstalled_django_contrib_contenttypes_generic_prefetch
451+
installed_apps:
452+
- myapp
453+
main: |
454+
from django.contrib.contenttypes.prefetch import GenericPrefetch
455+
from myapp.models import Bookmark, Animal, TaggedItem
456+
from typing_extensions import reveal_type
457+
458+
# Basic GenericPrefetch usage
459+
prefetch = GenericPrefetch(
460+
"content_object", [Bookmark.objects.all(), Animal.objects.only("name")]
461+
)
462+
reveal_type(prefetch) # N: Revealed type is "django.contrib.contenttypes.prefetch.GenericPrefetch[Literal['content_object'], builtins.list[django.db.models.query.QuerySet[django.db.models.base.Model, django.db.models.base.Model]], builtins.str]"
463+
464+
# Using GenericPrefetch with prefetch_related
465+
qs = TaggedItem.objects.prefetch_related(prefetch).all() # E: Cannot find "content_object" on "TaggedItem" object, "content_object" is an invalid parameter to "prefetch_related()" [misc]
466+
reveal_type(qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.TaggedItem, myapp.models.TaggedItem]"
467+
468+
469+
files:
470+
- path: myapp/__init__.py
471+
- path: myapp/models.py
472+
content: |
473+
from django.db import models
474+
475+
class Bookmark(models.Model):
476+
url = models.URLField()
477+
478+
class Animal(models.Model):
479+
name = models.CharField(max_length=100)
480+
481+
class TaggedItem(models.Model):
482+
tag = models.CharField(max_length=100)

0 commit comments

Comments
 (0)