Skip to content

Commit 65a0514

Browse files
committed
misc: Add module type hints
This commit adds type hints to the module. It's missing the addition of a `py.typed` file to be PEP 561 compliant [1]. However, it seems that will require moving away from the `py_modules` setup in `setup.py`, according to the PEP: > This PEP does not support distributing typing information as part of > module-only distributions or single-file modules within namespace > packages. > The single-file module should be refactored into a package and > indicate that the package supports typing as described above. [1] https://peps.python.org/pep-0561/
1 parent 7d73410 commit 65a0514

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

asgi_csrf.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from http.cookies import SimpleCookie
2-
from enum import Enum
3-
import fnmatch
2+
from collections.abc import Awaitable, Callable, Container, Mapping, MutableMapping
3+
from enum import IntEnum
44
from functools import wraps
55
from multipart import FormParser
66
import os
7+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
78
from urllib.parse import parse_qsl
89
from itsdangerous.url_safe import URLSafeSerializer
910
from itsdangerous import BadSignature
1011
import secrets
1112

13+
if TYPE_CHECKING:
14+
Scope = MutableMapping[str, Any]
15+
Message = MutableMapping[str, Any]
16+
Receive = Callable[[], Awaitable[Message]]
17+
Send = Callable[[Message], Awaitable[None]]
18+
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
19+
1220
DEFAULT_COOKIE_NAME = "csrftoken"
1321
DEFAULT_COOKIE_PATH = "/"
1422
DEFAULT_COOKIE_DOMAIN = None
@@ -21,7 +29,8 @@
2129
ENV_SECRET = "ASGI_CSRF_SECRET"
2230

2331

24-
class Errors(Enum):
32+
33+
class Errors(IntEnum):
2534
FORM_URLENCODED_MISMATCH = 1
2635
MULTIPART_MISMATCH = 2
2736
FILE_BEFORE_TOKEN = 3
@@ -37,30 +46,30 @@ class Errors(Enum):
3746

3847

3948
def asgi_csrf_decorator(
40-
cookie_name=DEFAULT_COOKIE_NAME,
41-
http_header=DEFAULT_HTTP_HEADER,
42-
form_input=DEFAULT_FORM_INPUT,
43-
signing_secret=None,
44-
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
45-
always_protect=None,
46-
always_set_cookie=False,
47-
skip_if_scope=None,
48-
cookie_path=DEFAULT_COOKIE_PATH,
49-
cookie_domain=DEFAULT_COOKIE_DOMAIN,
50-
cookie_secure=DEFAULT_COOKIE_SECURE,
51-
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
52-
send_csrf_failed=None,
53-
):
49+
cookie_name: str = DEFAULT_COOKIE_NAME,
50+
http_header: str = DEFAULT_HTTP_HEADER,
51+
form_input: str = DEFAULT_FORM_INPUT,
52+
signing_secret: Optional[str] = None,
53+
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
54+
always_protect: Optional[Container[str]] = None,
55+
always_set_cookie: bool = False,
56+
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
57+
cookie_path: str = DEFAULT_COOKIE_PATH,
58+
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
59+
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
60+
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
61+
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
62+
) -> Callable[["ASGIApp"], "ASGIApp"]:
5463
send_csrf_failed = send_csrf_failed or default_send_csrf_failed
5564
if signing_secret is None:
5665
signing_secret = os.environ.get(ENV_SECRET, None)
5766
if signing_secret is None:
5867
signing_secret = make_secret(128)
5968
signer = URLSafeSerializer(signing_secret)
6069

61-
def _asgi_csrf_decorator(app):
70+
def _asgi_csrf_decorator(app: "ASGIApp") -> "ASGIApp":
6271
@wraps(app)
63-
async def app_wrapped_with_csrf(scope, receive, send):
72+
async def app_wrapped_with_csrf(scope: "Scope", receive: "Receive", send: "Send") -> None:
6473
if scope["type"] != "http":
6574
await app(scope, receive, send)
6675
return
@@ -84,7 +93,7 @@ async def app_wrapped_with_csrf(scope, receive, send):
8493
if not has_csrftoken_cookie:
8594
csrftoken = signer.dumps(make_secret(16), signing_namespace)
8695

87-
def get_csrftoken():
96+
def get_csrftoken() -> Optional[str]:
8897
nonlocal should_set_cookie
8998
nonlocal page_needs_vary_header
9099
page_needs_vary_header = True
@@ -94,7 +103,7 @@ def get_csrftoken():
94103

95104
scope = {**scope, **{SCOPE_KEY: get_csrftoken}}
96105

97-
async def wrapped_send(event):
106+
async def wrapped_send(event: "Message") -> None:
98107
if event["type"] == "http.response.start":
99108
original_headers = event.get("headers") or []
100109
new_headers = []
@@ -144,7 +153,7 @@ async def wrapped_send(event):
144153
await app(scope, receive, wrapped_send)
145154
else:
146155
# Check for CSRF token in various places
147-
headers = dict(scope.get("headers" or []))
156+
headers = dict(scope.get("headers") or [])
148157
if secrets.compare_digest(
149158
headers.get(http_header.encode("latin-1"), b"").decode("latin-1"),
150159
csrftoken,
@@ -226,7 +235,7 @@ async def wrapped_send(event):
226235
return _asgi_csrf_decorator
227236

228237

229-
async def _parse_form_urlencoded(receive):
238+
async def _parse_form_urlencoded(receive: "Receive") -> Tuple[Dict[str, str], "Receive"]:
230239
# Returns {key: value}, replay_receive
231240
# where replay_receive is an awaitable that can replay what was received
232241
# We ignore cases like foo=one&foo=two because we do not need to
@@ -241,7 +250,7 @@ async def _parse_form_urlencoded(receive):
241250
body += message.get("body", b"")
242251
more_body = message.get("more_body", False)
243252

244-
async def replay_receive():
253+
async def replay_receive() -> "Message":
245254
if messages:
246255
return messages.pop(0)
247256
else:
@@ -262,7 +271,7 @@ class FileBeforeToken(Exception):
262271
pass
263272

264273

265-
async def _parse_multipart_form_data(boundary, receive):
274+
async def _parse_multipart_form_data(boundary: bytes, receive: "Receive") -> Tuple[Optional[str], "Receive"]:
266275
# Returns (csrftoken, replay_receive) - or raises an exception
267276
csrftoken = None
268277

@@ -272,17 +281,17 @@ def on_field(field):
272281
raise TokenFound(csrftoken)
273282

274283
class ErrorOnWrite:
275-
def __init__(self, file_name, field_name, config):
284+
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: Mapping[str, Any]) -> None:
276285
pass
277286

278-
def write(self, data):
287+
def write(self, data: bytes) -> int:
279288
raise FileBeforeToken
280289

281290
body = b""
282291
more_body = True
283-
messages = []
292+
messages: List["Message"] = []
284293

285-
async def replay_receive():
294+
async def replay_receive() -> "Message":
286295
if messages:
287296
return messages.pop(0)
288297
else:
@@ -308,7 +317,7 @@ async def replay_receive():
308317
return None, replay_receive
309318

310319

311-
async def default_send_csrf_failed(scope, send, message_id):
320+
async def default_send_csrf_failed(scope: "Scope", send: "Send", message_id: int) -> None:
312321
assert scope["type"] == "http"
313322
await send(
314323
{
@@ -323,19 +332,19 @@ async def default_send_csrf_failed(scope, send, message_id):
323332

324333
def asgi_csrf(
325334
app,
326-
cookie_name=DEFAULT_COOKIE_NAME,
327-
http_header=DEFAULT_HTTP_HEADER,
328-
signing_secret=None,
329-
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
330-
always_protect=None,
331-
always_set_cookie=False,
332-
skip_if_scope=None,
333-
cookie_path=DEFAULT_COOKIE_PATH,
334-
cookie_domain=DEFAULT_COOKIE_DOMAIN,
335-
cookie_secure=DEFAULT_COOKIE_SECURE,
336-
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
337-
send_csrf_failed=None,
338-
):
335+
cookie_name: str = DEFAULT_COOKIE_NAME,
336+
http_header: str = DEFAULT_HTTP_HEADER,
337+
signing_secret: Optional[str] = None,
338+
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
339+
always_protect: Optional[Container[str]] = None,
340+
always_set_cookie: bool = False,
341+
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
342+
cookie_path: str = DEFAULT_COOKIE_PATH,
343+
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
344+
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
345+
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
346+
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
347+
) -> "ASGIApp":
339348
return asgi_csrf_decorator(
340349
cookie_name,
341350
http_header,
@@ -352,7 +361,7 @@ def asgi_csrf(
352361
)(app)
353362

354363

355-
def cookies_from_scope(scope):
364+
def cookies_from_scope(scope: "Scope") -> Dict[str, str]:
356365
cookie = dict(scope.get("headers") or {}).get(b"cookie")
357366
if not cookie:
358367
return {}
@@ -364,5 +373,5 @@ def cookies_from_scope(scope):
364373
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
365374

366375

367-
def make_secret(length):
376+
def make_secret(length: int) -> str:
368377
return "".join(secrets.choice(allowed_chars) for i in range(length))

test_asgi_csrf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):
186186

187187
@pytest.mark.asyncio
188188
@pytest.mark.parametrize("custom_errors", (False, True))
189-
async def test_prevents_post_if_cookie_not_sent_in_post(
189+
async def test_prevents_post_if_cookie_different_than_data(
190190
custom_errors, app_csrf, app_csrf_custom_errors, csrftoken
191191
):
192192
async with httpx.AsyncClient(

0 commit comments

Comments
 (0)