11from http .cookies import SimpleCookie
2- from enum import Enum
3- import fnmatch
2+ from collections . abc import Awaitable , Callable , Container , Mapping , MutableMapping
3+ from enum import IntEnum
44from functools import wraps
55from python_multipart import FormParser
66import os
7+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
78from urllib .parse import parse_qsl
89from itsdangerous .url_safe import URLSafeSerializer
910from itsdangerous import BadSignature
1011import secrets
1112
13+ if TYPE_CHECKING :
14+ from python_multipart .multipart import OnFileCallback
15+
16+ Scope = MutableMapping [str , Any ]
17+ Message = MutableMapping [str , Any ]
18+ Receive = Callable [[], Awaitable [Message ]]
19+ Send = Callable [[Message ], Awaitable [None ]]
20+ ASGIApp = Callable [[Scope , Receive , Send ], Awaitable [None ]]
21+
1222DEFAULT_COOKIE_NAME = "csrftoken"
1323DEFAULT_COOKIE_PATH = "/"
1424DEFAULT_COOKIE_DOMAIN = None
2131ENV_SECRET = "ASGI_CSRF_SECRET"
2232
2333
24- class Errors (Enum ):
34+
35+ class Errors (IntEnum ):
2536 FORM_URLENCODED_MISMATCH = 1
2637 MULTIPART_MISMATCH = 2
2738 FILE_BEFORE_TOKEN = 3
2839 UNKNOWN_CONTENT_TYPE = 4
2940
3041
31- error_messages = {
42+ error_messages : dict [ int , str ] = {
3243 Errors .FORM_URLENCODED_MISMATCH : "form-urlencoded POST field did not match cookie" ,
3344 Errors .MULTIPART_MISMATCH : "multipart/form-data POST field did not match cookie" ,
3445 Errors .FILE_BEFORE_TOKEN : "File encountered before csrftoken - make sure csrftoken is first in the HTML" ,
@@ -37,30 +48,30 @@ class Errors(Enum):
3748
3849
3950def asgi_csrf_decorator (
40- cookie_name = DEFAULT_COOKIE_NAME ,
41- http_header = DEFAULT_HTTP_HEADER ,
42- form_input = DEFAULT_FORM_INPUT ,
43- signing_secret = None ,
44- signing_namespace = DEFAULT_SIGNING_NAMESPACE ,
45- always_protect = None ,
46- always_set_cookie = False ,
47- skip_if_scope = None ,
48- cookie_path = DEFAULT_COOKIE_PATH ,
49- cookie_domain = DEFAULT_COOKIE_DOMAIN ,
50- cookie_secure = DEFAULT_COOKIE_SECURE ,
51- cookie_samesite = DEFAULT_COOKIE_SAMESITE ,
52- send_csrf_failed = None ,
53- ):
51+ cookie_name : str = DEFAULT_COOKIE_NAME ,
52+ http_header : str = DEFAULT_HTTP_HEADER ,
53+ form_input : str = DEFAULT_FORM_INPUT ,
54+ signing_secret : Optional [ str ] = None ,
55+ signing_namespace : str = DEFAULT_SIGNING_NAMESPACE ,
56+ always_protect : Optional [ Container [ str ]] = None ,
57+ always_set_cookie : bool = False ,
58+ skip_if_scope : Optional [ Callable [[ "Scope" ], bool ]] = None ,
59+ cookie_path : str = DEFAULT_COOKIE_PATH ,
60+ cookie_domain : Optional [ str ] = DEFAULT_COOKIE_DOMAIN ,
61+ cookie_secure : bool = DEFAULT_COOKIE_SECURE ,
62+ cookie_samesite : str = DEFAULT_COOKIE_SAMESITE ,
63+ send_csrf_failed : Optional [ Callable [[ "Scope" , "Send" , int ], Awaitable [ None ]]] = None ,
64+ ) -> Callable [[ "ASGIApp" ], "ASGIApp" ] :
5465 send_csrf_failed = send_csrf_failed or default_send_csrf_failed
5566 if signing_secret is None :
5667 signing_secret = os .environ .get (ENV_SECRET , None )
5768 if signing_secret is None :
5869 signing_secret = make_secret (128 )
5970 signer = URLSafeSerializer (signing_secret )
6071
61- def _asgi_csrf_decorator (app ) :
72+ def _asgi_csrf_decorator (app : "ASGIApp" ) -> "ASGIApp" :
6273 @wraps (app )
63- async def app_wrapped_with_csrf (scope , receive , send ) :
74+ async def app_wrapped_with_csrf (scope : "Scope" , receive : "Receive" , send : "Send" ) -> None :
6475 if scope ["type" ] != "http" :
6576 await app (scope , receive , send )
6677 return
@@ -84,7 +95,7 @@ async def app_wrapped_with_csrf(scope, receive, send):
8495 if not has_csrftoken_cookie :
8596 csrftoken = signer .dumps (make_secret (16 ), signing_namespace )
8697
87- def get_csrftoken ():
98+ def get_csrftoken () -> Optional [ str ] :
8899 nonlocal should_set_cookie
89100 nonlocal page_needs_vary_header
90101 page_needs_vary_header = True
@@ -94,7 +105,7 @@ def get_csrftoken():
94105
95106 scope = {** scope , ** {SCOPE_KEY : get_csrftoken }}
96107
97- async def wrapped_send (event ) :
108+ async def wrapped_send (event : "Message" ) -> None :
98109 if event ["type" ] == "http.response.start" :
99110 original_headers = event .get ("headers" ) or []
100111 new_headers = []
@@ -144,7 +155,7 @@ async def wrapped_send(event):
144155 await app (scope , receive , wrapped_send )
145156 else :
146157 # Check for CSRF token in various places
147- headers = dict (scope .get ("headers" or []) )
158+ headers = dict (scope .get ("headers" ) or [])
148159 if secrets .compare_digest (
149160 headers .get (http_header .encode ("latin-1" ), b"" ).decode ("latin-1" ),
150161 csrftoken ,
@@ -174,7 +185,7 @@ async def wrapped_send(event):
174185 if content_type == b"application/x-www-form-urlencoded" :
175186 # Consume entire POST body and check for csrftoken field
176187 post_data , replay_receive = await _parse_form_urlencoded (receive )
177- if secrets .compare_digest (post_data .get (form_input , "" ), csrftoken ):
188+ if secrets .compare_digest (post_data .get (form_input , "" ), csrftoken or "" ):
178189 # All is good! Forward on the request and replay the body
179190 await app (scope , replay_receive , wrapped_send )
180191 return
@@ -186,7 +197,7 @@ async def wrapped_send(event):
186197 elif content_type == b"multipart/form-data" :
187198 # Consume non-file items until we see a csrftoken
188199 # If we see a file item first, it's an error
189- boundary = headers .get (b"content-type" ).split (b"; boundary=" )[1 ]
200+ boundary = headers .get (b"content-type" , "" ).split (b"; boundary=" )[1 ]
190201 assert boundary is not None , "missing 'boundary' header: {}" .format (
191202 repr (headers )
192203 )
@@ -197,7 +208,7 @@ async def wrapped_send(event):
197208 replay_receive ,
198209 ) = await _parse_multipart_form_data (boundary , receive )
199210 if not secrets .compare_digest (
200- csrftoken_from_body or "" , csrftoken
211+ csrftoken_from_body or "" , csrftoken or ""
201212 ):
202213 await send_csrf_failed (
203214 scope ,
@@ -226,7 +237,7 @@ async def wrapped_send(event):
226237 return _asgi_csrf_decorator
227238
228239
229- async def _parse_form_urlencoded (receive ) :
240+ async def _parse_form_urlencoded (receive : "Receive" ) -> Tuple [ Dict [ str , str ], "Receive" ] :
230241 # Returns {key: value}, replay_receive
231242 # where replay_receive is an awaitable that can replay what was received
232243 # We ignore cases like foo=one&foo=two because we do not need to
@@ -241,7 +252,7 @@ async def _parse_form_urlencoded(receive):
241252 body += message .get ("body" , b"" )
242253 more_body = message .get ("more_body" , False )
243254
244- async def replay_receive ():
255+ async def replay_receive () -> "Message" :
245256 if messages :
246257 return messages .pop (0 )
247258 else :
@@ -262,7 +273,7 @@ class FileBeforeToken(Exception):
262273 pass
263274
264275
265- async def _parse_multipart_form_data (boundary , receive ) :
276+ async def _parse_multipart_form_data (boundary : bytes , receive : "Receive" ) -> Tuple [ Optional [ str ], "Receive" ] :
266277 # Returns (csrftoken, replay_receive) - or raises an exception
267278 csrftoken = None
268279
@@ -272,26 +283,34 @@ def on_field(field):
272283 raise TokenFound (csrftoken )
273284
274285 class ErrorOnWrite :
275- def __init__ (self , file_name , field_name , config ) :
286+ def __init__ (self , file_name : bytes | None , field_name : bytes | None , config : Mapping [ str , Any ]) -> None :
276287 pass
277288
278- def write (self , data ) :
289+ def write (self , data : bytes ) -> int :
279290 raise FileBeforeToken
280291
292+ def finalize (self ) -> None : ...
293+
294+ def close (self ) -> None : ...
295+
296+ def set_none (self ) -> None : ...
297+
281298 body = b""
282299 more_body = True
283- messages = []
300+ messages : List [ "Message" ] = []
284301
285- async def replay_receive ():
302+ async def replay_receive () -> "Message" :
286303 if messages :
287304 return messages .pop (0 )
288305 else :
289306 return await receive ()
290307
308+ on_file : OnFileCallback = lambda _ : None
309+
291310 form_parser = FormParser (
292311 "multipart/form-data" ,
293- on_field ,
294- lambda : None ,
312+ on_field = on_field ,
313+ on_file = on_file ,
295314 boundary = boundary ,
296315 FileClass = ErrorOnWrite ,
297316 )
@@ -308,7 +327,7 @@ async def replay_receive():
308327 return None , replay_receive
309328
310329
311- async def default_send_csrf_failed (scope , send , message_id ) :
330+ async def default_send_csrf_failed (scope : "Scope" , send : "Send" , message_id : int ) -> None :
312331 assert scope ["type" ] == "http"
313332 await send (
314333 {
@@ -323,19 +342,19 @@ async def default_send_csrf_failed(scope, send, message_id):
323342
324343def asgi_csrf (
325344 app ,
326- cookie_name = DEFAULT_COOKIE_NAME ,
327- http_header = DEFAULT_HTTP_HEADER ,
328- signing_secret = None ,
329- signing_namespace = DEFAULT_SIGNING_NAMESPACE ,
330- always_protect = None ,
331- always_set_cookie = False ,
332- skip_if_scope = None ,
333- cookie_path = DEFAULT_COOKIE_PATH ,
334- cookie_domain = DEFAULT_COOKIE_DOMAIN ,
335- cookie_secure = DEFAULT_COOKIE_SECURE ,
336- cookie_samesite = DEFAULT_COOKIE_SAMESITE ,
337- send_csrf_failed = None ,
338- ):
345+ cookie_name : str = DEFAULT_COOKIE_NAME ,
346+ http_header : str = DEFAULT_HTTP_HEADER ,
347+ signing_secret : Optional [ str ] = None ,
348+ signing_namespace : str = DEFAULT_SIGNING_NAMESPACE ,
349+ always_protect : Optional [ Container [ str ]] = None ,
350+ always_set_cookie : bool = False ,
351+ skip_if_scope : Optional [ Callable [[ "Scope" ], bool ]] = None ,
352+ cookie_path : str = DEFAULT_COOKIE_PATH ,
353+ cookie_domain : Optional [ str ] = DEFAULT_COOKIE_DOMAIN ,
354+ cookie_secure : bool = DEFAULT_COOKIE_SECURE ,
355+ cookie_samesite : str = DEFAULT_COOKIE_SAMESITE ,
356+ send_csrf_failed : Optional [ Callable [[ "Scope" , "Send" , int ], Awaitable [ None ]]] = None ,
357+ ) -> "ASGIApp" :
339358 return asgi_csrf_decorator (
340359 cookie_name ,
341360 http_header ,
@@ -352,7 +371,7 @@ def asgi_csrf(
352371 )(app )
353372
354373
355- def cookies_from_scope (scope ) :
374+ def cookies_from_scope (scope : "Scope" ) -> Dict [ str , str ] :
356375 cookie = dict (scope .get ("headers" ) or {}).get (b"cookie" )
357376 if not cookie :
358377 return {}
@@ -364,5 +383,5 @@ def cookies_from_scope(scope):
364383allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
365384
366385
367- def make_secret (length ) :
386+ def make_secret (length : int ) -> str :
368387 return "" .join (secrets .choice (allowed_chars ) for i in range (length ))
0 commit comments