Skip to content

Commit e072a12

Browse files
committed
This is a first example of using fastapi-auth. Test instructions are in the users router
1 parent 93826e0 commit e072a12

File tree

4 files changed

+165
-0
lines changed

4 files changed

+165
-0
lines changed

api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@
8181

8282
from dotenv import load_dotenv
8383

84+
from transformerlab.models.users import (
85+
fastapi_users,
86+
auth_backend,
87+
current_active_user,
88+
UserRead,
89+
UserCreate,
90+
UserUpdate,
91+
)
92+
from transformerlab.routers.test_users import router as users_router
93+
from transformerlab.shared.models.user_model import create_db_and_tables, User
94+
8495
load_dotenv()
8596

8697
# The following environment variable can be used by other scripts
@@ -109,6 +120,7 @@ async def lifespan(app: FastAPI):
109120
galleries.update_gallery_cache()
110121
spawn_fastchat_controller_subprocess()
111122
await db.init()
123+
await create_db_and_tables()
112124
print("✅ SEED DATA")
113125
# Initialize experiments and cancel any running jobs
114126
seed_default_experiments()
@@ -230,6 +242,25 @@ async def validation_exception_handler(request, exc):
230242
app.include_router(remote.router)
231243
app.include_router(fastchat_openai_api.router)
232244

245+
# Include Auth and Registration Routers
246+
app.include_router(
247+
fastapi_users.get_auth_router(auth_backend),
248+
prefix="/auth/jwt",
249+
tags=["auth"],
250+
)
251+
app.include_router(
252+
fastapi_users.get_register_router(UserRead, UserCreate),
253+
prefix="/auth",
254+
tags=["auth"],
255+
)
256+
# Include User Management Router (allows authenticated users to view/update their profile)
257+
app.include_router(
258+
fastapi_users.get_users_router(UserRead, UserUpdate),
259+
prefix="/users",
260+
tags=["users"],
261+
)
262+
app.include_router(users_router)
263+
233264
# Authentication and session management routes
234265
if os.getenv("TFL_MULTITENANT") == "true":
235266
from transformerlab.routers import auth # noqa: E402

transformerlab/models/users.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# users.py
2+
import uuid
3+
from typing import Optional, AsyncGenerator
4+
from fastapi import Depends, Request
5+
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas
6+
from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy
7+
from fastapi_users.db import SQLAlchemyUserDatabase
8+
from transformerlab.shared.models.user_model import User, get_async_session
9+
from sqlalchemy.ext.asyncio import AsyncSession
10+
11+
12+
# --- Pydantic Schemas for API interactions ---
13+
class UserRead(schemas.BaseUser[uuid.UUID]):
14+
# Add your custom fields here if you added them to the User model
15+
pass
16+
17+
18+
class UserCreate(schemas.BaseUserCreate):
19+
pass
20+
21+
22+
class UserUpdate(schemas.BaseUserUpdate):
23+
pass
24+
25+
26+
# --- User Manager (Handles registration, password reset, etc.) ---
27+
SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !!
28+
29+
30+
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
31+
reset_password_token_secret = SECRET
32+
verification_token_secret = SECRET
33+
34+
# Optional: Define custom logic after registration
35+
async def on_after_register(self, user: User, request: Optional[Request] = None):
36+
print(f"User {user.id} has registered.")
37+
38+
39+
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
40+
yield SQLAlchemyUserDatabase(session, User)
41+
42+
43+
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
44+
yield UserManager(user_db)
45+
46+
47+
# --- Authentication Backend (JWT/Bearer Token) ---
48+
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
49+
50+
51+
def get_jwt_strategy() -> JWTStrategy:
52+
# Token lasts for 3600 seconds (1 hour)
53+
return JWTStrategy(secret=SECRET, lifetime_seconds=3600)
54+
55+
56+
auth_backend = AuthenticationBackend(
57+
name="jwt",
58+
transport=bearer_transport,
59+
get_strategy=get_jwt_strategy,
60+
)
61+
62+
# --- FastAPIUsers Instance (The main utility) ---
63+
fastapi_users = FastAPIUsers[User, uuid.UUID](
64+
get_user_manager,
65+
[auth_backend], # Add more backends (like Google OAuth) here
66+
)
67+
68+
# --- Dependency for Protected Routes ---
69+
# This is what you'll use in your route decorators
70+
current_active_user = fastapi_users.current_user(active=True)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from fastapi import APIRouter, Depends
2+
from transformerlab.shared.models.user_model import User
3+
from transformerlab.models.users import (
4+
current_active_user,
5+
)
6+
7+
8+
router = APIRouter(prefix="/test_users", tags=["users"])
9+
10+
11+
@router.get("/authenticated-route")
12+
async def authenticated_route(user: User = Depends(current_active_user)):
13+
return {"message": f"Hello, {user.email}! You are authenticated."}
14+
15+
16+
# To test this, register a new user via /auth/register
17+
# curl -X POST 'http://127.0.0.1:8338/auth/register' \
18+
# -H 'Content-Type: application/json' \
19+
# -d '{
20+
# "email": "test@example.com",
21+
# "password": "password123"
22+
# }'
23+
24+
# Then
25+
# curl -X POST 'http://127.0.0.1:8338/auth/jwt/login' \
26+
# -H 'Content-Type: application/x-www-form-urlencoded' \
27+
# -d 'username=test@example.com&password=password123'
28+
29+
# This will return a token, which you can use to access the authenticated route:
30+
# curl -X GET 'http://127.0.0.1:8338/authenticated-route' \
31+
# -H 'Authorization: Bearer <YOUR_ACCESS_TOKEN>'
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# database.py
2+
from typing import AsyncGenerator
3+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
4+
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
5+
from sqlalchemy.orm import sessionmaker
6+
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
7+
8+
# Replace with your actual database URL (e.g., PostgreSQL, SQLite)
9+
from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL
10+
11+
Base: DeclarativeMeta = declarative_base()
12+
13+
14+
# 1. Define your User Model (inherits from a FastAPI Users base class)
15+
class User(SQLAlchemyBaseUserTableUUID, Base):
16+
pass # You can add custom fields here later, like 'first_name: str'
17+
18+
19+
# 2. Setup the Async Engine and Session
20+
engine = create_async_engine(DATABASE_URL)
21+
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
22+
23+
24+
# 3. Utility to create tables (run this on app startup)
25+
async def create_db_and_tables():
26+
async with engine.begin() as conn:
27+
await conn.run_sync(Base.metadata.create_all)
28+
29+
30+
# 4. Database session dependency
31+
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
32+
async with AsyncSessionLocal() as session:
33+
yield session

0 commit comments

Comments
 (0)