Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES=120

# List of origins that can access this API, separated by a comma, eg:
# CORS_ORIGINS=http://localhost,https://www.gnramsay.com
# If you want all origins to access (the default), use * or comment out:
# For public APIs using Bearer tokens, * is acceptable but will log a warning.
# Use explicit origins if serving browser clients.
CORS_ORIGINS=*

# Email Settings - OPTIONAL for development, REQUIRED for production
Expand Down
8 changes: 8 additions & 0 deletions SECURITY-REVIEW.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

### 1. Refresh Token Authentication Bypass

> [!NOTE]
> ✅ **Done**: Access tokens now include `typ="access"` and `get_jwt_user`
> enforces it; refresh or typ-less tokens are rejected.

**Location**: `app/managers/auth.py:478-544` (`get_jwt_user`)

- **Issue**: Refresh tokens can be used as access tokens because access tokens
Expand All @@ -19,6 +23,10 @@

### 2. CORS Wildcard with Credentials - SEVERE

> [!NOTE]
> ✅ **Done**: CORS credentials are disabled for the API and startup now warns
> when `CORS_ORIGINS=*` is used.

**Location**: `app/main.py:169`, `app/config/settings.py:56`

- **Issue**: Default CORS configuration allows ALL origins (`cors_origins="*"`)
Expand Down
21 changes: 17 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@

BLIND_USER_ERROR = 66

# set up CORS
cors_list = [
origin.strip()
for origin in get_settings().cors_origins.split(",")
if origin.strip()
]

# gatekeeper to ensure the user has read the docs and noted the major changes
# since the last version.
if not get_settings().i_read_the_damn_docs:
Expand Down Expand Up @@ -62,6 +69,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[Any, None]:
# Initialize loguru logging within the server process.
get_log_config()

if "*" in cors_list:
warning_msg = (
"CORS_ORIGINS is set to '*', allowing any origin to access the "
"API. This is fine for public APIs with bearer tokens, but you "
"should set explicit origins if serving browser clients."
)
logger.warning(warning_msg) # Console via uvicorn
loguru_logger.warning(warning_msg) # File via loguru

redis_client = None

# Test database connection
Expand Down Expand Up @@ -160,13 +176,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[Any, None]:
name="static",
)

# set up CORS
cors_list = (get_settings().cors_origins).split(",")

app.add_middleware(
CORSMiddleware,
allow_origins=cors_list,
allow_credentials=True,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
Expand Down
11 changes: 11 additions & 0 deletions app/managers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def encode_token(user: User) -> str:
try:
payload = {
"sub": user.id,
"typ": "access",
"exp": datetime.datetime.now(tz=datetime.timezone.utc)
+ datetime.timedelta(
minutes=get_settings().access_token_expire_minutes
Expand Down Expand Up @@ -492,6 +493,16 @@ async def get_jwt_user(
algorithms=["HS256"],
options={"verify_sub": False},
)
if payload.get("typ") != "access":
increment_auth_failure("invalid_token", "jwt")
category_logger.warning(
"Authentication attempted with non-access token",
LogCategory.AUTH,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ResponseMessages.INVALID_TOKEN,
)
user_data = await get_user_by_id_(payload["sub"], db)

# Check user validity - user must exist, be verified, and not banned
Expand Down
4 changes: 2 additions & 2 deletions docs/deployment/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Before deploying to production, ensure you've configured all critical settings p
- Used for JWT token signing and session security

- **CORS_ORIGINS**
- **Don't** leave as `*` in production
- Set to your actual frontend domain(s)
- For browser clients, set to your actual frontend domain(s)
- For public APIs using Bearer tokens, `*` is acceptable but will log a warning
- Example: `CORS_ORIGINS=https://app.example.com,https://www.example.com`
- Multiple origins separated by commas

Expand Down
7 changes: 6 additions & 1 deletion docs/important.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ API.

## Breaking Changes in `HEAD`

None.
### CORS credentials disabled by default

Credentialed CORS requests are now disabled for the API. If you relied on
cookie-based auth from browser clients, you must switch to Bearer tokens or
explicitly re-enable credentials and restrict `CORS_ORIGINS` to your frontend
domains.

## Breaking Changes in `0.8.0`

Expand Down
8 changes: 5 additions & 3 deletions docs/usage/configuration/environment.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,11 @@ is an HTTP-header based mechanism that allows a server to indicate any origins
(domain, scheme, or port) other than its own from which a browser should permit
loading resources.

For a **PUBLIC API** (unless its going through an API gateway!), set
`CORS_ORIGINS=*`, otherwise list the domains (**and ports**) required. If you
use an API gateway of some nature, that will probably need to be listed.
For a **PUBLIC API** using Bearer tokens, `CORS_ORIGINS=*` is acceptable.
If you serve browser clients, list the required domains (**and ports**) instead
to restrict access. The app will log a warning when `CORS_ORIGINS=*` is set so
you can confirm the intent. If you use an API gateway, that will probably need
to be listed.

```ini
CORS_ORIGINS=*
Expand Down
60 changes: 60 additions & 0 deletions tests/integration/test_protected_user_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Integration tests for user routes."""

import datetime

import jwt
import pytest
from fastapi import status

from app.config.settings import get_settings
from app.database.helpers import hash_password
from app.managers.auth import AuthManager
from app.models.user import User


@pytest.mark.integration
Expand Down Expand Up @@ -64,3 +70,57 @@ async def test_routes_bad_auth(self, client, route) -> None:

assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response.json() == {"detail": "That token is Invalid"}

@pytest.mark.asyncio
@pytest.mark.parametrize(
"route",
test_routes,
)
async def test_routes_refresh_token_rejected(
self, client, test_db, route
) -> None:
"""Test that refresh tokens are rejected on protected routes."""
test_user = User(**self.test_user)
test_db.add(test_user)
await test_db.commit()
refresh_token = AuthManager.encode_refresh_token(test_user)

route_name, method = route
fn = getattr(client, method)
response = await fn(
route_name, headers={"Authorization": f"Bearer {refresh_token}"}
)

assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response.json() == {"detail": "That token is Invalid"}

@pytest.mark.asyncio
@pytest.mark.parametrize(
"route",
test_routes,
)
async def test_routes_missing_typ_rejected(
self, client, test_db, route
) -> None:
"""Test that tokens without typ are rejected on protected routes."""
test_user = User(**self.test_user)
test_db.add(test_user)
await test_db.commit()
token = jwt.encode(
{
"sub": test_user.id,
"exp": datetime.datetime.now(tz=datetime.timezone.utc)
+ datetime.timedelta(minutes=10),
},
get_settings().secret_key,
algorithm="HS256",
)

route_name, method = route
fn = getattr(client, method)
response = await fn(
route_name, headers={"Authorization": f"Bearer {token}"}
)

assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response.json() == {"detail": "That token is Invalid"}
2 changes: 2 additions & 0 deletions tests/unit/test_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_encode_token(self) -> None:
options={"verify_sub": False},
)
assert payload["sub"] == 1
assert payload["typ"] == "access"
assert isinstance(payload["exp"], int)
# TODO(seapagan): better comparison to ensure the exp is in the future
# but close to the expected expiry time taking into account the setting
Expand All @@ -68,6 +69,7 @@ def test_encode_refresh_token(self) -> None:
)

assert payload["sub"] == 1
assert payload["typ"] == "refresh"
assert isinstance(payload["exp"], int)
# TODO(seapagan): better comparison to ensure the exp is in the future
# but close to the expected expiry time taking into account the expiry
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_cors_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for CORS middleware configuration."""

from typing import Any, cast

import pytest
from fastapi.middleware.cors import CORSMiddleware

from app.main import app


@pytest.mark.unit
def test_cors_middleware_disables_credentials() -> None:
"""Ensure the API does not allow credentialed CORS by default."""
cors_middleware = next(
middleware
for middleware in app.user_middleware
if cast("Any", middleware.cls) is CORSMiddleware
)

kwargs = cast("dict[str, Any]", cors_middleware.kwargs)
allow_origins = cast("list[str]", kwargs["allow_origins"])

assert kwargs["allow_credentials"] is False
assert "*" in allow_origins
47 changes: 47 additions & 0 deletions tests/unit/test_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import datetime

import jwt
import pytest
from fastapi import BackgroundTasks, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials

from app.config.settings import get_settings
from app.managers.auth import ResponseMessages, get_jwt_user
from app.managers.user import UserManager
from app.models.user import User
Expand Down Expand Up @@ -59,6 +61,51 @@ async def test_jwt_auth_invalid_token(self, test_db, mocker) -> None:
assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc.value.detail == ResponseMessages.INVALID_TOKEN

async def test_jwt_auth_refresh_token_rejected(
self, test_db, mocker
) -> None:
"""Test with a refresh token used as access token."""
_, refresh = await UserManager.register(self.test_user, test_db)
mock_req = mocker.patch(self.mock_request_path)
mock_req.headers = {"Authorization": f"Bearer {refresh}"}
mock_credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials=refresh
)

with pytest.raises(HTTPException) as exc:
await get_jwt_user(
request=mock_req, db=test_db, credentials=mock_credentials
)

assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc.value.detail == ResponseMessages.INVALID_TOKEN

async def test_jwt_auth_missing_typ_rejected(self, test_db, mocker) -> None:
"""Test with a token missing typ claim."""
await UserManager.register(self.test_user, test_db)
token = jwt.encode(
{
"sub": 1,
"exp": datetime.datetime.now(tz=datetime.timezone.utc)
+ datetime.timedelta(minutes=10),
},
get_settings().secret_key,
algorithm="HS256",
)
mock_req = mocker.patch(self.mock_request_path)
mock_req.headers = {"Authorization": f"Bearer {token}"}
mock_credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials=token
)

with pytest.raises(HTTPException) as exc:
await get_jwt_user(
request=mock_req, db=test_db, credentials=mock_credentials
)

assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc.value.detail == ResponseMessages.INVALID_TOKEN

async def test_jwt_auth_no_auth_header(self, test_db, mocker) -> None:
"""Test with no authorization header."""
mock_req = mocker.patch(self.mock_request_path)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,32 @@ async def test_lifespan_warns_on_missing_redis_password(
for record in caplog.records
)

async def test_lifespan_warns_on_cors_wildcard(
self, caplog, mocker
) -> None:
"""Ensure a warning is logged when CORS is set to '*'."""
app = FastAPI()
mock_session = mocker.patch(self.mock_session)
mock_connection = (
mock_session.return_value.__aenter__.return_value.connection
)
mock_connection.return_value = None

mocker.patch("app.main.cors_list", ["*"])
mocker.patch("app.main.get_settings").return_value.cache_enabled = False
loguru_warning = mocker.patch("app.main.loguru_logger.warning")

caplog.set_level(logging.WARNING)

async with lifespan(app):
pass # NOSONAR

assert any(
"CORS_ORIGINS is set to '*'" in record.message
for record in caplog.records
)
loguru_warning.assert_called_once()

async def test_lifespan_initializes_redis_and_closes_client(
self, caplog, mocker
) -> None:
Expand Down