44import asyncio
55from pathlib import Path
66
7- from fastapi import FastAPI , Depends , UploadFile , Form , File
8- from starlette .requests import Request
7+ from fastapi import FastAPI , Depends , UploadFile , Form , File , HTTPException
98from starlette .responses import HTMLResponse , FileResponse
109from starlette .staticfiles import StaticFiles
1110
1514import settings
1615from database import get_session , Codes , init_models , engine
1716from storage import STORAGE_ENGINE
17+ from depends import admin_required , IPRateLimit
1818
1919app = FastAPI (debug = settings .DEBUG )
2020
@@ -43,7 +43,7 @@ async def startup():
4343 .replace ('{{description}}' , settings .DESCRIPTION ) \
4444 .replace ('{{keywords}}' , settings .KEYWORDS )
4545
46- error_ip_count = {}
46+ ip_limit = IPRateLimit ()
4747
4848
4949async def delete_expire_files ():
@@ -72,88 +72,60 @@ async def admin():
7272 return HTMLResponse (admin_html )
7373
7474
75- @app .post (f'/{ settings .ADMIN_ADDRESS } ' )
76- async def admin_post (request : Request , s : AsyncSession = Depends (get_session )):
77- if request .headers .get ('pwd' ) == settings .ADMIN_PASSWORD :
78- query = select (Codes )
79- codes = (await s .execute (query )).scalars ().all ()
80- return {'code' : 200 , 'msg' : '查询成功' , 'data' : codes }
81- else :
82- return {'code' : 404 , 'msg' : '密码错误' }
75+ @app .post (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )])
76+ async def admin_post (s : AsyncSession = Depends (get_session )):
77+ query = select (Codes )
78+ codes = (await s .execute (query )).scalars ().all ()
79+ return {'msg' : '查询成功' , 'data' : codes }
8380
8481
85- @app .delete (f'/{ settings .ADMIN_ADDRESS } ' )
86- async def admin_delete (request : Request , code : str , s : AsyncSession = Depends (get_session )):
87- if request .headers .get ('pwd' ) == settings .ADMIN_PASSWORD :
88- query = select (Codes ).where (Codes .code == code )
89- file = (await s .execute (query )).scalars ().first ()
90- await storage .delete_file ({'type' : file .type , 'text' : file .text })
91- await s .delete (file )
92- await s .commit ()
93- return {'code' : 200 , 'msg' : '删除成功' }
94- else :
95- return {'code' : 404 , 'msg' : '密码错误' }
82+ @app .delete (f'/{ settings .ADMIN_ADDRESS } ' , dependencies = [Depends (admin_required )])
83+ async def admin_delete (code : str , s : AsyncSession = Depends (get_session )):
84+ query = select (Codes ).where (Codes .code == code )
85+ file = (await s .execute (query )).scalars ().first ()
86+ await storage .delete_file ({'type' : file .type , 'text' : file .text })
87+ await s .delete (file )
88+ await s .commit ()
89+ return {'msg' : '删除成功' }
9690
9791
9892@app .get ('/' )
9993async def index ():
10094 return HTMLResponse (index_html )
10195
10296
103- def check_ip (ip ):
104- # 检查ip是否被禁止
105- if ip in error_ip_count :
106- if error_ip_count [ip ]['count' ] >= settings .ERROR_COUNT :
107- if error_ip_count [ip ]['time' ] + datetime .timedelta (minutes = settings .ERROR_MINUTE ) > datetime .datetime .now ():
108- return False
109- else :
110- error_ip_count .pop (ip )
111- return True
112-
113-
114- def ip_error (ip ):
115- ip_info = error_ip_count .get (ip , {'count' : 0 , 'time' : datetime .datetime .now ()})
116- ip_info ['count' ] += 1
117- error_ip_count [ip ] = ip_info
118- return ip_info ['count' ]
119-
120-
12197@app .get ('/select' )
12298async def get_file (code : str , s : AsyncSession = Depends (get_session )):
12399 query = select (Codes ).where (Codes .code == code )
124100 info = (await s .execute (query )).scalars ().first ()
125- if info :
126- if info .type == 'text' :
127- return {'code' : code , 'msg' : '查询成功' , 'data' : info .text }
128- else :
129- filepath = await storage .get_filepath (info .text )
130- return FileResponse (filepath , filename = info .name )
101+ if not info :
102+ raise HTTPException (status_code = 404 , detail = "口令不存在" )
103+ if info .type == 'text' :
104+ return {'msg' : '查询成功' , 'data' : info .text }
131105 else :
132- return {'code' : 404 , 'msg' : '口令不存在' }
106+ filepath = await storage .get_filepath (info .text )
107+ return FileResponse (filepath , filename = info .name )
133108
134109
135110@app .post ('/' )
136- async def index (request : Request , code : str , s : AsyncSession = Depends (get_session )):
137- ip = request .client .host
138- if not check_ip (ip ):
139- return {'code' : 404 , 'msg' : '错误次数过多,请稍后再试' }
111+ async def index (code : str , ip : str = Depends (ip_limit ), s : AsyncSession = Depends (get_session )):
140112 query = select (Codes ).where (Codes .code == code )
141113 info = (await s .execute (query )).scalars ().first ()
142114 if not info :
143- return {'code' : 404 , 'msg' : f'取件码错误,错误{ settings .ERROR_COUNT - ip_error (ip )} 次将被禁止10分钟' }
115+ error_count = ip_limit .add_ip (ip )
116+ raise HTTPException (status_code = 404 , detail = f"取件码错误,错误{ settings .ERROR_COUNT - error_count } 次将被禁止10分钟" )
144117 if info .exp_time < datetime .datetime .now () or info .count == 0 :
145118 await storage .delete_file ({'type' : info .type , 'text' : info .text })
146119 await s .delete (info )
147120 await s .commit ()
148- return { 'code' : 404 , 'msg' : ' 取件码已过期,请联系寄件人' }
121+ raise HTTPException ( status_code = 404 , detail = " 取件码已过期,请联系寄件人" )
149122 count = info .count - 1
150123 query = update (Codes ).where (Codes .id == info .id ).values (count = count )
151124 await s .execute (query )
152125 await s .commit ()
153126 if info .type != 'text' :
154127 info .text = f'/select?code={ code } '
155128 return {
156- 'code' : 200 ,
157129 'msg' : '取件成功,请点击"取"查看' ,
158130 'data' : {'type' : info .type , 'text' : info .text , 'name' : info .name , 'code' : info .code }
159131 }
@@ -165,12 +137,12 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
165137 code = await get_code (s )
166138 if style == '2' :
167139 if value > 7 :
168- return { 'code' : 404 , 'msg' : ' 最大有效天数为7天' }
140+ raise HTTPException ( status_code = 400 , detail = " 最大有效天数为7天" )
169141 exp_time = datetime .datetime .now () + datetime .timedelta (days = value )
170142 exp_count = - 1
171143 elif style == '1' :
172144 if value < 1 :
173- return { 'code' : 404 , 'msg' : ' 最小有效次数为1次' }
145+ raise HTTPException ( status_code = 400 , detail = " 最小有效次数为1次" )
174146 exp_time = datetime .datetime .now () + datetime .timedelta (days = 1 )
175147 exp_count = value
176148 else :
@@ -181,7 +153,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
181153 file_bytes = await file .read ()
182154 size = len (file_bytes )
183155 if size > settings .FILE_SIZE_LIMIT :
184- return { 'code' : 404 , 'msg' : ' 文件过大' }
156+ raise HTTPException ( status_code = 400 , detail = " 文件过大" )
185157 _text , _type , name = await storage .save_file (file , file_bytes , key ), file .content_type , file .filename
186158 else :
187159 size , _text , _type , name = len (text ), text , 'text' , '文本分享'
@@ -198,7 +170,6 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
198170 s .add (info )
199171 await s .commit ()
200172 return {
201- 'code' : 200 ,
202173 'msg' : '分享成功,请点击文件箱查看取件码' ,
203174 'data' : {'code' : code , 'key' : key , 'name' : name , 'text' : _text }
204175 }
0 commit comments