Skip to content

Commit 4db11ae

Browse files
authored
Custom errors (#30)
* send_csrf_failed for customizing errors, closes #28
1 parent 0d25087 commit 4db11ae

File tree

3 files changed

+151
-24
lines changed

3 files changed

+151
-24
lines changed

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,45 @@ app = asgi_csrf(
100100
skip_if_scope=skip_api_paths
101101
)
102102
```
103+
104+
### send_csrf_failed
105+
106+
By default, when a CSRF token is missing or invalid, the middleware will return a 403 Forbidden response page with a short error message.
107+
108+
You can customize this behavior by passing a `send_csrf_failed` function to the middleware. This function should accept the ASGI `scope` and `send` functions, and the `message_id` of the error that occurred.
109+
110+
The `message_id` will be an integer representing an item from the `asgi_csrf.Errors` enum.
111+
112+
This example shows how you could customize the error message based on that `message_id`:
113+
114+
```python
115+
async def custom_csrf_failed(scope, send, message_id):
116+
assert scope["type"] == "http"
117+
await send(
118+
{
119+
"type": "http.response.start",
120+
"status": 403,
121+
"headers": [[b"content-type", b"text/html; charset=utf-8"]],
122+
}
123+
)
124+
await send(
125+
{
126+
"type": "http.response.body",
127+
"body": {
128+
Errors.FORM_URLENCODED_MISMATCH: "custom form-urlencoded error",
129+
Errors.MULTIPART_MISMATCH: "custom multipart error",
130+
Errors.FILE_BEFORE_TOKEN: "custom file before token error",
131+
Errors.UNKNOWN_CONTENT_TYPE: "custom unknown content type error",
132+
}
133+
.get(message_id, "")
134+
.encode("utf-8"),
135+
}
136+
)
137+
138+
139+
app = asgi_csrf(
140+
app,
141+
signing_secret="secret-goes-here",
142+
send_csrf_failed=custom_csrf_failed
143+
)
144+
```

asgi_csrf.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from http.cookies import SimpleCookie
2+
from enum import Enum
23
import fnmatch
34
from functools import wraps
45
from multipart import FormParser
@@ -16,6 +17,21 @@
1617
ENV_SECRET = "ASGI_CSRF_SECRET"
1718

1819

20+
class Errors(Enum):
21+
FORM_URLENCODED_MISMATCH = 1
22+
MULTIPART_MISMATCH = 2
23+
FILE_BEFORE_TOKEN = 3
24+
UNKNOWN_CONTENT_TYPE = 4
25+
26+
27+
error_messages = {
28+
Errors.FORM_URLENCODED_MISMATCH: "form-urlencoded POST field did not match cookie",
29+
Errors.MULTIPART_MISMATCH: "multipart/form-data POST field did not match cookie",
30+
Errors.FILE_BEFORE_TOKEN: "File encountered before csrftoken - make sure csrftoken is first in the HTML",
31+
Errors.UNKNOWN_CONTENT_TYPE: "Unknown content-type",
32+
}
33+
34+
1935
def asgi_csrf_decorator(
2036
cookie_name=DEFAULT_COOKIE_NAME,
2137
http_header=DEFAULT_HTTP_HEADER,
@@ -25,7 +41,9 @@ def asgi_csrf_decorator(
2541
always_protect=None,
2642
always_set_cookie=False,
2743
skip_if_scope=None,
44+
send_csrf_failed=None,
2845
):
46+
send_csrf_failed = send_csrf_failed or default_send_csrf_failed
2947
if signing_secret is None:
3048
signing_secret = os.environ.get(ENV_SECRET, None)
3149
if signing_secret is None:
@@ -144,9 +162,7 @@ async def wrapped_send(event):
144162
return
145163
else:
146164
await send_csrf_failed(
147-
scope,
148-
wrapped_send,
149-
"form-urlencoded POST field did not match cookie",
165+
scope, wrapped_send, Errors.FORM_URLENCODED_MISMATCH
150166
)
151167
return
152168
elif content_type == b"multipart/form-data":
@@ -168,22 +184,22 @@ async def wrapped_send(event):
168184
await send_csrf_failed(
169185
scope,
170186
wrapped_send,
171-
"multipart/form-data POST field did not match cookie",
187+
Errors.MULTIPART_MISMATCH,
172188
)
173189
return
174190
except FileBeforeToken:
175191
await send_csrf_failed(
176192
scope,
177193
wrapped_send,
178-
"File encountered before csrftoken - make sure csrftoken is first in the HTML",
194+
Errors.FILE_BEFORE_TOKEN,
179195
)
180196
return
181197
# Now replay the body
182198
await app(scope, replay_receive, wrapped_send)
183199
return
184200
else:
185201
await send_csrf_failed(
186-
scope, wrapped_send, message="Unknown content-type"
202+
scope, wrapped_send, Errors.UNKNOWN_CONTENT_TYPE
187203
)
188204
return
189205

@@ -271,7 +287,7 @@ async def replay_receive():
271287
return None, replay_receive
272288

273289

274-
async def send_csrf_failed(scope, send, message="CSRF check failed"):
290+
async def default_send_csrf_failed(scope, send, message_id):
275291
assert scope["type"] == "http"
276292
await send(
277293
{
@@ -280,6 +296,7 @@ async def send_csrf_failed(scope, send, message="CSRF check failed"):
280296
"headers": [[b"content-type", b"text/html; charset=utf-8"]],
281297
}
282298
)
299+
message = error_messages.get(message_id) or "CSRF validation failed"
283300
await send({"type": "http.response.body", "body": message.encode("utf-8")})
284301

285302

@@ -292,6 +309,7 @@ def asgi_csrf(
292309
always_protect=None,
293310
always_set_cookie=False,
294311
skip_if_scope=None,
312+
send_csrf_failed=None,
295313
):
296314
return asgi_csrf_decorator(
297315
cookie_name,
@@ -301,6 +319,7 @@ def asgi_csrf(
301319
always_protect=always_protect,
302320
always_set_cookie=always_set_cookie,
303321
skip_if_scope=skip_if_scope,
322+
send_csrf_failed=send_csrf_failed,
304323
)(app)
305324

306325

test_asgi_csrf.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from starlette.applications import Starlette
33
from starlette.responses import JSONResponse
44
from starlette.routing import Route
5-
from asgi_csrf import asgi_csrf
5+
from asgi_csrf import asgi_csrf, Errors
66
from itsdangerous.url_safe import URLSafeSerializer
77
import httpx
88
import json
@@ -45,6 +45,39 @@ def app_csrf():
4545
return asgi_csrf(hello_world_app, signing_secret=SECRET)
4646

4747

48+
async def custom_csrf_failed(scope, send, message_id):
49+
assert scope["type"] == "http"
50+
await send(
51+
{
52+
"type": "http.response.start",
53+
"status": 403,
54+
"headers": [[b"content-type", b"text/html; charset=utf-8"]],
55+
}
56+
)
57+
await send(
58+
{
59+
"type": "http.response.body",
60+
"body": {
61+
Errors.FORM_URLENCODED_MISMATCH: "custom form-urlencoded error",
62+
Errors.MULTIPART_MISMATCH: "custom multipart error",
63+
Errors.FILE_BEFORE_TOKEN: "custom file before token error",
64+
Errors.UNKNOWN_CONTENT_TYPE: "custom unknown content type error",
65+
}
66+
.get(message_id, "")
67+
.encode("utf-8"),
68+
}
69+
)
70+
71+
72+
@pytest.fixture
73+
def app_csrf_custom_errors():
74+
return asgi_csrf(
75+
hello_world_app,
76+
signing_secret=SECRET,
77+
send_csrf_failed=custom_csrf_failed,
78+
)
79+
80+
4881
@pytest.fixture
4982
def csrftoken():
5083
return URLSafeSerializer(SECRET).dumps("token", "csrftoken")
@@ -60,11 +93,11 @@ async def test_hello_world_app():
6093
def test_signing_secret_if_none_provided(monkeypatch):
6194
app = asgi_csrf(hello_world_app)
6295
# Should be randomly generated
63-
assert isinstance(app.__closure__[6].cell_contents.secret_key, bytes)
96+
assert isinstance(app.__closure__[7].cell_contents.secret_key, bytes)
6497
# Should pick up `ASGI_CSRF_SECRET` if available
6598
monkeypatch.setenv("ASGI_CSRF_SECRET", "secret-from-environment")
6699
app2 = asgi_csrf(hello_world_app)
67-
assert app2.__closure__[6].cell_contents.secret_key == b"secret-from-environment"
100+
assert app2.__closure__[7].cell_contents.secret_key == b"secret-from-environment"
68101

69102

70103
@pytest.mark.asyncio
@@ -142,15 +175,24 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):
142175

143176

144177
@pytest.mark.asyncio
145-
async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):
146-
async with httpx.AsyncClient(app=app_csrf) as client:
178+
@pytest.mark.parametrize("custom_errors", (False, True))
179+
async def test_prevents_post_if_cookie_not_sent_in_post(
180+
custom_errors, app_csrf, app_csrf_custom_errors, csrftoken
181+
):
182+
async with httpx.AsyncClient(
183+
app=app_csrf_custom_errors if custom_errors else app_csrf
184+
) as client:
147185
response = await client.post(
148186
"http://localhost/",
149187
cookies={"csrftoken": csrftoken},
150188
data={"csrftoken": csrftoken[-1]},
151189
)
152190
assert 403 == response.status_code
153-
assert response.text == "form-urlencoded POST field did not match cookie"
191+
assert (
192+
response.text == "custom form-urlencoded error"
193+
if custom_errors
194+
else "form-urlencoded POST field did not match cookie"
195+
)
154196

155197

156198
@pytest.mark.asyncio
@@ -194,9 +236,14 @@ async def test_multipart(csrftoken):
194236

195237

196238
@pytest.mark.asyncio
197-
async def test_multipart_failure_wrong_token(csrftoken):
239+
@pytest.mark.parametrize("custom_errors", (False, True))
240+
async def test_multipart_failure_wrong_token(csrftoken, custom_errors):
198241
async with httpx.AsyncClient(
199-
app=asgi_csrf(hello_world_app, signing_secret=SECRET)
242+
app=asgi_csrf(
243+
hello_world_app,
244+
signing_secret=SECRET,
245+
send_csrf_failed=custom_csrf_failed if custom_errors else None,
246+
)
200247
) as client:
201248
response = await client.post(
202249
"http://localhost/",
@@ -205,7 +252,11 @@ async def test_multipart_failure_wrong_token(csrftoken):
205252
cookies={"csrftoken": csrftoken[:-1]},
206253
)
207254
assert response.status_code == 403
208-
assert response.text == "multipart/form-data POST field did not match cookie"
255+
assert (
256+
response.text == "custom multipart error"
257+
if custom_errors
258+
else "multipart/form-data POST field did not match cookie"
259+
)
209260

210261

211262
class TrickEmptyDictionary(dict):
@@ -215,9 +266,14 @@ def __bool__(self):
215266

216267

217268
@pytest.mark.asyncio
218-
async def test_multipart_failure_missing_token(csrftoken):
269+
@pytest.mark.parametrize("custom_errors", (False, True))
270+
async def test_multipart_failure_missing_token(csrftoken, custom_errors):
219271
async with httpx.AsyncClient(
220-
app=asgi_csrf(hello_world_app, signing_secret=SECRET)
272+
app=asgi_csrf(
273+
hello_world_app,
274+
signing_secret=SECRET,
275+
send_csrf_failed=custom_csrf_failed if custom_errors else None,
276+
)
221277
) as client:
222278
response = await client.post(
223279
"http://localhost/",
@@ -226,18 +282,27 @@ async def test_multipart_failure_missing_token(csrftoken):
226282
cookies={"csrftoken": csrftoken},
227283
)
228284
assert response.status_code == 403
229-
assert response.text == "multipart/form-data POST field did not match cookie"
285+
assert response.text == (
286+
"custom multipart error"
287+
if custom_errors
288+
else "multipart/form-data POST field did not match cookie"
289+
)
230290

231291

232292
@pytest.mark.asyncio
233-
async def test_multipart_failure_file_comes_before_token(csrftoken):
293+
@pytest.mark.parametrize("custom_errors", (False, True))
294+
async def test_multipart_failure_file_comes_before_token(csrftoken, custom_errors):
234295
async with httpx.AsyncClient(
235-
app=asgi_csrf(hello_world_app, signing_secret=SECRET)
296+
app=asgi_csrf(
297+
hello_world_app,
298+
signing_secret=SECRET,
299+
send_csrf_failed=custom_csrf_failed if custom_errors else None,
300+
)
236301
) as client:
237302
request = httpx.Request(
238303
url="http://localhost/",
239304
method="POST",
240-
data=(
305+
content=(
241306
b"--boo\r\n"
242307
b'Content-Disposition: form-data; name="csv"; filename="data.csv"'
243308
b"\r\nContent-Type: text/csv\r\n\r\n"
@@ -255,8 +320,9 @@ async def test_multipart_failure_file_comes_before_token(csrftoken):
255320
response = await client.send(request)
256321
assert response.status_code == 403
257322
assert (
258-
response.text
259-
== "File encountered before csrftoken - make sure csrftoken is first in the HTML"
323+
response.text == "custom file before token error"
324+
if custom_errors
325+
else "File encountered before csrftoken - make sure csrftoken is first in the HTML"
260326
)
261327

262328

0 commit comments

Comments
 (0)