Skip to content

Commit cd99f4e

Browse files
committed
Implement basic token auth
1 parent 2647e0a commit cd99f4e

File tree

12 files changed

+1359
-16
lines changed

12 files changed

+1359
-16
lines changed

agent_memory_server/auth.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import secrets
12
import threading
23
import time
34
from datetime import UTC, datetime
45
from typing import Any
56

7+
import bcrypt
68
import httpx
79
import structlog
810
from fastapi import Depends, HTTPException, status
@@ -11,6 +13,8 @@
1113
from pydantic import BaseModel
1214

1315
from agent_memory_server.config import settings
16+
from agent_memory_server.utils.keys import Keys
17+
from agent_memory_server.utils.redis import get_redis_conn
1418

1519

1620
logger = structlog.get_logger()
@@ -27,6 +31,15 @@ class UserInfo(BaseModel):
2731
roles: list[str] | None = None
2832

2933

34+
class TokenInfo(BaseModel):
35+
"""Token information stored in Redis."""
36+
37+
description: str
38+
created_at: datetime
39+
expires_at: datetime | None = None
40+
token_hash: str
41+
42+
3043
class JWKSCache:
3144
def __init__(self, cache_duration: int = 3600):
3245
self._cache: dict[str, Any] = {}
@@ -245,10 +258,98 @@ def verify_jwt(token: str) -> UserInfo:
245258
) from e
246259

247260

261+
def generate_token() -> str:
262+
"""Generate a secure random token."""
263+
return secrets.token_urlsafe(32)
264+
265+
266+
def hash_token(token: str) -> str:
267+
"""Hash a token using bcrypt."""
268+
return bcrypt.hashpw(token.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
269+
270+
271+
def verify_token_hash(token: str, token_hash: str) -> bool:
272+
"""Verify a token against its hash."""
273+
try:
274+
return bcrypt.checkpw(token.encode("utf-8"), token_hash.encode("utf-8"))
275+
except Exception as e:
276+
logger.warning("Token hash verification failed", error=str(e))
277+
return False
278+
279+
280+
async def verify_token(token: str) -> UserInfo:
281+
"""Verify a token and return user info."""
282+
try:
283+
redis = await get_redis_conn()
284+
285+
# Get all auth tokens and check each one
286+
# This is not the most efficient approach, but it works for now
287+
# In a production system, you might want to store a mapping of token prefixes
288+
pattern = Keys.auth_token_key("*")
289+
token_keys = []
290+
291+
async for key in redis.scan_iter(pattern):
292+
token_keys.append(key)
293+
294+
for key in token_keys:
295+
token_data = await redis.get(key)
296+
if not token_data:
297+
continue
298+
299+
try:
300+
token_info = TokenInfo.model_validate_json(token_data)
301+
302+
# Check if token matches
303+
if verify_token_hash(token, token_info.token_hash):
304+
# Check if token is expired
305+
if (
306+
token_info.expires_at
307+
and datetime.now(UTC) > token_info.expires_at
308+
):
309+
logger.warning("Token has expired")
310+
raise HTTPException(
311+
status_code=status.HTTP_401_UNAUTHORIZED,
312+
detail="Token has expired",
313+
)
314+
315+
# Return user info for valid token
316+
return UserInfo(
317+
sub="token-user",
318+
aud="token-auth",
319+
scope="admin",
320+
roles=["admin"],
321+
exp=int(token_info.expires_at.timestamp())
322+
if token_info.expires_at
323+
else None,
324+
iat=int(token_info.created_at.timestamp()),
325+
)
326+
327+
except HTTPException:
328+
# Re-raise HTTP exceptions (like token expired)
329+
raise
330+
except Exception as e:
331+
logger.warning("Error processing token", error=str(e))
332+
continue
333+
334+
# If no token matched, authentication failed
335+
raise HTTPException(
336+
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
337+
)
338+
339+
except HTTPException:
340+
raise
341+
except Exception as e:
342+
logger.error("Unexpected error during token verification", error=str(e))
343+
raise HTTPException(
344+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
345+
detail="Internal server error during authentication",
346+
) from e
347+
348+
248349
def get_current_user(
249350
credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme),
250351
) -> UserInfo:
251-
if settings.disable_auth:
352+
if settings.disable_auth or settings.auth_mode == "disabled":
252353
logger.debug("Authentication disabled, returning default user")
253354
return UserInfo(
254355
sub="local-dev-user", aud="local-dev", scope="admin", roles=["admin"]
@@ -268,6 +369,14 @@ def get_current_user(
268369
headers={"WWW-Authenticate": "Bearer"},
269370
)
270371

372+
# Determine authentication mode
373+
if settings.auth_mode == "token" or settings.token_auth_enabled:
374+
import asyncio
375+
376+
return asyncio.run(verify_token(credentials.credentials))
377+
if settings.auth_mode == "oauth2":
378+
return verify_jwt(credentials.credentials)
379+
# Default to OAuth2 for backward compatibility
271380
return verify_jwt(credentials.credentials)
272381

273382

@@ -304,18 +413,42 @@ def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
304413

305414

306415
def verify_auth_config():
307-
if settings.disable_auth:
416+
if settings.disable_auth or settings.auth_mode == "disabled":
308417
logger.warning("Authentication is DISABLED - suitable for development only")
309418
return
310419

420+
if settings.auth_mode == "token" or settings.token_auth_enabled:
421+
logger.info("Token authentication configured")
422+
return
423+
424+
if settings.auth_mode == "oauth2":
425+
if not settings.oauth2_issuer_url:
426+
raise ValueError(
427+
"OAUTH2_ISSUER_URL must be set when OAuth2 authentication is enabled"
428+
)
429+
430+
if not settings.oauth2_audience:
431+
logger.warning(
432+
"OAUTH2_AUDIENCE not set - audience validation will be skipped"
433+
)
434+
435+
logger.info(
436+
"OAuth2 authentication configured",
437+
issuer=settings.oauth2_issuer_url,
438+
audience=settings.oauth2_audience or "not-set",
439+
algorithms=settings.oauth2_algorithms,
440+
)
441+
return
442+
443+
# Default to OAuth2 for backward compatibility
311444
if not settings.oauth2_issuer_url:
312445
raise ValueError("OAUTH2_ISSUER_URL must be set when authentication is enabled")
313446

314447
if not settings.oauth2_audience:
315448
logger.warning("OAUTH2_AUDIENCE not set - audience validation will be skipped")
316449

317450
logger.info(
318-
"OAuth2 authentication configured",
451+
"OAuth2 authentication configured (default)",
319452
issuer=settings.oauth2_issuer_url,
320453
audience=settings.oauth2_audience or "not-set",
321454
algorithms=settings.oauth2_algorithms,

0 commit comments

Comments
 (0)