Skip to content

Commit b34e7ed

Browse files
committed
Protect GraphQL view
1 parent 5c4f009 commit b34e7ed

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

app/views.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
from typing import Optional
3+
4+
from starlette.datastructures import Headers
5+
from starlette.requests import Request
6+
from starlette.responses import Response
7+
from starlette.types import Receive, Scope, Send
8+
from strawberry.asgi import GraphQL as BaseGraphQL
9+
10+
11+
def is_token_valid(token: str) -> bool:
12+
return token == os.environ["ALLOWED_TOKEN"]
13+
14+
15+
def get_token_from_authorization_header(value: Optional[str]) -> Optional[str]:
16+
if not value:
17+
return None
18+
19+
prefix, token = value.split(" ")
20+
21+
assert prefix == "Bearer"
22+
23+
return token
24+
25+
26+
class GraphQL(BaseGraphQL):
27+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
28+
request = Request(scope=scope, receive=receive)
29+
30+
token = self._get_token(scope)
31+
32+
# Allow GraphiQL
33+
if request.method != "GET":
34+
if token is None:
35+
response = Response(status_code=401)
36+
37+
return await response(scope, receive, send)
38+
39+
if not is_token_valid(token):
40+
response = Response(status_code=403)
41+
42+
return await response(scope, receive, send)
43+
44+
await super().__call__(scope=scope, receive=receive, send=send)
45+
46+
def _get_token(self, scope: Scope) -> Optional[str]:
47+
headers = Headers(scope=scope)
48+
authorization_header = headers.get("Authorization")
49+
50+
return get_token_from_authorization_header(
51+
authorization_header,
52+
)

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import uvicorn
22
from starlette.applications import Starlette
3-
from strawberry.asgi import GraphQL
43

54
from app.schema import schema
5+
from app.views import GraphQL
66

77
app = Starlette(debug=False)
88
app.add_route("/graphql", GraphQL(schema))

0 commit comments

Comments
 (0)