1111from sqlalchemy .ext .asyncio .session import AsyncSession
1212
1313import settings
14- from utils import delete_expire_files , storage , get_code
14+ from utils import delete_expire_files , storage , get_code , error_ip_limit , upload_ip_limit
1515from database import get_session , Codes , init_models
16- from depends import admin_required , IPRateLimit
16+ from depends import admin_required
1717
18+ # 实例化FastAPI
1819app = FastAPI (debug = settings .DEBUG )
1920
21+ # 数据存储文件夹
2022DATA_ROOT = Path (settings .DATA_ROOT )
2123if not DATA_ROOT .exists ():
2224 DATA_ROOT .mkdir (parents = True )
23-
25+ # 静态文件夹
2426app .mount (settings .STATIC_URL , StaticFiles (directory = DATA_ROOT ), name = "static" )
2527
2628
2729@app .on_event ('startup' )
2830async def startup ():
31+ # 初始化数据库
2932 await init_models ()
33+ # 启动后台任务,不定时删除过期文件
3034 asyncio .create_task (delete_expire_files ())
3135
3236
37+ # 首页页面
3338index_html = open ('templates/index.html' , 'r' , encoding = 'utf-8' ).read () \
3439 .replace ('{{title}}' , settings .TITLE ) \
3540 .replace ('{{description}}' , settings .DESCRIPTION ) \
3641 .replace ('{{keywords}}' , settings .KEYWORDS )
42+ # 管理页面
3743admin_html = open ('templates/admin.html' , 'r' , encoding = 'utf-8' ).read () \
3844 .replace ('{{title}}' , settings .TITLE ) \
3945 .replace ('{{description}}' , settings .DESCRIPTION ) \
4046 .replace ('{{keywords}}' , settings .KEYWORDS )
4147
42- ip_limit = IPRateLimit ()
43-
4448
45- @app .get (f'/{ settings .ADMIN_ADDRESS } ' )
49+ @app .get (f'/{ settings .ADMIN_ADDRESS } ' , description = '管理页面' , response_class = HTMLResponse )
4650async def admin ():
4751 return HTMLResponse (admin_html )
4852
4953
50- @app .post (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )])
54+ @app .post (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )], description = '查询数据库列表' )
5155async def admin_post (s : AsyncSession = Depends (get_session )):
52- query = select ( Codes )
53- codes = (await s .execute (query )).scalars ().all ()
56+ # 查询数据库列表
57+ codes = (await s .execute (select ( Codes ) )).scalars ().all ()
5458 return {'detail' : '查询成功' , 'data' : codes }
5559
5660
57- @app .delete (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )])
61+ @app .delete (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )], description = '删除数据库记录' )
5862async def admin_delete (code : str , s : AsyncSession = Depends (get_session )):
63+ # 找到相应记录
5964 query = select (Codes ).where (Codes .code == code )
65+ # 找到第一条记录
6066 file = (await s .execute (query )).scalars ().first ()
61- if file :
62- if file .type != 'text' :
63- await storage .delete_file (file .text )
64- await s .delete (file )
65- await s .commit ()
67+ # 如果记录存在,并且不是文本
68+ if file and file .type != 'text' :
69+ # 删除文件
70+ await storage .delete_file (file .text )
71+ # 删除数据库记录
72+ await s .delete (file )
73+ await s .commit ()
6674 return {'detail' : '删除成功' }
6775
6876
@@ -72,25 +80,30 @@ async def index():
7280
7381
7482@app .get ('/select' )
75- async def get_file (code : str , s : AsyncSession = Depends (get_session )):
83+ async def get_file (code : str , ip : str = Depends (error_ip_limit ), s : AsyncSession = Depends (get_session )):
84+ # 查出数据库记录
7685 query = select (Codes ).where (Codes .code == code )
7786 info = (await s .execute (query )).scalars ().first ()
87+ # 如果记录不存在,IP错误次数+1
7888 if not info :
79- raise HTTPException (status_code = 404 , detail = "口令不存在" )
89+ error_ip_limit .add_ip (ip )
90+ raise HTTPException (status_code = 404 , detail = "口令不存在,次数过多将被禁止访问" )
91+ # 如果是文本,直接返回
8092 if info .type == 'text' :
8193 return {'detail' : '查询成功' , 'data' : info .text }
94+ # 如果是文件,返回文件
8295 else :
8396 filepath = await storage .get_filepath (info .text )
8497 return FileResponse (filepath , filename = info .name )
8598
8699
87100@app .post ('/' )
88- async def index (code : str , ip : str = Depends (ip_limit ), s : AsyncSession = Depends (get_session )):
101+ async def index (code : str , ip : str = Depends (error_ip_limit ), s : AsyncSession = Depends (get_session )):
89102 query = select (Codes ).where (Codes .code == code )
90103 info = (await s .execute (query )).scalars ().first ()
91104 if not info :
92- error_count = settings .ERROR_COUNT - ip_limit .add_ip (ip )
93- raise HTTPException (status_code = 404 , detail = f"取件码错误,错误 { error_count } 次将被禁止10分钟 " )
105+ error_count = settings .ERROR_COUNT - error_ip_limit .add_ip (ip )
106+ raise HTTPException (status_code = 404 , detail = f"取件码错误,{ error_count } 次后将被禁止 { settings . ERROR_MINUTE } 分钟 " )
94107 if info .exp_time < datetime .datetime .now () or info .count == 0 :
95108 if info .type != "text" :
96109 await storage .delete_file (info .text )
@@ -109,7 +122,7 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend
109122
110123@app .post ('/share' )
111124async def share (background_tasks : BackgroundTasks , text : str = Form (default = None ), style : str = Form (default = '2' ),
112- value : int = Form (default = 1 ), file : UploadFile = File (default = None ),
125+ value : int = Form (default = 1 ), file : UploadFile = File (default = None ), ip : str = Depends ( upload_ip_limit ),
113126 s : AsyncSession = Depends (get_session )):
114127 code = await get_code (s )
115128 if style == '2' :
@@ -137,6 +150,7 @@ async def share(background_tasks: BackgroundTasks, text: str = Form(default=None
137150 info = Codes (code = code , text = _text , size = size , type = _type , name = name , count = exp_count , exp_time = exp_time , key = key )
138151 s .add (info )
139152 await s .commit ()
153+ upload_ip_limit .add_ip (ip )
140154 return {
141155 'detail' : '分享成功,请点击取件码按钮查看上传列表' ,
142156 'data' : {'code' : code , 'key' : key , 'name' : name , 'text' : _text }
0 commit comments