11# users.py
22import uuid
3- from typing import Optional , AsyncGenerator
3+ from typing import Optional
44from fastapi import Depends , Request
55from fastapi_users import BaseUserManager , FastAPIUsers , UUIDIDMixin , schemas
66from fastapi_users .authentication import AuthenticationBackend , BearerTransport , JWTStrategy
77from fastapi_users .db import SQLAlchemyUserDatabase
88from transformerlab .shared .models .user_model import User , get_async_session
99from sqlalchemy .ext .asyncio import AsyncSession
10+ from jose import jwt as _jose_jwt
11+ from datetime import datetime , timedelta
1012
1113
1214# --- Pydantic Schemas for API interactions ---
@@ -25,6 +27,8 @@ class UserUpdate(schemas.BaseUserUpdate):
2527
2628# --- User Manager (Handles registration, password reset, etc.) ---
2729SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !!
30+ REFRESH_SECRET = "YOUR_REFRESH_TOKEN_SECRET" # !! USE A DIFFERENT SECRET !!
31+ REFRESH_LIFETIME = 60 * 60 * 24 * 7 # 7 days
2832
2933
3034class UserManager (UUIDIDMixin , BaseUserManager [User , uuid .UUID ]):
@@ -68,3 +72,67 @@ def get_jwt_strategy() -> JWTStrategy:
6872# --- Dependency for Protected Routes ---
6973# This is what you'll use in your route decorators
7074current_active_user = fastapi_users .current_user (active = True )
75+
76+
77+ def get_refresh_strategy () -> JWTStrategy :
78+ return JWTStrategy (secret = REFRESH_SECRET , lifetime_seconds = REFRESH_LIFETIME )
79+
80+
81+ # --- Small helper to create access + refresh tokens for manual flows (e.g. refresh endpoint) ---
82+
83+
84+ class _JWTAuthenticationHelper :
85+ """Minimal helper that mirrors a login response (access + refresh token).
86+
87+ We keep this small and explicit so callers (like the `refresh` endpoint in
88+ `routers/auth.py`) can create new access tokens when given a valid
89+ refresh token.
90+ """
91+
92+ def __init__ (
93+ self ,
94+ access_secret : str ,
95+ refresh_secret : str ,
96+ access_lifetime : int = 3600 ,
97+ refresh_lifetime : int = REFRESH_LIFETIME ,
98+ ):
99+ self .access_secret = access_secret
100+ self .refresh_secret = refresh_secret
101+ self .access_lifetime = access_lifetime
102+ self .refresh_lifetime = refresh_lifetime
103+
104+ def _create_token (self , user , secret : str , lifetime_seconds : int ) -> str :
105+ now = datetime .utcnow ()
106+ exp = now + timedelta (seconds = lifetime_seconds )
107+ payload = {
108+ "sub" : str (user .id ),
109+ "email" : getattr (user , "email" , None ),
110+ "exp" : int (exp .timestamp ()),
111+ }
112+ return _jose_jwt .encode (payload , secret , algorithm = "HS256" )
113+
114+ def get_login_response (self , user ) -> dict :
115+ """Return a dict similar to what FastAPI-Users returns on login.
116+
117+ Keys:
118+ - access_token: short-lived JWT
119+ - refresh_token: long-lived JWT (can be validated with refresh strategy)
120+ - token_type: 'bearer'
121+ - expires_in: seconds until access token expiry
122+ """
123+ access = self ._create_token (user , self .access_secret , self .access_lifetime )
124+ refresh = self ._create_token (user , self .refresh_secret , self .refresh_lifetime )
125+ return {
126+ "access_token" : access ,
127+ "refresh_token" : refresh ,
128+ "token_type" : "bearer" ,
129+ "expires_in" : self .access_lifetime ,
130+ }
131+
132+
133+ # Module-level helpers for imports elsewhere
134+ jwt_authentication = _JWTAuthenticationHelper (
135+ SECRET , REFRESH_SECRET , access_lifetime = 3600 , refresh_lifetime = REFRESH_LIFETIME
136+ )
137+ access_strategy = get_jwt_strategy ()
138+ refresh_strategy = get_refresh_strategy ()
0 commit comments