1
+ import secrets
1
2
import threading
2
3
import time
3
4
from datetime import UTC , datetime
4
5
from typing import Any
5
6
7
+ import bcrypt
6
8
import httpx
7
9
import structlog
8
10
from fastapi import Depends , HTTPException , status
11
13
from pydantic import BaseModel
12
14
13
15
from agent_memory_server .config import settings
16
+ from agent_memory_server .utils .keys import Keys
17
+ from agent_memory_server .utils .redis import get_redis_conn
14
18
15
19
16
20
logger = structlog .get_logger ()
@@ -27,6 +31,15 @@ class UserInfo(BaseModel):
27
31
roles : list [str ] | None = None
28
32
29
33
34
+ class TokenInfo (BaseModel ):
35
+ """Token information stored in Redis."""
36
+
37
+ description : str
38
+ created_at : datetime
39
+ expires_at : datetime | None = None
40
+ token_hash : str
41
+
42
+
30
43
class JWKSCache :
31
44
def __init__ (self , cache_duration : int = 3600 ):
32
45
self ._cache : dict [str , Any ] = {}
@@ -245,10 +258,98 @@ def verify_jwt(token: str) -> UserInfo:
245
258
) from e
246
259
247
260
261
+ def generate_token () -> str :
262
+ """Generate a secure random token."""
263
+ return secrets .token_urlsafe (32 )
264
+
265
+
266
+ def hash_token (token : str ) -> str :
267
+ """Hash a token using bcrypt."""
268
+ return bcrypt .hashpw (token .encode ("utf-8" ), bcrypt .gensalt ()).decode ("utf-8" )
269
+
270
+
271
+ def verify_token_hash (token : str , token_hash : str ) -> bool :
272
+ """Verify a token against its hash."""
273
+ try :
274
+ return bcrypt .checkpw (token .encode ("utf-8" ), token_hash .encode ("utf-8" ))
275
+ except Exception as e :
276
+ logger .warning ("Token hash verification failed" , error = str (e ))
277
+ return False
278
+
279
+
280
+ async def verify_token (token : str ) -> UserInfo :
281
+ """Verify a token and return user info."""
282
+ try :
283
+ redis = await get_redis_conn ()
284
+
285
+ # Get all auth tokens and check each one
286
+ # This is not the most efficient approach, but it works for now
287
+ # In a production system, you might want to store a mapping of token prefixes
288
+ pattern = Keys .auth_token_key ("*" )
289
+ token_keys = []
290
+
291
+ async for key in redis .scan_iter (pattern ):
292
+ token_keys .append (key )
293
+
294
+ for key in token_keys :
295
+ token_data = await redis .get (key )
296
+ if not token_data :
297
+ continue
298
+
299
+ try :
300
+ token_info = TokenInfo .model_validate_json (token_data )
301
+
302
+ # Check if token matches
303
+ if verify_token_hash (token , token_info .token_hash ):
304
+ # Check if token is expired
305
+ if (
306
+ token_info .expires_at
307
+ and datetime .now (UTC ) > token_info .expires_at
308
+ ):
309
+ logger .warning ("Token has expired" )
310
+ raise HTTPException (
311
+ status_code = status .HTTP_401_UNAUTHORIZED ,
312
+ detail = "Token has expired" ,
313
+ )
314
+
315
+ # Return user info for valid token
316
+ return UserInfo (
317
+ sub = "token-user" ,
318
+ aud = "token-auth" ,
319
+ scope = "admin" ,
320
+ roles = ["admin" ],
321
+ exp = int (token_info .expires_at .timestamp ())
322
+ if token_info .expires_at
323
+ else None ,
324
+ iat = int (token_info .created_at .timestamp ()),
325
+ )
326
+
327
+ except HTTPException :
328
+ # Re-raise HTTP exceptions (like token expired)
329
+ raise
330
+ except Exception as e :
331
+ logger .warning ("Error processing token" , error = str (e ))
332
+ continue
333
+
334
+ # If no token matched, authentication failed
335
+ raise HTTPException (
336
+ status_code = status .HTTP_401_UNAUTHORIZED , detail = "Invalid token"
337
+ )
338
+
339
+ except HTTPException :
340
+ raise
341
+ except Exception as e :
342
+ logger .error ("Unexpected error during token verification" , error = str (e ))
343
+ raise HTTPException (
344
+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
345
+ detail = "Internal server error during authentication" ,
346
+ ) from e
347
+
348
+
248
349
def get_current_user (
249
350
credentials : HTTPAuthorizationCredentials | None = Depends (oauth2_scheme ),
250
351
) -> UserInfo :
251
- if settings .disable_auth :
352
+ if settings .disable_auth or settings . auth_mode == "disabled" :
252
353
logger .debug ("Authentication disabled, returning default user" )
253
354
return UserInfo (
254
355
sub = "local-dev-user" , aud = "local-dev" , scope = "admin" , roles = ["admin" ]
@@ -268,6 +369,14 @@ def get_current_user(
268
369
headers = {"WWW-Authenticate" : "Bearer" },
269
370
)
270
371
372
+ # Determine authentication mode
373
+ if settings .auth_mode == "token" or settings .token_auth_enabled :
374
+ import asyncio
375
+
376
+ return asyncio .run (verify_token (credentials .credentials ))
377
+ if settings .auth_mode == "oauth2" :
378
+ return verify_jwt (credentials .credentials )
379
+ # Default to OAuth2 for backward compatibility
271
380
return verify_jwt (credentials .credentials )
272
381
273
382
@@ -304,18 +413,42 @@ def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
304
413
305
414
306
415
def verify_auth_config ():
307
- if settings .disable_auth :
416
+ if settings .disable_auth or settings . auth_mode == "disabled" :
308
417
logger .warning ("Authentication is DISABLED - suitable for development only" )
309
418
return
310
419
420
+ if settings .auth_mode == "token" or settings .token_auth_enabled :
421
+ logger .info ("Token authentication configured" )
422
+ return
423
+
424
+ if settings .auth_mode == "oauth2" :
425
+ if not settings .oauth2_issuer_url :
426
+ raise ValueError (
427
+ "OAUTH2_ISSUER_URL must be set when OAuth2 authentication is enabled"
428
+ )
429
+
430
+ if not settings .oauth2_audience :
431
+ logger .warning (
432
+ "OAUTH2_AUDIENCE not set - audience validation will be skipped"
433
+ )
434
+
435
+ logger .info (
436
+ "OAuth2 authentication configured" ,
437
+ issuer = settings .oauth2_issuer_url ,
438
+ audience = settings .oauth2_audience or "not-set" ,
439
+ algorithms = settings .oauth2_algorithms ,
440
+ )
441
+ return
442
+
443
+ # Default to OAuth2 for backward compatibility
311
444
if not settings .oauth2_issuer_url :
312
445
raise ValueError ("OAUTH2_ISSUER_URL must be set when authentication is enabled" )
313
446
314
447
if not settings .oauth2_audience :
315
448
logger .warning ("OAUTH2_AUDIENCE not set - audience validation will be skipped" )
316
449
317
450
logger .info (
318
- "OAuth2 authentication configured" ,
451
+ "OAuth2 authentication configured (default) " ,
319
452
issuer = settings .oauth2_issuer_url ,
320
453
audience = settings .oauth2_audience or "not-set" ,
321
454
algorithms = settings .oauth2_algorithms ,
0 commit comments