|
| 1 | +# users.py |
| 2 | +import uuid |
| 3 | +from typing import Optional, AsyncGenerator |
| 4 | +from fastapi import Depends, Request |
| 5 | +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas |
| 6 | +from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy |
| 7 | +from fastapi_users.db import SQLAlchemyUserDatabase |
| 8 | +from transformerlab.shared.models.user_model import User, get_async_session |
| 9 | +from sqlalchemy.ext.asyncio import AsyncSession |
| 10 | + |
| 11 | + |
| 12 | +# --- Pydantic Schemas for API interactions --- |
| 13 | +class UserRead(schemas.BaseUser[uuid.UUID]): |
| 14 | + # Add your custom fields here if you added them to the User model |
| 15 | + pass |
| 16 | + |
| 17 | + |
| 18 | +class UserCreate(schemas.BaseUserCreate): |
| 19 | + pass |
| 20 | + |
| 21 | + |
| 22 | +class UserUpdate(schemas.BaseUserUpdate): |
| 23 | + pass |
| 24 | + |
| 25 | + |
| 26 | +# --- User Manager (Handles registration, password reset, etc.) --- |
| 27 | +SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !! |
| 28 | + |
| 29 | + |
| 30 | +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): |
| 31 | + reset_password_token_secret = SECRET |
| 32 | + verification_token_secret = SECRET |
| 33 | + |
| 34 | + # Optional: Define custom logic after registration |
| 35 | + async def on_after_register(self, user: User, request: Optional[Request] = None): |
| 36 | + print(f"User {user.id} has registered.") |
| 37 | + |
| 38 | + |
| 39 | +async def get_user_db(session: AsyncSession = Depends(get_async_session)): |
| 40 | + yield SQLAlchemyUserDatabase(session, User) |
| 41 | + |
| 42 | + |
| 43 | +async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): |
| 44 | + yield UserManager(user_db) |
| 45 | + |
| 46 | + |
| 47 | +# --- Authentication Backend (JWT/Bearer Token) --- |
| 48 | +bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") |
| 49 | + |
| 50 | + |
| 51 | +def get_jwt_strategy() -> JWTStrategy: |
| 52 | + # Token lasts for 3600 seconds (1 hour) |
| 53 | + return JWTStrategy(secret=SECRET, lifetime_seconds=3600) |
| 54 | + |
| 55 | + |
| 56 | +auth_backend = AuthenticationBackend( |
| 57 | + name="jwt", |
| 58 | + transport=bearer_transport, |
| 59 | + get_strategy=get_jwt_strategy, |
| 60 | +) |
| 61 | + |
| 62 | +# --- FastAPIUsers Instance (The main utility) --- |
| 63 | +fastapi_users = FastAPIUsers[User, uuid.UUID]( |
| 64 | + get_user_manager, |
| 65 | + [auth_backend], # Add more backends (like Google OAuth) here |
| 66 | +) |
| 67 | + |
| 68 | +# --- Dependency for Protected Routes --- |
| 69 | +# This is what you'll use in your route decorators |
| 70 | +current_active_user = fastapi_users.current_user(active=True) |
0 commit comments