Skip to content

Commit 6273ddc

Browse files
committed
misc: Add module type hints
This commit adds type hints to the module. `mypy` does not report any errors besides missing type hints from other imported libraries.
1 parent e23af32 commit 6273ddc

File tree

3 files changed

+71
-52
lines changed

3 files changed

+71
-52
lines changed

asgi_csrf.py

Lines changed: 70 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
from http.cookies import SimpleCookie
2-
from enum import Enum
3-
import fnmatch
2+
from collections.abc import Awaitable, Callable, Container, Mapping, MutableMapping
3+
from enum import IntEnum
44
from functools import wraps
55
from python_multipart import FormParser
66
import os
7+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
78
from urllib.parse import parse_qsl
89
from itsdangerous.url_safe import URLSafeSerializer
910
from itsdangerous import BadSignature
1011
import secrets
1112

13+
if TYPE_CHECKING:
14+
from python_multipart.multipart import OnFileCallback
15+
16+
Scope = MutableMapping[str, Any]
17+
Message = MutableMapping[str, Any]
18+
Receive = Callable[[], Awaitable[Message]]
19+
Send = Callable[[Message], Awaitable[None]]
20+
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
21+
1222
DEFAULT_COOKIE_NAME = "csrftoken"
1323
DEFAULT_COOKIE_PATH = "/"
1424
DEFAULT_COOKIE_DOMAIN = None
@@ -21,14 +31,15 @@
2131
ENV_SECRET = "ASGI_CSRF_SECRET"
2232

2333

24-
class Errors(Enum):
34+
35+
class Errors(IntEnum):
2536
FORM_URLENCODED_MISMATCH = 1
2637
MULTIPART_MISMATCH = 2
2738
FILE_BEFORE_TOKEN = 3
2839
UNKNOWN_CONTENT_TYPE = 4
2940

3041

31-
error_messages = {
42+
error_messages: dict[int, str] = {
3243
Errors.FORM_URLENCODED_MISMATCH: "form-urlencoded POST field did not match cookie",
3344
Errors.MULTIPART_MISMATCH: "multipart/form-data POST field did not match cookie",
3445
Errors.FILE_BEFORE_TOKEN: "File encountered before csrftoken - make sure csrftoken is first in the HTML",
@@ -37,30 +48,30 @@ class Errors(Enum):
3748

3849

3950
def asgi_csrf_decorator(
40-
cookie_name=DEFAULT_COOKIE_NAME,
41-
http_header=DEFAULT_HTTP_HEADER,
42-
form_input=DEFAULT_FORM_INPUT,
43-
signing_secret=None,
44-
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
45-
always_protect=None,
46-
always_set_cookie=False,
47-
skip_if_scope=None,
48-
cookie_path=DEFAULT_COOKIE_PATH,
49-
cookie_domain=DEFAULT_COOKIE_DOMAIN,
50-
cookie_secure=DEFAULT_COOKIE_SECURE,
51-
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
52-
send_csrf_failed=None,
53-
):
51+
cookie_name: str = DEFAULT_COOKIE_NAME,
52+
http_header: str = DEFAULT_HTTP_HEADER,
53+
form_input: str = DEFAULT_FORM_INPUT,
54+
signing_secret: Optional[str] = None,
55+
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
56+
always_protect: Optional[Container[str]] = None,
57+
always_set_cookie: bool = False,
58+
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
59+
cookie_path: str = DEFAULT_COOKIE_PATH,
60+
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
61+
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
62+
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
63+
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
64+
) -> Callable[["ASGIApp"], "ASGIApp"]:
5465
send_csrf_failed = send_csrf_failed or default_send_csrf_failed
5566
if signing_secret is None:
5667
signing_secret = os.environ.get(ENV_SECRET, None)
5768
if signing_secret is None:
5869
signing_secret = make_secret(128)
5970
signer = URLSafeSerializer(signing_secret)
6071

61-
def _asgi_csrf_decorator(app):
72+
def _asgi_csrf_decorator(app: "ASGIApp") -> "ASGIApp":
6273
@wraps(app)
63-
async def app_wrapped_with_csrf(scope, receive, send):
74+
async def app_wrapped_with_csrf(scope: "Scope", receive: "Receive", send: "Send") -> None:
6475
if scope["type"] != "http":
6576
await app(scope, receive, send)
6677
return
@@ -84,7 +95,7 @@ async def app_wrapped_with_csrf(scope, receive, send):
8495
if not has_csrftoken_cookie:
8596
csrftoken = signer.dumps(make_secret(16), signing_namespace)
8697

87-
def get_csrftoken():
98+
def get_csrftoken() -> Optional[str]:
8899
nonlocal should_set_cookie
89100
nonlocal page_needs_vary_header
90101
page_needs_vary_header = True
@@ -94,7 +105,7 @@ def get_csrftoken():
94105

95106
scope = {**scope, **{SCOPE_KEY: get_csrftoken}}
96107

97-
async def wrapped_send(event):
108+
async def wrapped_send(event: "Message") -> None:
98109
if event["type"] == "http.response.start":
99110
original_headers = event.get("headers") or []
100111
new_headers = []
@@ -144,7 +155,7 @@ async def wrapped_send(event):
144155
await app(scope, receive, wrapped_send)
145156
else:
146157
# Check for CSRF token in various places
147-
headers = dict(scope.get("headers" or []))
158+
headers = dict(scope.get("headers") or [])
148159
if secrets.compare_digest(
149160
headers.get(http_header.encode("latin-1"), b"").decode("latin-1"),
150161
csrftoken,
@@ -174,7 +185,7 @@ async def wrapped_send(event):
174185
if content_type == b"application/x-www-form-urlencoded":
175186
# Consume entire POST body and check for csrftoken field
176187
post_data, replay_receive = await _parse_form_urlencoded(receive)
177-
if secrets.compare_digest(post_data.get(form_input, ""), csrftoken):
188+
if secrets.compare_digest(post_data.get(form_input, ""), csrftoken or ""):
178189
# All is good! Forward on the request and replay the body
179190
await app(scope, replay_receive, wrapped_send)
180191
return
@@ -186,7 +197,7 @@ async def wrapped_send(event):
186197
elif content_type == b"multipart/form-data":
187198
# Consume non-file items until we see a csrftoken
188199
# If we see a file item first, it's an error
189-
boundary = headers.get(b"content-type").split(b"; boundary=")[1]
200+
boundary = headers.get(b"content-type", "").split(b"; boundary=")[1]
190201
assert boundary is not None, "missing 'boundary' header: {}".format(
191202
repr(headers)
192203
)
@@ -197,7 +208,7 @@ async def wrapped_send(event):
197208
replay_receive,
198209
) = await _parse_multipart_form_data(boundary, receive)
199210
if not secrets.compare_digest(
200-
csrftoken_from_body or "", csrftoken
211+
csrftoken_from_body or "", csrftoken or ""
201212
):
202213
await send_csrf_failed(
203214
scope,
@@ -226,7 +237,7 @@ async def wrapped_send(event):
226237
return _asgi_csrf_decorator
227238

228239

229-
async def _parse_form_urlencoded(receive):
240+
async def _parse_form_urlencoded(receive: "Receive") -> Tuple[Dict[str, str], "Receive"]:
230241
# Returns {key: value}, replay_receive
231242
# where replay_receive is an awaitable that can replay what was received
232243
# We ignore cases like foo=one&foo=two because we do not need to
@@ -241,7 +252,7 @@ async def _parse_form_urlencoded(receive):
241252
body += message.get("body", b"")
242253
more_body = message.get("more_body", False)
243254

244-
async def replay_receive():
255+
async def replay_receive() -> "Message":
245256
if messages:
246257
return messages.pop(0)
247258
else:
@@ -262,7 +273,7 @@ class FileBeforeToken(Exception):
262273
pass
263274

264275

265-
async def _parse_multipart_form_data(boundary, receive):
276+
async def _parse_multipart_form_data(boundary: bytes, receive: "Receive") -> Tuple[Optional[str], "Receive"]:
266277
# Returns (csrftoken, replay_receive) - or raises an exception
267278
csrftoken = None
268279

@@ -272,26 +283,34 @@ def on_field(field):
272283
raise TokenFound(csrftoken)
273284

274285
class ErrorOnWrite:
275-
def __init__(self, file_name, field_name, config):
286+
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: Mapping[str, Any]) -> None:
276287
pass
277288

278-
def write(self, data):
289+
def write(self, data: bytes) -> int:
279290
raise FileBeforeToken
280291

292+
def finalize(self) -> None: ...
293+
294+
def close(self) -> None: ...
295+
296+
def set_none(self) -> None: ...
297+
281298
body = b""
282299
more_body = True
283-
messages = []
300+
messages: List["Message"] = []
284301

285-
async def replay_receive():
302+
async def replay_receive() -> "Message":
286303
if messages:
287304
return messages.pop(0)
288305
else:
289306
return await receive()
290307

308+
on_file: OnFileCallback = lambda _: None
309+
291310
form_parser = FormParser(
292311
"multipart/form-data",
293-
on_field,
294-
lambda: None,
312+
on_field=on_field,
313+
on_file=on_file,
295314
boundary=boundary,
296315
FileClass=ErrorOnWrite,
297316
)
@@ -308,7 +327,7 @@ async def replay_receive():
308327
return None, replay_receive
309328

310329

311-
async def default_send_csrf_failed(scope, send, message_id):
330+
async def default_send_csrf_failed(scope: "Scope", send: "Send", message_id: int) -> None:
312331
assert scope["type"] == "http"
313332
await send(
314333
{
@@ -323,19 +342,19 @@ async def default_send_csrf_failed(scope, send, message_id):
323342

324343
def asgi_csrf(
325344
app,
326-
cookie_name=DEFAULT_COOKIE_NAME,
327-
http_header=DEFAULT_HTTP_HEADER,
328-
signing_secret=None,
329-
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
330-
always_protect=None,
331-
always_set_cookie=False,
332-
skip_if_scope=None,
333-
cookie_path=DEFAULT_COOKIE_PATH,
334-
cookie_domain=DEFAULT_COOKIE_DOMAIN,
335-
cookie_secure=DEFAULT_COOKIE_SECURE,
336-
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
337-
send_csrf_failed=None,
338-
):
345+
cookie_name: str = DEFAULT_COOKIE_NAME,
346+
http_header: str = DEFAULT_HTTP_HEADER,
347+
signing_secret: Optional[str] = None,
348+
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
349+
always_protect: Optional[Container[str]] = None,
350+
always_set_cookie: bool = False,
351+
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
352+
cookie_path: str = DEFAULT_COOKIE_PATH,
353+
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
354+
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
355+
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
356+
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
357+
) -> "ASGIApp":
339358
return asgi_csrf_decorator(
340359
cookie_name,
341360
http_header,
@@ -352,7 +371,7 @@ def asgi_csrf(
352371
)(app)
353372

354373

355-
def cookies_from_scope(scope):
374+
def cookies_from_scope(scope: "Scope") -> Dict[str, str]:
356375
cookie = dict(scope.get("headers") or {}).get(b"cookie")
357376
if not cookie:
358377
return {}
@@ -364,5 +383,5 @@ def cookies_from_scope(scope):
364383
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
365384

366385

367-
def make_secret(length):
386+
def make_secret(length: int) -> str:
368387
return "".join(secrets.choice(allowed_chars) for i in range(length))

py.typed

Whitespace-only changes.

tests/test_asgi_csrf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):
178178

179179
@pytest.mark.asyncio
180180
@pytest.mark.parametrize("custom_errors", (False, True))
181-
async def test_prevents_post_if_cookie_not_sent_in_post(
181+
async def test_prevents_post_if_cookie_different_than_data(
182182
custom_errors, app_csrf, app_csrf_custom_errors, csrftoken
183183
):
184184
async with httpx.AsyncClient(

0 commit comments

Comments
 (0)