22from starlette .applications import Starlette
33from starlette .responses import JSONResponse
44from starlette .routing import Route
5- from asgi_csrf import asgi_csrf
5+ from asgi_csrf import asgi_csrf , Errors
66from itsdangerous .url_safe import URLSafeSerializer
77import httpx
88import json
@@ -45,6 +45,39 @@ def app_csrf():
4545 return asgi_csrf (hello_world_app , signing_secret = SECRET )
4646
4747
48+ async def custom_csrf_failed (scope , send , message_id ):
49+ assert scope ["type" ] == "http"
50+ await send (
51+ {
52+ "type" : "http.response.start" ,
53+ "status" : 403 ,
54+ "headers" : [[b"content-type" , b"text/html; charset=utf-8" ]],
55+ }
56+ )
57+ await send (
58+ {
59+ "type" : "http.response.body" ,
60+ "body" : {
61+ Errors .FORM_URLENCODED_MISMATCH : "custom form-urlencoded error" ,
62+ Errors .MULTIPART_MISMATCH : "custom multipart error" ,
63+ Errors .FILE_BEFORE_TOKEN : "custom file before token error" ,
64+ Errors .UNKNOWN_CONTENT_TYPE : "custom unknown content type error" ,
65+ }
66+ .get (message_id , "" )
67+ .encode ("utf-8" ),
68+ }
69+ )
70+
71+
72+ @pytest .fixture
73+ def app_csrf_custom_errors ():
74+ return asgi_csrf (
75+ hello_world_app ,
76+ signing_secret = SECRET ,
77+ send_csrf_failed = custom_csrf_failed ,
78+ )
79+
80+
4881@pytest .fixture
4982def csrftoken ():
5083 return URLSafeSerializer (SECRET ).dumps ("token" , "csrftoken" )
@@ -60,11 +93,11 @@ async def test_hello_world_app():
6093def test_signing_secret_if_none_provided (monkeypatch ):
6194 app = asgi_csrf (hello_world_app )
6295 # Should be randomly generated
63- assert isinstance (app .__closure__ [6 ].cell_contents .secret_key , bytes )
96+ assert isinstance (app .__closure__ [7 ].cell_contents .secret_key , bytes )
6497 # Should pick up `ASGI_CSRF_SECRET` if available
6598 monkeypatch .setenv ("ASGI_CSRF_SECRET" , "secret-from-environment" )
6699 app2 = asgi_csrf (hello_world_app )
67- assert app2 .__closure__ [6 ].cell_contents .secret_key == b"secret-from-environment"
100+ assert app2 .__closure__ [7 ].cell_contents .secret_key == b"secret-from-environment"
68101
69102
70103@pytest .mark .asyncio
@@ -142,15 +175,24 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):
142175
143176
144177@pytest .mark .asyncio
145- async def test_prevents_post_if_cookie_not_sent_in_post (app_csrf , csrftoken ):
146- async with httpx .AsyncClient (app = app_csrf ) as client :
178+ @pytest .mark .parametrize ("custom_errors" , (False , True ))
179+ async def test_prevents_post_if_cookie_not_sent_in_post (
180+ custom_errors , app_csrf , app_csrf_custom_errors , csrftoken
181+ ):
182+ async with httpx .AsyncClient (
183+ app = app_csrf_custom_errors if custom_errors else app_csrf
184+ ) as client :
147185 response = await client .post (
148186 "http://localhost/" ,
149187 cookies = {"csrftoken" : csrftoken },
150188 data = {"csrftoken" : csrftoken [- 1 ]},
151189 )
152190 assert 403 == response .status_code
153- assert response .text == "form-urlencoded POST field did not match cookie"
191+ assert (
192+ response .text == "custom form-urlencoded error"
193+ if custom_errors
194+ else "form-urlencoded POST field did not match cookie"
195+ )
154196
155197
156198@pytest .mark .asyncio
@@ -194,9 +236,14 @@ async def test_multipart(csrftoken):
194236
195237
196238@pytest .mark .asyncio
197- async def test_multipart_failure_wrong_token (csrftoken ):
239+ @pytest .mark .parametrize ("custom_errors" , (False , True ))
240+ async def test_multipart_failure_wrong_token (csrftoken , custom_errors ):
198241 async with httpx .AsyncClient (
199- app = asgi_csrf (hello_world_app , signing_secret = SECRET )
242+ app = asgi_csrf (
243+ hello_world_app ,
244+ signing_secret = SECRET ,
245+ send_csrf_failed = custom_csrf_failed if custom_errors else None ,
246+ )
200247 ) as client :
201248 response = await client .post (
202249 "http://localhost/" ,
@@ -205,7 +252,11 @@ async def test_multipart_failure_wrong_token(csrftoken):
205252 cookies = {"csrftoken" : csrftoken [:- 1 ]},
206253 )
207254 assert response .status_code == 403
208- assert response .text == "multipart/form-data POST field did not match cookie"
255+ assert (
256+ response .text == "custom multipart error"
257+ if custom_errors
258+ else "multipart/form-data POST field did not match cookie"
259+ )
209260
210261
211262class TrickEmptyDictionary (dict ):
@@ -215,9 +266,14 @@ def __bool__(self):
215266
216267
217268@pytest .mark .asyncio
218- async def test_multipart_failure_missing_token (csrftoken ):
269+ @pytest .mark .parametrize ("custom_errors" , (False , True ))
270+ async def test_multipart_failure_missing_token (csrftoken , custom_errors ):
219271 async with httpx .AsyncClient (
220- app = asgi_csrf (hello_world_app , signing_secret = SECRET )
272+ app = asgi_csrf (
273+ hello_world_app ,
274+ signing_secret = SECRET ,
275+ send_csrf_failed = custom_csrf_failed if custom_errors else None ,
276+ )
221277 ) as client :
222278 response = await client .post (
223279 "http://localhost/" ,
@@ -226,18 +282,27 @@ async def test_multipart_failure_missing_token(csrftoken):
226282 cookies = {"csrftoken" : csrftoken },
227283 )
228284 assert response .status_code == 403
229- assert response .text == "multipart/form-data POST field did not match cookie"
285+ assert response .text == (
286+ "custom multipart error"
287+ if custom_errors
288+ else "multipart/form-data POST field did not match cookie"
289+ )
230290
231291
232292@pytest .mark .asyncio
233- async def test_multipart_failure_file_comes_before_token (csrftoken ):
293+ @pytest .mark .parametrize ("custom_errors" , (False , True ))
294+ async def test_multipart_failure_file_comes_before_token (csrftoken , custom_errors ):
234295 async with httpx .AsyncClient (
235- app = asgi_csrf (hello_world_app , signing_secret = SECRET )
296+ app = asgi_csrf (
297+ hello_world_app ,
298+ signing_secret = SECRET ,
299+ send_csrf_failed = custom_csrf_failed if custom_errors else None ,
300+ )
236301 ) as client :
237302 request = httpx .Request (
238303 url = "http://localhost/" ,
239304 method = "POST" ,
240- data = (
305+ content = (
241306 b"--boo\r \n "
242307 b'Content-Disposition: form-data; name="csv"; filename="data.csv"'
243308 b"\r \n Content-Type: text/csv\r \n \r \n "
@@ -255,8 +320,9 @@ async def test_multipart_failure_file_comes_before_token(csrftoken):
255320 response = await client .send (request )
256321 assert response .status_code == 403
257322 assert (
258- response .text
259- == "File encountered before csrftoken - make sure csrftoken is first in the HTML"
323+ response .text == "custom file before token error"
324+ if custom_errors
325+ else "File encountered before csrftoken - make sure csrftoken is first in the HTML"
260326 )
261327
262328
0 commit comments