Skip to content

Commit afa16bf

Browse files
authored
Make decorator functions transparent to Mypy (#306)
By declaring return type as -> Callable[[_C], _C], Mypy can infer that the decorated function has also the same arguments and return type as the original. View functions are constrained to return HttpResponseBase (or any subclass of it). Also added typecheck test coverage to most of the cases.
1 parent f770731 commit afa16bf

File tree

8 files changed

+126
-14
lines changed

8 files changed

+126
-14
lines changed
Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1-
from typing import Callable, List, Optional, Set, Union
1+
from typing import Callable, List, Optional, Set, Union, TypeVar, overload
22

33
from django.contrib.auth import REDIRECT_FIELD_NAME as REDIRECT_FIELD_NAME # noqa: F401
4+
from django.http.response import HttpResponseBase
5+
6+
from django.contrib.auth.models import AbstractUser
7+
8+
_VIEW = TypeVar("_VIEW", bound=Callable[..., HttpResponseBase])
49

510
def user_passes_test(
6-
test_func: Callable, login_url: Optional[str] = ..., redirect_field_name: str = ...
7-
) -> Callable: ...
8-
def login_required(
9-
function: Optional[Callable] = ..., redirect_field_name: str = ..., login_url: Optional[str] = ...
10-
) -> Callable: ...
11+
test_func: Callable[[AbstractUser], bool], login_url: Optional[str] = ..., redirect_field_name: str = ...
12+
) -> Callable[[_VIEW], _VIEW]: ...
13+
14+
# There are two ways of calling @login_required: @with(arguments) and @bare
15+
@overload
16+
def login_required(redirect_field_name: str = ..., login_url: Optional[str] = ...) -> Callable[[_VIEW], _VIEW]: ...
17+
@overload
18+
def login_required(function: _VIEW, redirect_field_name: str = ..., login_url: Optional[str] = ...) -> _VIEW: ...
1119
def permission_required(
1220
perm: Union[List[str], Set[str], str], login_url: None = ..., raise_exception: bool = ...
13-
) -> Callable: ...
21+
) -> Callable[[_VIEW], _VIEW]: ...

django-stubs/db/transaction.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,11 @@ def atomic(using: _C) -> _C: ...
3939
# Decorator or context-manager with parameters
4040
@overload
4141
def atomic(using: Optional[str] = ..., savepoint: bool = ...) -> Atomic: ...
42-
def non_atomic_requests(using: Callable = ...) -> Callable: ...
42+
43+
# Bare decorator
44+
@overload
45+
def non_atomic_requests(using: _C) -> _C: ...
46+
47+
# Decorator with arguments
48+
@overload
49+
def non_atomic_requests(using: Optional[str] = ...) -> Callable[[_C], _C]: ...

django-stubs/test/utils.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ from typing import (
1616
Type,
1717
Union,
1818
ContextManager,
19+
TypeVar,
1920
)
2021

2122
from django.apps.registry import Apps
@@ -29,6 +30,7 @@ from django.conf import LazySettings, Settings
2930

3031
_TestClass = Type[SimpleTestCase]
3132
_DecoratedTest = Union[Callable, _TestClass]
33+
_C = TypeVar("_C", bound=Callable) # Any callable
3234

3335
TZ_SUPPORT: bool = ...
3436

@@ -56,7 +58,7 @@ class TestContextDecorator:
5658
def __enter__(self) -> Optional[Apps]: ...
5759
def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: ...
5860
def decorate_class(self, cls: _TestClass) -> _TestClass: ...
59-
def decorate_callable(self, func: Callable) -> Callable: ...
61+
def decorate_callable(self, func: _C) -> _C: ...
6062
def __call__(self, decorated: _DecoratedTest) -> Any: ...
6163

6264
class override_settings(TestContextDecorator):
@@ -146,7 +148,7 @@ def get_unique_databases_and_mirrors() -> Tuple[Dict[_Signature, _TestDatabase],
146148
def teardown_databases(
147149
old_config: Iterable[Tuple[Any, str, bool]], verbosity: int, parallel: int = ..., keepdb: bool = ...
148150
) -> None: ...
149-
def require_jinja2(test_func: Callable) -> Callable: ...
151+
def require_jinja2(test_func: _C) -> _C: ...
150152
@contextmanager
151153
def register_lookup(
152154
field: Type[RegisterLookupMixin], *lookups: Type[Union[Lookup, Transform]], lookup_name: Optional[str] = ...

django-stubs/utils/decorators.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from typing import Any, Callable, Iterable, Optional, Type, Union
1+
from typing import Any, Callable, Iterable, Optional, Type, Union, TypeVar
22

33
from django.utils.deprecation import MiddlewareMixin
4+
from django.views.generic.base import View
5+
6+
_T = TypeVar("_T", bound=Union[View, Callable]) # Any callable
47

58
class classonlymethod(classmethod): ...
69

7-
def method_decorator(decorator: Union[Callable, Iterable[Callable]], name: str = ...) -> Callable: ...
10+
def method_decorator(decorator: Union[Callable, Iterable[Callable]], name: str = ...) -> Callable[[_T], _T]: ...
811
def decorator_from_middleware_with_args(middleware_class: type) -> Callable: ...
912
def decorator_from_middleware(middleware_class: type) -> Callable: ...
10-
def available_attrs(fn: Any): ...
13+
def available_attrs(fn: Callable): ...
1114
def make_middleware_decorator(middleware_class: Type[MiddlewareMixin]) -> Callable: ...
1215

1316
class classproperty:

scripts/enabled_test_modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@
170170
'Incompatible types in assignment (expression has type "Optional[Any]", variable has type "FloatModel")'
171171
],
172172
'decorators': [
173-
'"Type[object]" has no attribute "method"'
173+
'"Type[object]" has no attribute "method"',
174+
'Value of type variable "_T" of function cannot be "descriptor_wrapper"'
174175
],
175176
'expressions_window': [
176177
'has incompatible type "str"'
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
- case: login_required_bare
2+
main: |
3+
from django.contrib.auth.decorators import login_required
4+
@login_required
5+
def view_func(request): ...
6+
reveal_type(view_func) # N: Revealed type is 'def (request: Any) -> Any'
7+
- case: login_required_fancy
8+
main: |
9+
from django.contrib.auth.decorators import login_required
10+
from django.core.handlers.wsgi import WSGIRequest
11+
from django.http import HttpResponse
12+
@login_required(redirect_field_name='a', login_url='b')
13+
def view_func(request: WSGIRequest, arg: str) -> HttpResponse: ...
14+
reveal_type(view_func) # N: Revealed type is 'def (request: django.core.handlers.wsgi.WSGIRequest, arg: builtins.str) -> django.http.response.HttpResponse'
15+
- case: login_required_weird
16+
main: |
17+
from django.contrib.auth.decorators import login_required
18+
# This is non-conventional usage, but covered in Django tests, so we allow it.
19+
def view_func(request): ...
20+
wrapped_view = login_required(view_func, redirect_field_name='a', login_url='b')
21+
reveal_type(wrapped_view) # N: Revealed type is 'def (request: Any) -> Any'
22+
- case: login_required_incorrect_return
23+
main: |
24+
from django.contrib.auth.decorators import login_required
25+
@login_required() # E: Value of type variable "_VIEW" of function cannot be "Callable[[Any], str]"
26+
def view_func2(request) -> str: ...
27+
- case: user_passes_test
28+
main: |
29+
from django.contrib.auth.decorators import user_passes_test
30+
@user_passes_test(lambda u: u.username.startswith('super'))
31+
def view_func(request): ...
32+
reveal_type(view_func) # N: Revealed type is 'def (request: Any) -> Any'
33+
- case: user_passes_test_bare_is_error
34+
main: |
35+
from django.http.response import HttpResponse
36+
from django.contrib.auth.decorators import user_passes_test
37+
@user_passes_test # E: Argument 1 to "user_passes_test" has incompatible type "Callable[[Any], HttpResponse]"; expected "Callable[[AbstractUser], bool]"
38+
def view_func(request) -> HttpResponse: ...
39+
- case: permission_required
40+
main: |
41+
from django.contrib.auth.decorators import permission_required
42+
@permission_required('polls.can_vote')
43+
def view_func(request): ...
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
- case: atomic_bare
2+
main: |
3+
from django.db.transaction import atomic
4+
@atomic
5+
def func(x: int) -> list: ...
6+
reveal_type(func) # N: Revealed type is 'def (x: builtins.int) -> builtins.list[Any]'
7+
- case: atomic_args
8+
main: |
9+
from django.db.transaction import atomic
10+
@atomic(using='bla', savepoint=False)
11+
def func(x: int) -> list: ...
12+
reveal_type(func) # N: Revealed type is 'def (x: builtins.int) -> builtins.list[Any]'
13+
- case: non_atomic_requests_bare
14+
main: |
15+
from django.db.transaction import non_atomic_requests
16+
@non_atomic_requests
17+
def view_func(request): ...
18+
reveal_type(view_func) # N: Revealed type is 'def (request: Any) -> Any'
19+
20+
- case: non_atomic_requests_args
21+
main: |
22+
from django.http.request import HttpRequest
23+
from django.http.response import HttpResponse
24+
from django.db.transaction import non_atomic_requests
25+
@non_atomic_requests
26+
def view_func(request: HttpRequest, arg: str) -> HttpResponse: ...
27+
reveal_type(view_func) # N: Revealed type is 'def (request: django.http.request.HttpRequest, arg: builtins.str) -> django.http.response.HttpResponse'
28+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
- case: method_decorator_class
2+
main: |
3+
from django.views.generic.base import View
4+
from django.utils.decorators import method_decorator
5+
from django.contrib.auth.decorators import login_required
6+
@method_decorator(login_required, name='dispatch')
7+
class TestView(View): ...
8+
reveal_type(TestView()) # N: Revealed type is 'main.TestView'
9+
- case: method_decorator_function
10+
main: |
11+
from django.views.generic.base import View
12+
from django.utils.decorators import method_decorator
13+
from django.contrib.auth.decorators import login_required
14+
from django.http.response import HttpResponse
15+
from django.http.request import HttpRequest
16+
class TestView(View):
17+
@method_decorator(login_required)
18+
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
19+
return super().dispatch(request, *args, **kwargs)
20+
reveal_type(dispatch) # N: Revealed type is 'def (self: main.TestView, request: django.http.request.HttpRequest, *args: Any, **kwargs: Any) -> django.http.response.HttpResponse'

0 commit comments

Comments
 (0)