Skip to content

Commit 70f964d

Browse files
committed
Refactor auth cache with type hints, RateLimiterService, token cleanup, Lua preload, and expanded test coverage.
1 parent 05c0e55 commit 70f964d

File tree

8 files changed

+1207
-172
lines changed

8 files changed

+1207
-172
lines changed

examples/auth/cache.py

Lines changed: 364 additions & 92 deletions
Large diffs are not rendered by default.

examples/auth/lifespan.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi import FastAPI
88
from fastapi_limiter import FastAPILimiter
99

10+
from examples.auth.cache import _DEFAULT_SERVICE
1011
from examples.auth.database import engine
1112
from examples.auth.jwt_scheduler import start_jwt_scheduler
1213
from examples.auth.models import Base
@@ -26,10 +27,10 @@ async def global_lifespan(app: FastAPI) -> AsyncGenerator[None]:
2627
None: Control is yielded back to the application
2728
after performing startup tasks.
2829
"""
29-
# Step 1: Start the scheduler (e.g., for rotating JWT secret keys).
30+
# Start the scheduler (e.g., for rotating JWT secret keys).
3031
scheduler = start_jwt_scheduler(app)
3132

32-
# Step 2: Initialise Redis connection and rate limiter.
33+
# Initialise Redis connection and rate limiter.
3334
redis_host: str = os.getenv('REDIS_HOST', '127.0.0.1')
3435
redis_port: str = os.getenv('REDIS_PORT', '6379')
3536
redis_password: str = os.getenv('REDIS_PASSWORD', '')
@@ -39,7 +40,14 @@ async def global_lifespan(app: FastAPI) -> AsyncGenerator[None]:
3940
redis_conn = await app.state.redis_client.connect()
4041
await FastAPILimiter.init(redis_conn)
4142

42-
# Step 3: Optionally create database tables on startup
43+
# Preload Lua scripts into Redis (if any).
44+
try:
45+
await _DEFAULT_SERVICE.preload_script(redis_conn)
46+
except Exception:
47+
# Log the error or handle it as needed
48+
pass
49+
50+
# Optionally create database tables on startup
4351
async with engine.begin() as conn:
4452
await conn.run_sync(Base.metadata.create_all)
4553

examples/auth/token_cleanup.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import jwt
6+
from redis.asyncio import Redis
7+
8+
from examples.auth.cache import get_user_data
9+
from examples.auth.cache import set_user_data
10+
from examples.auth.config import Settings
11+
12+
13+
settings = Settings()
14+
15+
16+
async def prune_user_cache(
17+
redis_pool: Redis,
18+
username: str,
19+
) -> dict[str, object] | None:
20+
"""
21+
Prune a user's cached authentication data in Redis.
22+
23+
Args:
24+
redis_pool: Asynchronous Redis client/connection.
25+
username: Username whose cache entry should be pruned.
26+
27+
Returns:
28+
The updated cache dictionary if present, otherwise ``None`` when no
29+
cache entry exists.
30+
"""
31+
cache: dict[str, object] | None = await get_user_data(redis_pool, username)
32+
if not cache:
33+
return None
34+
35+
changed: bool = False
36+
now: int = int(time.time())
37+
38+
# Refresh tokens pruning
39+
# The cache may contain a list of refresh tokens; keep those that are
40+
# valid according to ``jwt.decode`` and drop expired/invalid ones.
41+
refresh_raw = cache.get('refresh_tokens', [])
42+
refresh_tokens: list[str]
43+
if isinstance(refresh_raw, list):
44+
# Ensure a typed list of strings only
45+
refresh_tokens = [t for t in refresh_raw if isinstance(t, str)]
46+
else:
47+
refresh_tokens = []
48+
new_refresh_tokens: list[str] = []
49+
for tok in refresh_tokens:
50+
try:
51+
# Decode to verify validity and expiry
52+
jwt.decode(
53+
tok,
54+
settings.authjwt_secret_key,
55+
algorithms=[settings.ALGORITHM],
56+
)
57+
# Keep token only if decode succeeds
58+
new_refresh_tokens.append(tok)
59+
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
60+
changed = True
61+
# Drop expired/invalid tokens silently
62+
continue
63+
if new_refresh_tokens != refresh_tokens:
64+
cache['refresh_tokens'] = new_refresh_tokens
65+
66+
# JTI metadata pruning
67+
# jti_list holds active JWT IDs; jti_meta maps JTI -> expiry timestamp.
68+
jti_list_raw = cache.get('jti_list', [])
69+
jti_list: list[str]
70+
if isinstance(jti_list_raw, list):
71+
jti_list = [j for j in jti_list_raw if isinstance(j, str)]
72+
else:
73+
jti_list = []
74+
75+
jti_meta_raw = cache.get('jti_meta', {})
76+
jti_meta: dict[str, int]
77+
if isinstance(jti_meta_raw, dict):
78+
# Build a strictly typed mapping of str -> int
79+
jti_meta = {
80+
k: int(v)
81+
for k, v in jti_meta_raw.items()
82+
if isinstance(k, str) and isinstance(v, int)
83+
}
84+
else:
85+
jti_meta = {}
86+
if jti_meta:
87+
new_jti_list: list[str] = []
88+
for j in jti_list:
89+
exp_ts: int = int(jti_meta.get(j, 0))
90+
# Keep if no expiry is tracked (0) or it is still in the future
91+
if exp_ts == 0 or exp_ts > now:
92+
new_jti_list.append(j)
93+
else:
94+
changed = True
95+
96+
# Remove stale jti_meta entries not in list or already expired
97+
new_jti_meta: dict[str, int] = {}
98+
for j, exp in jti_meta.items():
99+
if j in new_jti_list and exp > now:
100+
new_jti_meta[j] = int(exp)
101+
else:
102+
changed = True
103+
104+
if new_jti_list != jti_list:
105+
cache['jti_list'] = new_jti_list
106+
cache['jti_meta'] = new_jti_meta
107+
108+
# Persist only if an actual change occurred
109+
if changed:
110+
await set_user_data(redis_pool, username, cache)
111+
112+
return cache

examples/auth/user_service.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import time
4+
from typing import TypeAlias
45

56
from fastapi import HTTPException
67
from sqlalchemy import select
@@ -9,23 +10,40 @@
910

1011
from examples.auth.models import User
1112

12-
# Cache mechanism for storing user site information
13-
_user_sites_cache: dict[str, tuple[list[str], float]] = {}
14-
_cache_ttl: int = 300 # Cache time-to-live in seconds (5 minutes)
13+
# A cache entry stores: (list of site names, cached-at epoch seconds).
14+
CacheEntry: TypeAlias = tuple[list[str], float]
15+
16+
# Process-local cache for storing user site information.
17+
_user_sites_cache: dict[str, CacheEntry] = {}
18+
19+
# Cache time-to-live in seconds (5 minutes).
20+
_cache_ttl: int = 300
1521

1622

1723
async def get_user_sites_cached(username: str, db: AsyncSession) -> list[str]:
18-
"""Return site names the user can access, with simple in-memory caching.
24+
"""
25+
Return site names the user may access, with simple in-memory caching.
26+
27+
Args:
28+
username: The unique username to resolve.
29+
db: An asynchronous SQLAlchemy session used for the lookup.
30+
31+
Returns:
32+
A list of site names that the user may access. The list order follows
33+
the ORM relationship ordering as returned by the database.
1934
20-
Raises HTTPException(404) if the user is not found.
35+
Raises:
36+
HTTPException: With status code 404 if the user is not found.
2137
"""
2238
current_time: float = time.time()
2339

2440
if username in _user_sites_cache:
41+
# Fast path: honour TTL and return cached site names when still fresh.
2542
cached_names, cached_time = _user_sites_cache[username]
2643
if current_time - cached_time < _cache_ttl:
2744
return cached_names
2845

46+
# Query the user and their sites in one round-trip.
2947
stmt_user = (
3048
select(User)
3149
.where(User.username == username)
@@ -35,6 +53,40 @@ async def get_user_sites_cached(username: str, db: AsyncSession) -> list[str]:
3553
if not user_obj:
3654
raise HTTPException(status_code=404, detail='User not found')
3755

56+
# Extract and cache the site names with the current timestamp.
3857
site_names: list[str] = [site.name for site in user_obj.sites]
3958
_user_sites_cache[username] = (site_names, current_time)
4059
return site_names
60+
61+
62+
async def get_user_and_sites(
63+
db: AsyncSession, username: str,
64+
) -> tuple[User, list[str], str]:
65+
"""
66+
Fetch the user, their site names, and role from the database.
67+
68+
Args:
69+
db: An asynchronous SQLAlchemy session.
70+
username: The username to query.
71+
72+
Returns:
73+
A 3-tuple of ``(user, site_names, role)`` where:
74+
- ``user`` is the fully loaded ``User`` ORM instance,
75+
- ``site_names`` is a list of the user's site names, and
76+
- ``role`` is the user's role as a string.
77+
78+
Raises:
79+
HTTPException: With status code 401 if the user cannot be found.
80+
"""
81+
stmt_user = (
82+
select(User)
83+
.where(User.username == username)
84+
.options(selectinload(User.sites))
85+
)
86+
result = await db.execute(stmt_user)
87+
user: User | None = result.scalars().first()
88+
if not user:
89+
raise HTTPException(status_code=401, detail='Invalid user')
90+
user_role: str = user.role
91+
user_site_names: list[str] = [site.name for site in user.sites]
92+
return user, user_site_names, user_role

0 commit comments

Comments
 (0)