Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def from_model_type(cls, model_type: Instance, django_context: "DjangoContext")

return cls(cls=model_cls, typ=model_type, is_annotated=is_annotated)

def is_annotated_field(self, field_name: str) -> bool:
"""Whether this field name was annotated via annotate/values"""
return bool(self.typ.extra_attrs and field_name in self.typ.extra_attrs.attrs)


def extract_model_type_from_queryset(queryset_type: Instance, api: TypeChecker) -> Instance | None:
"""Extract the django model `Instance` associated to a queryset `Instance`"""
Expand Down
5 changes: 5 additions & 0 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def manager_and_queryset_method_hooks(self) -> dict[str, Callable[[MethodContext
"alatest": partial(querysets.validate_order_by, django_context=self.django_context),
"defer": partial(querysets.validate_defer_only, django_context=self.django_context, is_defer=True),
"only": partial(querysets.validate_defer_only, django_context=self.django_context, is_defer=False),
"distinct": partial(querysets.validate_distinct, django_context=self.django_context),
"update": partial(querysets.validate_update, django_context=self.django_context),
"aupdate": partial(querysets.validate_update, django_context=self.django_context),
"in_bulk": partial(querysets.validate_in_bulk, django_context=self.django_context),
"ain_bulk": partial(querysets.validate_in_bulk, django_context=self.django_context),
}

def get_method_hook(self, fullname: str) -> Callable[[MethodContext], MypyType] | None:
Expand Down
67 changes: 61 additions & 6 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,16 @@ def _check_field_concrete(ctx: MethodContext, field: "_AnyField", field_name: st
return True


def _check_field_unique(
ctx: MethodContext, model_cls: type[Model], field: "_AnyField", field_name: str, method: str
) -> bool:
unique_fields = [c.fields[0] for c in model_cls._meta.total_unique_constraints if len(c.fields) == 1]
if not getattr(field, "unique", None) and field_name not in unique_fields:
ctx.api.fail(f'"{method}()"\'s field_name must be a unique field but "{field_name}" isn\'t', ctx.context)
return False
return True


def _check_field_not_pk(
ctx: MethodContext,
model_cls: type[Model],
Expand Down Expand Up @@ -907,25 +917,28 @@ def validate_bulk_create(
return ctx.default_return_type


def _validate_order_by_lookup(ctx: MethodContext, model_cls: type[Model], parts: list[str]) -> None:
if len(parts) == 1 and parts[0] == "?":
return

def _validate_lookup(ctx: MethodContext, model_cls: type[Model], parts: list[str]) -> None:
try:
Query(model_cls).names_to_path(parts, model_cls._meta, fail_on_missing=True)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)


def _validate_order_by_lookup(ctx: MethodContext, model_cls: type[Model], parts: list[str]) -> None:
if len(parts) == 1 and parts[0] == "?":
return

_validate_lookup(ctx, model_cls, parts)


def validate_order_by(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
if (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is None:
return ctx.default_return_type

for lookup_value in _extract_field_names_from_varargs(ctx):
parts = lookup_value.removeprefix("-").split(LOOKUP_SEP)

if django_model.typ.extra_attrs and parts[0] in django_model.typ.extra_attrs.attrs:
# Skip validation for annotated fields
if django_model.is_annotated_field(field_name=parts[0]):
continue
_validate_order_by_lookup(ctx, django_model.cls, parts)

Expand Down Expand Up @@ -954,3 +967,45 @@ def validate_defer_only(ctx: MethodContext, django_context: DjangoContext, *, is
_validate_defer_only_fields(ctx, django_model.cls, field_names, is_defer=is_defer)

return ctx.default_return_type


def validate_distinct(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
if (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is None:
return ctx.default_return_type

for lookup_value in _extract_field_names_from_varargs(ctx):
parts = lookup_value.split(LOOKUP_SEP)
if django_model.is_annotated_field(field_name=parts[0]):
continue
_validate_lookup(ctx, django_model.cls, parts)

return ctx.default_return_type


def validate_update(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
if (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is None or not (
kwargs := gather_kwargs(ctx)
):
return ctx.default_return_type

for field_name in kwargs:
field = _try_get_field(ctx, django_model.cls, field_name)
if field is not None:
_check_field_concrete(ctx, field, field_name, method="update")

return ctx.default_return_type


def validate_in_bulk(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
if (
(django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is None
or (field_name_expr := helpers.get_call_argument_by_name(ctx, "field_name")) is None
or (field_name := helpers.resolve_string_attribute_value(field_name_expr, django_context)) is None
or field_name == "pk"
or (field := _try_get_field(ctx, django_model.cls, field_name)) is None
):
return ctx.default_return_type

_check_field_unique(ctx, django_model.cls, field, field_name, method="in_bulk")

return ctx.default_return_type
90 changes: 90 additions & 0 deletions tests/typecheck/managers/querysets/test_distinct.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
- case: distinct_valid_fields
installed_apps:
- myapp
main: |
from myapp.models import Article, Author
from django.db.models import F

# Bare distinct (no field names)
Article.objects.distinct()

# Simple fields
Article.objects.distinct("title")
Article.objects.distinct("pk")
Article.objects.distinct("id")

# FK traversal
Article.objects.distinct("author__name")

# Multiple fields
Article.objects.distinct("title", "author__name")

# FK id field
Article.objects.distinct("author_id")

# Annotated fields
Article.objects.annotate(foo=F("id")).distinct("foo")

# Non-literal strings (skipped by validation)
field: str = "foo"
Article.objects.distinct(field)

# Reverse FK traversal
Author.objects.distinct("article__title")

# Chained
Article.objects.filter(published=True).distinct("title")
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class Author(models.Model):
name = models.CharField(max_length=100)
email = models.EmailField()

class Article(models.Model):
title = models.CharField(max_length=200)
content = models.TextField()
published = models.BooleanField(default=False)
author = models.ForeignKey(Author, on_delete=models.CASCADE, related_name="article")


- case: distinct_invalid_fields
installed_apps:
- myapp
main: |
from myapp.models import Article, Author
from typing import Literal

# Non-existent field
Article.objects.distinct("nonexistent") # E: Cannot resolve keyword 'nonexistent' into field. Choices are: author, author_id, content, id, published, title [misc]

# Invalid chained lookup
Article.objects.distinct("author__nonexistent") # E: Cannot resolve keyword 'nonexistent' into field. Choices are: article, email, id, name [misc]

# Lookup suffix on a valid field
Article.objects.distinct("title__exact") # E: Cannot resolve keyword 'exact' into field. Join on 'title' not permitted. [misc]

# Literal typed invalid
field: Literal["nonexistent"] = "nonexistent"
Article.objects.distinct(field) # E: Cannot resolve keyword 'nonexistent' into field. Choices are: author, author_id, content, id, published, title [misc]

# Invalid field at depth > 2
Article.objects.distinct("author__name__nonexistent") # E: Cannot resolve keyword 'nonexistent' into field. Join on 'name' not permitted. [misc]
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class Author(models.Model):
name = models.CharField(max_length=100)
email = models.EmailField()

class Article(models.Model):
title = models.CharField(max_length=200)
content = models.TextField()
published = models.BooleanField(default=False)
author = models.ForeignKey(Author, on_delete=models.CASCADE, related_name="article")
64 changes: 64 additions & 0 deletions tests/typecheck/managers/querysets/test_in_bulk.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
- case: in_bulk_valid
installed_apps:
- myapp
main: |
from myapp.models import Article

# Default (no field_name — uses pk)
Article.objects.in_bulk()

# Explicit pk
Article.objects.in_bulk(field_name="pk")

# Unique field
Article.objects.in_bulk(field_name="slug")

# Non-literal field_name (skipped by validation)
field: str = "title"
Article.objects.in_bulk(field_name=field)

# With id_list
Article.objects.in_bulk([1, 2, 3])
Article.objects.in_bulk([1, 2], field_name="pk")

# Async variant with valid field
async def test_async() -> None:
await Article.objects.ain_bulk(field_name="slug")
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class Article(models.Model):
title = models.CharField(max_length=200)
slug = models.SlugField(unique=True)
published = models.BooleanField(default=False)


- case: in_bulk_invalid
installed_apps:
- myapp
main: |
from myapp.models import Article

# Non-existent field
Article.objects.in_bulk(field_name="nonexistent") # E: Article has no field named 'nonexistent' [misc]

# Non-unique field
Article.objects.in_bulk(field_name="title") # E: "in_bulk()"'s field_name must be a unique field but "title" isn't [misc]

# Async variant with invalid field
async def test_async() -> None:
await Article.objects.ain_bulk(field_name="nonexistent") # E: Article has no field named 'nonexistent' [misc]
await Article.objects.ain_bulk(field_name="title") # E: "in_bulk()"'s field_name must be a unique field but "title" isn't [misc]
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class Article(models.Model):
title = models.CharField(max_length=200)
slug = models.SlugField(unique=True)
published = models.BooleanField(default=False)
74 changes: 74 additions & 0 deletions tests/typecheck/managers/querysets/test_update.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
- case: update_valid_fields
installed_apps:
- myapp
main: |
from myapp.models import Article, Author

# Simple field
Article.objects.update(title="new title")

# Multiple fields
Article.objects.update(title="new title", published=True)

# FK id field (concrete)
Article.objects.update(author_id=1)

# Chained with filter
Article.objects.filter(published=False).update(published=True)

# Async variant
async def test_async() -> None:
await Article.objects.aupdate(title="new title")
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class Tag(models.Model):
name = models.CharField(max_length=50)

class Author(models.Model):
name = models.CharField(max_length=100)

class Article(models.Model):
title = models.CharField(max_length=200)
content = models.TextField()
published = models.BooleanField(default=False)
author = models.ForeignKey(Author, on_delete=models.CASCADE)
tags = models.ManyToManyField(Tag)


- case: update_invalid_fields
installed_apps:
- myapp
main: |
from myapp.models import Article, Author

# Non-existent field
Article.objects.update(nonexistent=1) # E: Article has no field named 'nonexistent' [misc]

# M2M field (not concrete)
Article.objects.update(tags=1) # E: "update()" can only be used with concrete fields. Got "tags" [misc]

# Async variant with invalid field
async def test_async() -> None:
await Article.objects.aupdate(nonexistent=1) # E: Article has no field named 'nonexistent' [misc]
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models

class Tag(models.Model):
name = models.CharField(max_length=50)

class Author(models.Model):
name = models.CharField(max_length=100)

class Article(models.Model):
title = models.CharField(max_length=200)
content = models.TextField()
published = models.BooleanField(default=False)
author = models.ForeignKey(Author, on_delete=models.CASCADE)
tags = models.ManyToManyField(Tag)
Loading