Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 70 additions & 51 deletions asgi_csrf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from http.cookies import SimpleCookie
from enum import Enum
import fnmatch
from collections.abc import Awaitable, Callable, Container, Mapping, MutableMapping
from enum import IntEnum
from functools import wraps
from python_multipart import FormParser
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from urllib.parse import parse_qsl
from itsdangerous.url_safe import URLSafeSerializer
from itsdangerous import BadSignature
import secrets

if TYPE_CHECKING:
from python_multipart.multipart import OnFileCallback

Scope = MutableMapping[str, Any]
Message = MutableMapping[str, Any]
Receive = Callable[[], Awaitable[Message]]
Send = Callable[[Message], Awaitable[None]]
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]

DEFAULT_COOKIE_NAME = "csrftoken"
DEFAULT_COOKIE_PATH = "/"
DEFAULT_COOKIE_DOMAIN = None
Expand All @@ -21,14 +31,15 @@
ENV_SECRET = "ASGI_CSRF_SECRET"


class Errors(Enum):

class Errors(IntEnum):
FORM_URLENCODED_MISMATCH = 1
MULTIPART_MISMATCH = 2
FILE_BEFORE_TOKEN = 3
UNKNOWN_CONTENT_TYPE = 4


error_messages = {
error_messages: dict[int, str] = {
Errors.FORM_URLENCODED_MISMATCH: "form-urlencoded POST field did not match cookie",
Errors.MULTIPART_MISMATCH: "multipart/form-data POST field did not match cookie",
Errors.FILE_BEFORE_TOKEN: "File encountered before csrftoken - make sure csrftoken is first in the HTML",
Expand All @@ -37,30 +48,30 @@ class Errors(Enum):


def asgi_csrf_decorator(
cookie_name=DEFAULT_COOKIE_NAME,
http_header=DEFAULT_HTTP_HEADER,
form_input=DEFAULT_FORM_INPUT,
signing_secret=None,
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
always_protect=None,
always_set_cookie=False,
skip_if_scope=None,
cookie_path=DEFAULT_COOKIE_PATH,
cookie_domain=DEFAULT_COOKIE_DOMAIN,
cookie_secure=DEFAULT_COOKIE_SECURE,
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
send_csrf_failed=None,
):
cookie_name: str = DEFAULT_COOKIE_NAME,
http_header: str = DEFAULT_HTTP_HEADER,
form_input: str = DEFAULT_FORM_INPUT,
signing_secret: Optional[str] = None,
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
always_protect: Optional[Container[str]] = None,
always_set_cookie: bool = False,
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
cookie_path: str = DEFAULT_COOKIE_PATH,
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
) -> Callable[["ASGIApp"], "ASGIApp"]:
send_csrf_failed = send_csrf_failed or default_send_csrf_failed
if signing_secret is None:
signing_secret = os.environ.get(ENV_SECRET, None)
if signing_secret is None:
signing_secret = make_secret(128)
signer = URLSafeSerializer(signing_secret)

def _asgi_csrf_decorator(app):
def _asgi_csrf_decorator(app: "ASGIApp") -> "ASGIApp":
@wraps(app)
async def app_wrapped_with_csrf(scope, receive, send):
async def app_wrapped_with_csrf(scope: "Scope", receive: "Receive", send: "Send") -> None:
if scope["type"] != "http":
await app(scope, receive, send)
return
Expand All @@ -84,7 +95,7 @@ async def app_wrapped_with_csrf(scope, receive, send):
if not has_csrftoken_cookie:
csrftoken = signer.dumps(make_secret(16), signing_namespace)

def get_csrftoken():
def get_csrftoken() -> Optional[str]:
nonlocal should_set_cookie
nonlocal page_needs_vary_header
page_needs_vary_header = True
Expand All @@ -94,7 +105,7 @@ def get_csrftoken():

scope = {**scope, **{SCOPE_KEY: get_csrftoken}}

async def wrapped_send(event):
async def wrapped_send(event: "Message") -> None:
if event["type"] == "http.response.start":
original_headers = event.get("headers") or []
new_headers = []
Expand Down Expand Up @@ -144,7 +155,7 @@ async def wrapped_send(event):
await app(scope, receive, wrapped_send)
else:
# Check for CSRF token in various places
headers = dict(scope.get("headers" or []))
headers = dict(scope.get("headers") or [])
if secrets.compare_digest(
headers.get(http_header.encode("latin-1"), b"").decode("latin-1"),
csrftoken,
Expand Down Expand Up @@ -174,7 +185,7 @@ async def wrapped_send(event):
if content_type == b"application/x-www-form-urlencoded":
# Consume entire POST body and check for csrftoken field
post_data, replay_receive = await _parse_form_urlencoded(receive)
if secrets.compare_digest(post_data.get(form_input, ""), csrftoken):
if secrets.compare_digest(post_data.get(form_input, ""), csrftoken or ""):
# All is good! Forward on the request and replay the body
await app(scope, replay_receive, wrapped_send)
return
Expand All @@ -186,7 +197,7 @@ async def wrapped_send(event):
elif content_type == b"multipart/form-data":
# Consume non-file items until we see a csrftoken
# If we see a file item first, it's an error
boundary = headers.get(b"content-type").split(b"; boundary=")[1]
boundary = headers.get(b"content-type", "").split(b"; boundary=")[1]
assert boundary is not None, "missing 'boundary' header: {}".format(
repr(headers)
)
Expand All @@ -197,7 +208,7 @@ async def wrapped_send(event):
replay_receive,
) = await _parse_multipart_form_data(boundary, receive)
if not secrets.compare_digest(
csrftoken_from_body or "", csrftoken
csrftoken_from_body or "", csrftoken or ""
):
await send_csrf_failed(
scope,
Expand Down Expand Up @@ -226,7 +237,7 @@ async def wrapped_send(event):
return _asgi_csrf_decorator


async def _parse_form_urlencoded(receive):
async def _parse_form_urlencoded(receive: "Receive") -> Tuple[Dict[str, str], "Receive"]:
# Returns {key: value}, replay_receive
# where replay_receive is an awaitable that can replay what was received
# We ignore cases like foo=one&foo=two because we do not need to
Expand All @@ -241,7 +252,7 @@ async def _parse_form_urlencoded(receive):
body += message.get("body", b"")
more_body = message.get("more_body", False)

async def replay_receive():
async def replay_receive() -> "Message":
if messages:
return messages.pop(0)
else:
Expand All @@ -262,7 +273,7 @@ class FileBeforeToken(Exception):
pass


async def _parse_multipart_form_data(boundary, receive):
async def _parse_multipart_form_data(boundary: bytes, receive: "Receive") -> Tuple[Optional[str], "Receive"]:
# Returns (csrftoken, replay_receive) - or raises an exception
csrftoken = None

Expand All @@ -272,26 +283,34 @@ def on_field(field):
raise TokenFound(csrftoken)

class ErrorOnWrite:
def __init__(self, file_name, field_name, config):
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: Mapping[str, Any]) -> None:
pass

def write(self, data):
def write(self, data: bytes) -> int:
raise FileBeforeToken

def finalize(self) -> None: ...

def close(self) -> None: ...

def set_none(self) -> None: ...

body = b""
more_body = True
messages = []
messages: List["Message"] = []

async def replay_receive():
async def replay_receive() -> "Message":
if messages:
return messages.pop(0)
else:
return await receive()

on_file: OnFileCallback = lambda _: None

form_parser = FormParser(
"multipart/form-data",
on_field,
lambda: None,
on_field=on_field,
on_file=on_file,
boundary=boundary,
FileClass=ErrorOnWrite,
)
Expand All @@ -308,7 +327,7 @@ async def replay_receive():
return None, replay_receive


async def default_send_csrf_failed(scope, send, message_id):
async def default_send_csrf_failed(scope: "Scope", send: "Send", message_id: int) -> None:
assert scope["type"] == "http"
await send(
{
Expand All @@ -323,19 +342,19 @@ async def default_send_csrf_failed(scope, send, message_id):

def asgi_csrf(
app,
cookie_name=DEFAULT_COOKIE_NAME,
http_header=DEFAULT_HTTP_HEADER,
signing_secret=None,
signing_namespace=DEFAULT_SIGNING_NAMESPACE,
always_protect=None,
always_set_cookie=False,
skip_if_scope=None,
cookie_path=DEFAULT_COOKIE_PATH,
cookie_domain=DEFAULT_COOKIE_DOMAIN,
cookie_secure=DEFAULT_COOKIE_SECURE,
cookie_samesite=DEFAULT_COOKIE_SAMESITE,
send_csrf_failed=None,
):
cookie_name: str = DEFAULT_COOKIE_NAME,
http_header: str = DEFAULT_HTTP_HEADER,
signing_secret: Optional[str] = None,
signing_namespace: str = DEFAULT_SIGNING_NAMESPACE,
always_protect: Optional[Container[str]] = None,
always_set_cookie: bool = False,
skip_if_scope: Optional[Callable[["Scope"], bool]] = None,
cookie_path: str = DEFAULT_COOKIE_PATH,
cookie_domain: Optional[str] = DEFAULT_COOKIE_DOMAIN,
cookie_secure: bool = DEFAULT_COOKIE_SECURE,
cookie_samesite: str = DEFAULT_COOKIE_SAMESITE,
send_csrf_failed: Optional[Callable[["Scope", "Send", int], Awaitable[None]]] = None,
) -> "ASGIApp":
return asgi_csrf_decorator(
cookie_name,
http_header,
Expand All @@ -352,7 +371,7 @@ def asgi_csrf(
)(app)


def cookies_from_scope(scope):
def cookies_from_scope(scope: "Scope") -> Dict[str, str]:
cookie = dict(scope.get("headers") or {}).get(b"cookie")
if not cookie:
return {}
Expand All @@ -364,5 +383,5 @@ def cookies_from_scope(scope):
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"


def make_secret(length):
def make_secret(length: int) -> str:
return "".join(secrets.choice(allowed_chars) for i in range(length))
Empty file added py.typed
Empty file.
2 changes: 1 addition & 1 deletion tests/test_asgi_csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def test_prevents_post_if_cookie_not_sent_in_post(app_csrf, csrftoken):

@pytest.mark.asyncio
@pytest.mark.parametrize("custom_errors", (False, True))
async def test_prevents_post_if_cookie_not_sent_in_post(
async def test_prevents_post_if_cookie_different_than_data(
custom_errors, app_csrf, app_csrf_custom_errors, csrftoken
):
async with httpx.AsyncClient(
Expand Down