Skip to content

Commit 3b4d6a2

Browse files
committed
refactor(auth): 将数据库查询从同步改为异步实现
将auth_router中的所有数据库查询操作从同步SQLAlchemy改为异步实现 添加async_check_first_run方法以支持异步检查首次运行
1 parent 704954b commit 3b4d6a2

File tree

2 files changed

+86
-34
lines changed

2 files changed

+86
-34
lines changed

server/routers/auth_router.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from fastapi import APIRouter, Depends, HTTPException, Request, status, UploadFile, File
44
from fastapi.security import OAuth2PasswordRequestForm
55
from pydantic import BaseModel
6+
from sqlalchemy import select
7+
from sqlalchemy.ext.asyncio import AsyncSession
68
from sqlalchemy.orm import Session
79

810
from src.storage.db.manager import db_manager
@@ -89,16 +91,22 @@ class UserIdGeneration(BaseModel):
8991

9092

9193
@auth.post("/token", response_model=Token)
92-
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
94+
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
9395
# 查找用户 - 支持user_id和phone_number登录
9496
login_identifier = form_data.username # OAuth2表单中的username字段作为登录标识符
9597

9698
# 尝试通过user_id查找
97-
user = db.query(User).filter(User.user_id == login_identifier).first()
99+
result = await db.execute(
100+
select(User).filter(User.user_id == login_identifier)
101+
)
102+
user = result.scalar_one_or_none()
98103

99104
# 如果通过user_id没找到,尝试通过phone_number查找
100105
if not user:
101-
user = db.query(User).filter(User.phone_number == login_identifier).first()
106+
result = await db.execute(
107+
select(User).filter(User.phone_number == login_identifier)
108+
)
109+
user = result.scalar_one_or_none()
102110

103111
# 如果用户不存在,为防止用户名枚举攻击,返回通用错误信息
104112
if not user:
@@ -176,15 +184,15 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
176184
# 路由:校验是否需要初始化管理员
177185
@auth.get("/check-first-run")
178186
async def check_first_run():
179-
is_first_run = db_manager.check_first_run()
187+
is_first_run = await db_manager.async_check_first_run()
180188
return {"first_run": is_first_run}
181189

182190

183191
# 路由:初始化管理员账户
184192
@auth.post("/initialize", response_model=Token)
185-
async def initialize_admin(admin_data: InitializeAdmin, db: Session = Depends(get_db)):
193+
async def initialize_admin(admin_data: InitializeAdmin, db: AsyncSession = Depends(get_db)):
186194
# 检查是否是首次运行
187-
if not db_manager.check_first_run():
195+
if not await db_manager.async_check_first_run():
188196
raise HTTPException(
189197
status_code=status.HTTP_403_FORBIDDEN,
190198
detail="系统已经初始化,无法再次创建初始管理员",
@@ -263,7 +271,7 @@ async def update_profile(
263271
profile_data: UserProfileUpdate,
264272
request: Request,
265273
current_user: User = Depends(get_required_user),
266-
db: Session = Depends(get_db),
274+
db: AsyncSession = Depends(get_db),
267275
):
268276
"""更新当前用户的个人资料"""
269277
update_details = []
@@ -279,9 +287,10 @@ async def update_profile(
279287
)
280288

281289
# 检查用户名是否已被其他用户使用
282-
existing_user = (
283-
db.query(User).filter(User.username == profile_data.username, User.id != current_user.id).first()
290+
result = await db.execute(
291+
select(User).filter(User.username == profile_data.username, User.id != current_user.id)
284292
)
293+
existing_user = result.scalar_one_or_none()
285294
if existing_user:
286295
raise HTTPException(
287296
status_code=status.HTTP_400_BAD_REQUEST,
@@ -299,11 +308,11 @@ async def update_profile(
299308

300309
# 检查手机号是否已被其他用户使用
301310
if profile_data.phone_number:
302-
existing_phone = (
303-
db.query(User)
311+
result = await db.execute(
312+
select(User)
304313
.filter(User.phone_number == profile_data.phone_number, User.id != current_user.id)
305-
.first()
306314
)
315+
existing_phone = result.scalar_one_or_none()
307316
if existing_phone:
308317
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号已被其他用户使用")
309318

@@ -327,7 +336,7 @@ async def update_profile(
327336

328337
@auth.post("/users", response_model=UserResponse)
329338
async def create_user(
330-
user_data: UserCreate, request: Request, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db)
339+
user_data: UserCreate, request: Request, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)
331340
):
332341
# 验证用户名
333342
is_valid, error_msg = validate_username(user_data.username)
@@ -338,7 +347,10 @@ async def create_user(
338347
)
339348

340349
# 检查用户名是否已存在
341-
existing_user = db.query(User).filter(User.username == user_data.username).first()
350+
result = await db.execute(
351+
select(User).filter(User.username == user_data.username)
352+
)
353+
existing_user = result.scalar_one_or_none()
342354
if existing_user:
343355
raise HTTPException(
344356
status_code=status.HTTP_400_BAD_REQUEST,
@@ -347,15 +359,19 @@ async def create_user(
347359

348360
# 检查手机号是否已存在(如果提供了)
349361
if user_data.phone_number:
350-
existing_phone = db.query(User).filter(User.phone_number == user_data.phone_number).first()
362+
result = await db.execute(
363+
select(User).filter(User.phone_number == user_data.phone_number)
364+
)
365+
existing_phone = result.scalar_one_or_none()
351366
if existing_phone:
352367
raise HTTPException(
353368
status_code=status.HTTP_400_BAD_REQUEST,
354369
detail="手机号已存在",
355370
)
356371

357372
# 生成唯一的user_id
358-
existing_user_ids = [user.user_id for user in db.query(User.user_id).all()]
373+
result = await db.execute(select(User.user_id))
374+
existing_user_ids = [user_id for (user_id,) in result.all()]
359375
user_id = generate_unique_user_id(user_data.username, existing_user_ids)
360376

361377
# 创建新用户
@@ -397,16 +413,25 @@ async def create_user(
397413
# 路由:获取所有用户(管理员权限)
398414
@auth.get("/users", response_model=list[UserResponse])
399415
async def read_users(
400-
skip: int = 0, limit: int = 100, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db)
416+
skip: int = 0, limit: int = 100, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)
401417
):
402-
users = db.query(User).filter(User.is_deleted == 0).offset(skip).limit(limit).all()
418+
result = await db.execute(
419+
select(User)
420+
.filter(User.is_deleted == 0)
421+
.offset(skip)
422+
.limit(limit)
423+
)
424+
users = result.scalars().all()
403425
return [user.to_dict() for user in users]
404426

405427

406428
# 路由:获取特定用户信息(管理员权限)
407429
@auth.get("/users/{user_id}", response_model=UserResponse)
408-
async def read_user(user_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db)):
409-
user = db.query(User).filter(User.id == user_id, User.is_deleted == 0).first()
430+
async def read_user(user_id: int, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)):
431+
result = await db.execute(
432+
select(User).filter(User.id == user_id, User.is_deleted == 0)
433+
)
434+
user = result.scalar_one_or_none()
410435
if user is None:
411436
raise HTTPException(
412437
status_code=status.HTTP_404_NOT_FOUND,
@@ -422,9 +447,12 @@ async def update_user(
422447
user_data: UserUpdate,
423448
request: Request,
424449
current_user: User = Depends(get_admin_user),
425-
db: Session = Depends(get_db),
450+
db: AsyncSession = Depends(get_db),
426451
):
427-
user = db.query(User).filter(User.id == user_id, User.is_deleted == 0).first()
452+
result = await db.execute(
453+
select(User).filter(User.id == user_id, User.is_deleted == 0)
454+
)
455+
user = result.scalar_one_or_none()
428456
if user is None:
429457
raise HTTPException(
430458
status_code=status.HTTP_404_NOT_FOUND,
@@ -450,7 +478,10 @@ async def update_user(
450478

451479
if user_data.username is not None:
452480
# 检查用户名是否已被其他用户使用
453-
existing_user = db.query(User).filter(User.username == user_data.username, User.id != user_id).first()
481+
result = await db.execute(
482+
select(User).filter(User.username == user_data.username, User.id != user_id)
483+
)
484+
existing_user = result.scalar_one_or_none()
454485
if existing_user:
455486
raise HTTPException(
456487
status_code=status.HTTP_400_BAD_REQUEST,
@@ -478,9 +509,12 @@ async def update_user(
478509
# 路由:删除用户(管理员权限)
479510
@auth.delete("/users/{user_id}", response_model=dict)
480511
async def delete_user(
481-
user_id: int, request: Request, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db)
512+
user_id: int, request: Request, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)
482513
):
483-
user = db.query(User).filter(User.id == user_id, User.is_deleted == 0).first()
514+
result = await db.execute(
515+
select(User).filter(User.id == user_id, User.is_deleted == 0)
516+
)
517+
user = result.scalar_one_or_none()
484518
if user is None:
485519
raise HTTPException(
486520
status_code=status.HTTP_404_NOT_FOUND,
@@ -497,7 +531,10 @@ async def delete_user(
497531
)
498532

499533
# 检查是否是最后一个超级管理员
500-
superadmin_count = db.query(User).filter(User.role == "superadmin", User.is_deleted == 0).count()
534+
result = await db.execute(
535+
select(db.func.count(User.id)).filter(User.role == "superadmin", User.is_deleted == 0)
536+
)
537+
superadmin_count = result.scalar()
501538
if superadmin_count <= 1:
502539
raise HTTPException(
503540
status_code=status.HTTP_400_BAD_REQUEST,
@@ -544,7 +581,7 @@ async def delete_user(
544581
# 路由:验证用户名并生成user_id
545582
@auth.post("/validate-username", response_model=UserIdGeneration)
546583
async def validate_username_and_generate_user_id(
547-
validation_data: UsernameValidation, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db)
584+
validation_data: UsernameValidation, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)
548585
):
549586
"""验证用户名格式并生成可用的user_id"""
550587
# 验证用户名格式
@@ -556,15 +593,19 @@ async def validate_username_and_generate_user_id(
556593
)
557594

558595
# 检查用户名是否已存在
559-
existing_user = db.query(User).filter(User.username == validation_data.username).first()
596+
result = await db.execute(
597+
select(User).filter(User.username == validation_data.username)
598+
)
599+
existing_user = result.scalar_one_or_none()
560600
if existing_user:
561601
raise HTTPException(
562602
status_code=status.HTTP_400_BAD_REQUEST,
563603
detail="用户名已存在",
564604
)
565605

566606
# 生成唯一的user_id
567-
existing_user_ids = [user.user_id for user in db.query(User.user_id).all()]
607+
result = await db.execute(select(User.user_id))
608+
existing_user_ids = [user_id for (user_id,) in result.all()]
568609
user_id = generate_unique_user_id(validation_data.username, existing_user_ids)
569610

570611
return UserIdGeneration(username=validation_data.username, user_id=user_id, is_available=True)
@@ -573,17 +614,20 @@ async def validate_username_and_generate_user_id(
573614
# 路由:检查user_id是否可用
574615
@auth.get("/check-user-id/{user_id}")
575616
async def check_user_id_availability(
576-
user_id: str, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db)
617+
user_id: str, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)
577618
):
578619
"""检查user_id是否可用"""
579-
existing_user = db.query(User).filter(User.user_id == user_id).first()
620+
result = await db.execute(
621+
select(User).filter(User.user_id == user_id)
622+
)
623+
existing_user = result.scalar_one_or_none()
580624
return {"user_id": user_id, "is_available": existing_user is None}
581625

582626

583627
# 路由:上传用户头像
584628
@auth.post("/upload-avatar")
585629
async def upload_user_avatar(
586-
file: UploadFile = File(...), current_user: User = Depends(get_required_user), db: Session = Depends(get_db)
630+
file: UploadFile = File(...), current_user: User = Depends(get_required_user), db: AsyncSession = Depends(get_db)
587631
):
588632
"""上传用户头像"""
589633
# 检查文件类型

src/storage/db/manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pathlib
44
from contextlib import asynccontextmanager, contextmanager
55

6-
from sqlalchemy import create_engine
6+
from sqlalchemy import create_engine, select, func
77
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
88
from sqlalchemy.orm import sessionmaker
99

@@ -131,14 +131,22 @@ async def get_async_session_context(self):
131131
await session.close()
132132

133133
def check_first_run(self):
134-
"""检查是否首次运行"""
134+
"""检查是否首次运行(同步版本)"""
135135
session = self.get_session()
136136
try:
137137
# 检查是否有任何用户存在
138138
return session.query(User).count() == 0
139139
finally:
140140
session.close()
141141

142+
async def async_check_first_run(self):
143+
"""检查是否首次运行(异步版本)"""
144+
async with self.get_async_session_context() as session:
145+
# 检查是否有任何用户存在
146+
result = await session.execute(select(func.count(User.id)))
147+
count = result.scalar()
148+
return count == 0
149+
142150

143151
# 创建全局数据库管理器实例
144152
db_manager = DBManager()

0 commit comments

Comments
 (0)