11from http .cookies import SimpleCookie
2- from enum import Enum
3- import fnmatch
2+ from collections . abc import Awaitable , Callable , Container , Mapping , MutableMapping
3+ from enum import IntEnum
44from functools import wraps
55from multipart import FormParser
66import os
7+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
78from urllib .parse import parse_qsl
89from itsdangerous .url_safe import URLSafeSerializer
910from itsdangerous import BadSignature
1011import secrets
1112
13+ if TYPE_CHECKING :
14+ Scope = MutableMapping [str , Any ]
15+ Message = MutableMapping [str , Any ]
16+ Receive = Callable [[], Awaitable [Message ]]
17+ Send = Callable [[Message ], Awaitable [None ]]
18+ ASGIApp = Callable [[Scope , Receive , Send ], Awaitable [None ]]
19+
1220DEFAULT_COOKIE_NAME = "csrftoken"
1321DEFAULT_COOKIE_PATH = "/"
1422DEFAULT_COOKIE_DOMAIN = None
2129ENV_SECRET = "ASGI_CSRF_SECRET"
2230
2331
24- class Errors (Enum ):
32+
33+ class Errors (IntEnum ):
2534 FORM_URLENCODED_MISMATCH = 1
2635 MULTIPART_MISMATCH = 2
2736 FILE_BEFORE_TOKEN = 3
@@ -37,30 +46,30 @@ class Errors(Enum):
3746
3847
3948def asgi_csrf_decorator (
40- cookie_name = DEFAULT_COOKIE_NAME ,
41- http_header = DEFAULT_HTTP_HEADER ,
42- form_input = DEFAULT_FORM_INPUT ,
43- signing_secret = None ,
44- signing_namespace = DEFAULT_SIGNING_NAMESPACE ,
45- always_protect = None ,
46- always_set_cookie = False ,
47- skip_if_scope = None ,
48- cookie_path = DEFAULT_COOKIE_PATH ,
49- cookie_domain = DEFAULT_COOKIE_DOMAIN ,
50- cookie_secure = DEFAULT_COOKIE_SECURE ,
51- cookie_samesite = DEFAULT_COOKIE_SAMESITE ,
52- send_csrf_failed = None ,
53- ):
49+ cookie_name : str = DEFAULT_COOKIE_NAME ,
50+ http_header : str = DEFAULT_HTTP_HEADER ,
51+ form_input : str = DEFAULT_FORM_INPUT ,
52+ signing_secret : Optional [ str ] = None ,
53+ signing_namespace : str = DEFAULT_SIGNING_NAMESPACE ,
54+ always_protect : Optional [ Container [ str ]] = None ,
55+ always_set_cookie : bool = False ,
56+ skip_if_scope : Optional [ Callable [[ "Scope" ], bool ]] = None ,
57+ cookie_path : str = DEFAULT_COOKIE_PATH ,
58+ cookie_domain : Optional [ str ] = DEFAULT_COOKIE_DOMAIN ,
59+ cookie_secure : bool = DEFAULT_COOKIE_SECURE ,
60+ cookie_samesite : str = DEFAULT_COOKIE_SAMESITE ,
61+ send_csrf_failed : Optional [ Callable [[ "Scope" , "Send" , int ], Awaitable [ None ]]] = None ,
62+ ) -> Callable [[ "ASGIApp" ], "ASGIApp" ] :
5463 send_csrf_failed = send_csrf_failed or default_send_csrf_failed
5564 if signing_secret is None :
5665 signing_secret = os .environ .get (ENV_SECRET , None )
5766 if signing_secret is None :
5867 signing_secret = make_secret (128 )
5968 signer = URLSafeSerializer (signing_secret )
6069
61- def _asgi_csrf_decorator (app ) :
70+ def _asgi_csrf_decorator (app : "ASGIApp" ) -> "ASGIApp" :
6271 @wraps (app )
63- async def app_wrapped_with_csrf (scope , receive , send ) :
72+ async def app_wrapped_with_csrf (scope : "Scope" , receive : "Receive" , send : "Send" ) -> None :
6473 if scope ["type" ] != "http" :
6574 await app (scope , receive , send )
6675 return
@@ -84,7 +93,7 @@ async def app_wrapped_with_csrf(scope, receive, send):
8493 if not has_csrftoken_cookie :
8594 csrftoken = signer .dumps (make_secret (16 ), signing_namespace )
8695
87- def get_csrftoken ():
96+ def get_csrftoken () -> Optional [ str ] :
8897 nonlocal should_set_cookie
8998 nonlocal page_needs_vary_header
9099 page_needs_vary_header = True
@@ -94,7 +103,7 @@ def get_csrftoken():
94103
95104 scope = {** scope , ** {SCOPE_KEY : get_csrftoken }}
96105
97- async def wrapped_send (event ) :
106+ async def wrapped_send (event : "Message" ) -> None :
98107 if event ["type" ] == "http.response.start" :
99108 original_headers = event .get ("headers" ) or []
100109 new_headers = []
@@ -144,7 +153,7 @@ async def wrapped_send(event):
144153 await app (scope , receive , wrapped_send )
145154 else :
146155 # Check for CSRF token in various places
147- headers = dict (scope .get ("headers" or []) )
156+ headers = dict (scope .get ("headers" ) or [])
148157 if secrets .compare_digest (
149158 headers .get (http_header .encode ("latin-1" ), b"" ).decode ("latin-1" ),
150159 csrftoken ,
@@ -226,7 +235,7 @@ async def wrapped_send(event):
226235 return _asgi_csrf_decorator
227236
228237
229- async def _parse_form_urlencoded (receive ) :
238+ async def _parse_form_urlencoded (receive : "Receive" ) -> Tuple [ Dict [ str , str ], "Receive" ] :
230239 # Returns {key: value}, replay_receive
231240 # where replay_receive is an awaitable that can replay what was received
232241 # We ignore cases like foo=one&foo=two because we do not need to
@@ -241,7 +250,7 @@ async def _parse_form_urlencoded(receive):
241250 body += message .get ("body" , b"" )
242251 more_body = message .get ("more_body" , False )
243252
244- async def replay_receive ():
253+ async def replay_receive () -> "Message" :
245254 if messages :
246255 return messages .pop (0 )
247256 else :
@@ -262,7 +271,7 @@ class FileBeforeToken(Exception):
262271 pass
263272
264273
265- async def _parse_multipart_form_data (boundary , receive ) :
274+ async def _parse_multipart_form_data (boundary : bytes , receive : "Receive" ) -> Tuple [ Optional [ str ], "Receive" ] :
266275 # Returns (csrftoken, replay_receive) - or raises an exception
267276 csrftoken = None
268277
@@ -272,17 +281,17 @@ def on_field(field):
272281 raise TokenFound (csrftoken )
273282
274283 class ErrorOnWrite :
275- def __init__ (self , file_name , field_name , config ) :
284+ def __init__ (self , file_name : bytes | None , field_name : bytes | None , config : Mapping [ str , Any ]) -> None :
276285 pass
277286
278- def write (self , data ) :
287+ def write (self , data : bytes ) -> int :
279288 raise FileBeforeToken
280289
281290 body = b""
282291 more_body = True
283- messages = []
292+ messages : List [ "Message" ] = []
284293
285- async def replay_receive ():
294+ async def replay_receive () -> "Message" :
286295 if messages :
287296 return messages .pop (0 )
288297 else :
@@ -308,7 +317,7 @@ async def replay_receive():
308317 return None , replay_receive
309318
310319
311- async def default_send_csrf_failed (scope , send , message_id ) :
320+ async def default_send_csrf_failed (scope : "Scope" , send : "Send" , message_id : int ) -> None :
312321 assert scope ["type" ] == "http"
313322 await send (
314323 {
@@ -323,19 +332,19 @@ async def default_send_csrf_failed(scope, send, message_id):
323332
324333def asgi_csrf (
325334 app ,
326- cookie_name = DEFAULT_COOKIE_NAME ,
327- http_header = DEFAULT_HTTP_HEADER ,
328- signing_secret = None ,
329- signing_namespace = DEFAULT_SIGNING_NAMESPACE ,
330- always_protect = None ,
331- always_set_cookie = False ,
332- skip_if_scope = None ,
333- cookie_path = DEFAULT_COOKIE_PATH ,
334- cookie_domain = DEFAULT_COOKIE_DOMAIN ,
335- cookie_secure = DEFAULT_COOKIE_SECURE ,
336- cookie_samesite = DEFAULT_COOKIE_SAMESITE ,
337- send_csrf_failed = None ,
338- ):
335+ cookie_name : str = DEFAULT_COOKIE_NAME ,
336+ http_header : str = DEFAULT_HTTP_HEADER ,
337+ signing_secret : Optional [ str ] = None ,
338+ signing_namespace : str = DEFAULT_SIGNING_NAMESPACE ,
339+ always_protect : Optional [ Container [ str ]] = None ,
340+ always_set_cookie : bool = False ,
341+ skip_if_scope : Optional [ Callable [[ "Scope" ], bool ]] = None ,
342+ cookie_path : str = DEFAULT_COOKIE_PATH ,
343+ cookie_domain : Optional [ str ] = DEFAULT_COOKIE_DOMAIN ,
344+ cookie_secure : bool = DEFAULT_COOKIE_SECURE ,
345+ cookie_samesite : str = DEFAULT_COOKIE_SAMESITE ,
346+ send_csrf_failed : Optional [ Callable [[ "Scope" , "Send" , int ], Awaitable [ None ]]] = None ,
347+ ) -> "ASGIApp" :
339348 return asgi_csrf_decorator (
340349 cookie_name ,
341350 http_header ,
@@ -352,7 +361,7 @@ def asgi_csrf(
352361 )(app )
353362
354363
355- def cookies_from_scope (scope ) :
364+ def cookies_from_scope (scope : "Scope" ) -> Dict [ str , str ] :
356365 cookie = dict (scope .get ("headers" ) or {}).get (b"cookie" )
357366 if not cookie :
358367 return {}
@@ -364,5 +373,5 @@ def cookies_from_scope(scope):
364373allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
365374
366375
367- def make_secret (length ) :
376+ def make_secret (length : int ) -> str :
368377 return "" .join (secrets .choice (allowed_chars ) for i in range (length ))
0 commit comments