Skip to content

Commit d634d83

Browse files
committed
first auth support
1 parent e072a12 commit d634d83

File tree

4 files changed

+160
-60
lines changed

4 files changed

+160
-60
lines changed

api.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
batched_prompts,
5151
recipes,
5252
remote,
53+
auth2,
5354
)
5455
import torch
5556

@@ -81,15 +82,7 @@
8182

8283
from dotenv import load_dotenv
8384

84-
from transformerlab.models.users import (
85-
fastapi_users,
86-
auth_backend,
87-
current_active_user,
88-
UserRead,
89-
UserCreate,
90-
UserUpdate,
91-
)
92-
from transformerlab.routers.test_users import router as users_router
85+
9386
from transformerlab.shared.models.user_model import create_db_and_tables, User
9487

9588
load_dotenv()
@@ -241,25 +234,7 @@ async def validation_exception_handler(request, exc):
241234
app.include_router(batched_prompts.router)
242235
app.include_router(remote.router)
243236
app.include_router(fastchat_openai_api.router)
244-
245-
# Include Auth and Registration Routers
246-
app.include_router(
247-
fastapi_users.get_auth_router(auth_backend),
248-
prefix="/auth/jwt",
249-
tags=["auth"],
250-
)
251-
app.include_router(
252-
fastapi_users.get_register_router(UserRead, UserCreate),
253-
prefix="/auth",
254-
tags=["auth"],
255-
)
256-
# Include User Management Router (allows authenticated users to view/update their profile)
257-
app.include_router(
258-
fastapi_users.get_users_router(UserRead, UserUpdate),
259-
prefix="/users",
260-
tags=["users"],
261-
)
262-
app.include_router(users_router)
237+
app.include_router(auth2.router)
263238

264239
# Authentication and session management routes
265240
if os.getenv("TFL_MULTITENANT") == "true":

transformerlab/models/users.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# users.py
22
import uuid
3-
from typing import Optional, AsyncGenerator
3+
from typing import Optional
44
from fastapi import Depends, Request
55
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas
66
from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy
77
from fastapi_users.db import SQLAlchemyUserDatabase
88
from transformerlab.shared.models.user_model import User, get_async_session
99
from sqlalchemy.ext.asyncio import AsyncSession
10+
from jose import jwt as _jose_jwt
11+
from datetime import datetime, timedelta
1012

1113

1214
# --- Pydantic Schemas for API interactions ---
@@ -25,6 +27,8 @@ class UserUpdate(schemas.BaseUserUpdate):
2527

2628
# --- User Manager (Handles registration, password reset, etc.) ---
2729
SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !!
30+
REFRESH_SECRET = "YOUR_REFRESH_TOKEN_SECRET" # !! USE A DIFFERENT SECRET !!
31+
REFRESH_LIFETIME = 60 * 60 * 24 * 7 # 7 days
2832

2933

3034
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
@@ -68,3 +72,67 @@ def get_jwt_strategy() -> JWTStrategy:
6872
# --- Dependency for Protected Routes ---
6973
# This is what you'll use in your route decorators
7074
current_active_user = fastapi_users.current_user(active=True)
75+
76+
77+
def get_refresh_strategy() -> JWTStrategy:
78+
return JWTStrategy(secret=REFRESH_SECRET, lifetime_seconds=REFRESH_LIFETIME)
79+
80+
81+
# --- Small helper to create access + refresh tokens for manual flows (e.g. refresh endpoint) ---
82+
83+
84+
class _JWTAuthenticationHelper:
85+
"""Minimal helper that mirrors a login response (access + refresh token).
86+
87+
We keep this small and explicit so callers (like the `refresh` endpoint in
88+
`routers/auth.py`) can create new access tokens when given a valid
89+
refresh token.
90+
"""
91+
92+
def __init__(
93+
self,
94+
access_secret: str,
95+
refresh_secret: str,
96+
access_lifetime: int = 3600,
97+
refresh_lifetime: int = REFRESH_LIFETIME,
98+
):
99+
self.access_secret = access_secret
100+
self.refresh_secret = refresh_secret
101+
self.access_lifetime = access_lifetime
102+
self.refresh_lifetime = refresh_lifetime
103+
104+
def _create_token(self, user, secret: str, lifetime_seconds: int) -> str:
105+
now = datetime.utcnow()
106+
exp = now + timedelta(seconds=lifetime_seconds)
107+
payload = {
108+
"sub": str(user.id),
109+
"email": getattr(user, "email", None),
110+
"exp": int(exp.timestamp()),
111+
}
112+
return _jose_jwt.encode(payload, secret, algorithm="HS256")
113+
114+
def get_login_response(self, user) -> dict:
115+
"""Return a dict similar to what FastAPI-Users returns on login.
116+
117+
Keys:
118+
- access_token: short-lived JWT
119+
- refresh_token: long-lived JWT (can be validated with refresh strategy)
120+
- token_type: 'bearer'
121+
- expires_in: seconds until access token expiry
122+
"""
123+
access = self._create_token(user, self.access_secret, self.access_lifetime)
124+
refresh = self._create_token(user, self.refresh_secret, self.refresh_lifetime)
125+
return {
126+
"access_token": access,
127+
"refresh_token": refresh,
128+
"token_type": "bearer",
129+
"expires_in": self.access_lifetime,
130+
}
131+
132+
133+
# Module-level helpers for imports elsewhere
134+
jwt_authentication = _JWTAuthenticationHelper(
135+
SECRET, REFRESH_SECRET, access_lifetime=3600, refresh_lifetime=REFRESH_LIFETIME
136+
)
137+
access_strategy = get_jwt_strategy()
138+
refresh_strategy = get_refresh_strategy()

transformerlab/routers/auth2.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from fastapi import APIRouter, Depends, HTTPException
2+
from transformerlab.shared.models.user_model import User
3+
from transformerlab.models.users import (
4+
fastapi_users,
5+
auth_backend,
6+
current_active_user,
7+
UserRead,
8+
UserCreate,
9+
UserUpdate,
10+
get_user_manager,
11+
get_refresh_strategy,
12+
jwt_authentication,
13+
)
14+
15+
from jose import jwt, JWTError
16+
17+
router = APIRouter(tags=["users"])
18+
19+
20+
# Include Auth and Registration Routers
21+
router.include_router(
22+
fastapi_users.get_auth_router(auth_backend),
23+
prefix="/auth/jwt",
24+
tags=["auth"],
25+
)
26+
router.include_router(
27+
fastapi_users.get_register_router(UserRead, UserCreate),
28+
prefix="/auth",
29+
tags=["auth"],
30+
)
31+
# Include User Management Router (allows authenticated users to view/update their profile)
32+
router.include_router(
33+
fastapi_users.get_users_router(UserRead, UserUpdate),
34+
prefix="/users",
35+
tags=["users"],
36+
)
37+
38+
39+
@router.get("/test-users/authenticated-route")
40+
async def authenticated_route(user: User = Depends(current_active_user)):
41+
return {"message": f"Hello, {user.email}! You are authenticated."}
42+
43+
44+
# To test this, register a new user via /auth/register
45+
# curl -X POST 'http://127.0.0.1:8338/auth/register' \
46+
# -H 'Content-Type: application/json' \
47+
# -d '{
48+
# "email": "test@example.com",
49+
# "password": "password123"
50+
# }'
51+
52+
# Then
53+
# curl -X POST 'http://127.0.0.1:8338/auth/jwt/login' \
54+
# -H 'Content-Type: application/x-www-form-urlencoded' \
55+
# -d 'username=test@example.com&password=password123'
56+
57+
# This will return a token, which you can use to access the authenticated route:
58+
# curl -X GET 'http://127.0.0.1:8338/authenticated-route' \
59+
# -H 'Authorization: Bearer <YOUR_ACCESS_TOKEN>'
60+
61+
62+
@router.post("/auth/refresh")
63+
async def refresh_access_token(
64+
refresh_token: str, # Sent by the client in the request body
65+
user_manager=Depends(get_user_manager),
66+
):
67+
try:
68+
# 1. Decode and Validate the Refresh Token
69+
# Get a fresh refresh strategy instance and use its secret to decode
70+
refresh_strategy = get_refresh_strategy()
71+
payload = jwt.decode(refresh_token, str(refresh_strategy.secret), algorithms=["HS256"])
72+
user_id = payload.get("sub")
73+
74+
if user_id is None:
75+
raise HTTPException(status_code=401, detail="Invalid refresh token payload")
76+
77+
# 2. Get the user object from the database
78+
user = await user_manager.get(user_id)
79+
if user is None or not user.is_active:
80+
raise HTTPException(status_code=401, detail="User inactive or not found")
81+
82+
# 3. Create a NEW Access Token (using the short-lived strategy from the main JWT)
83+
new_access_token = jwt_authentication.get_login_response(user) # Needs custom helper
84+
85+
return {"access_token": new_access_token["access_token"], "token_type": "bearer"}
86+
87+
except JWTError:
88+
raise HTTPException(status_code=401, detail="Expired or invalid refresh token")

transformerlab/routers/test_users.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)