33import uuid
44import threading
55import random
6+ import asyncio
7+ from pathlib import Path
8+
69from fastapi import FastAPI , Depends , UploadFile , Form , File
710from starlette .requests import Request
811from starlette .responses import HTMLResponse , FileResponse
1114from sqlalchemy import or_ , select , update , delete
1215from sqlalchemy .ext .asyncio .session import AsyncSession
1316
14- from database import get_session , Codes , init_models
17+ import settings
18+ from database import get_session , Codes , init_models , engine
19+
20+ app = FastAPI (debug = settings .DEBUG )
21+
22+ DATA_ROOT = Path (settings .DATA_ROOT )
23+ if not DATA_ROOT .exists ():
24+ DATA_ROOT .mkdir (parents = True )
1525
16- app = FastAPI ()
17- if not os .path .exists ('./static' ):
18- os .makedirs ('./static' )
19- app .mount ("/static" , StaticFiles (directory = "static" ), name = "static" )
26+ STATIC_URL = settings .STATIC_URL
27+ app .mount (STATIC_URL , StaticFiles (directory = DATA_ROOT ), name = "static" )
2028
2129
2230@app .on_event ('startup' )
2331async def startup ():
2432 await init_models ()
2533
26-
27- ############################################
28- # 需要修改的参数
29- # 允许错误次数
30- error_count = 5
31- # 禁止分钟数
32- error_minute = 10
33- # 后台地址
34- admin_address = 'admin'
35- # 管理密码
36- admin_password = 'admin'
37- # 文件大小限制 10M
38- file_size_limit = 1024 * 1024 * 10
39- # 系统标题
40- title = '文件快递柜'
41- # 系统描述
42- description = 'FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件,图片,视频,音频,压缩包等文件'
43- # 系统关键字
44- keywords = 'FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件,图片,视频,音频,压缩包等文件'
45- ############################################
34+ asyncio .create_task (delete_expire_files ())
4635
4736index_html = open ('templates/index.html' , 'r' , encoding = 'utf-8' ).read () \
48- .replace ('{{title}}' , title ) \
49- .replace ('{{description}}' , description ) \
50- .replace ('{{keywords}}' , keywords )
37+ .replace ('{{title}}' , settings . TITLE ) \
38+ .replace ('{{description}}' , settings . DESCRIPTION ) \
39+ .replace ('{{keywords}}' , settings . KEYWORDS )
5140admin_html = open ('templates/admin.html' , 'r' , encoding = 'utf-8' ).read () \
52- .replace ('{{title}}' , title ) \
53- .replace ('{{description}}' , description ) \
54- .replace ('{{keywords}}' , keywords )
41+ .replace ('{{title}}' , settings . TITLE ) \
42+ .replace ('{{description}}' , settings . DESCRIPTION ) \
43+ .replace ('{{keywords}}' , settings . KEYWORDS )
5544
5645error_ip_count = {}
5746
5847
5948def delete_file (files ):
6049 for file in files :
6150 if file ['type' ] != 'text' :
62- os .remove ('.' + file ['text' ])
51+ os .remove (DATA_ROOT / file ['text' ].lstrip (STATIC_URL + '/' ))
52+
53+
54+ async def delete_expire_files ():
55+ while True :
56+ async with AsyncSession (engine , expire_on_commit = False ) as s :
57+ query = select (Codes ).where (or_ (Codes .exp_time < datetime .datetime .now (), Codes .count == 0 ))
58+ exps = (await s .execute (query )).scalars ().all ()
59+ await asyncio .to_thread (delete_file , [{'type' : old .type , 'text' : old .text } for old in exps ])
60+
61+ exps_ids = [exp .id for exp in exps ]
62+ query = delete (Codes ).where (Codes .id .in_ (exps_ids ))
63+ await s .execute (query )
64+ await s .commit ()
65+
66+ await asyncio .sleep (random .randint (60 , 300 ))
6367
6468
6569async def get_code (s : AsyncSession ):
@@ -73,38 +77,39 @@ def get_file_name(key, ext, file):
7377 now = datetime .datetime .now ()
7478 file_bytes = file .file .read ()
7579 size = len (file_bytes )
76- if size > file_size_limit :
80+ if size > settings . FILE_SIZE_LIMIT :
7781 return size , '' , '' , ''
78- path = f'./static/ upload/{ now .year } /{ now .month } /{ now .day } /'
82+ path = DATA_ROOT / f" upload/{ now .year } /{ now .month } /{ now .day } /"
7983 name = f'{ key } .{ ext } '
80- if not os .path .exists (path ):
81- os .makedirs (path )
82- with open (f'{ os .path .join (path , name )} ' , 'wb' ) as f :
84+ if not path .exists ():
85+ path .mkdir (parents = True )
86+ filepath = path / name
87+ with open (filepath , 'wb' ) as f :
8388 f .write (file_bytes )
84- return size , path [ 1 :] + name , file .content_type , file .filename
89+ return size , f" { STATIC_URL } / { filepath . relative_to ( DATA_ROOT ) } " , file .content_type , file .filename
8590
8691
87- @app .get (f'/{ admin_address } ' )
92+ @app .get (f'/{ settings . ADMIN_ADDRESS } ' )
8893async def admin ():
8994 return HTMLResponse (admin_html )
9095
9196
92- @app .post (f'/{ admin_address } ' )
97+ @app .post (f'/{ settings . ADMIN_ADDRESS } ' )
9398async def admin_post (request : Request , s : AsyncSession = Depends (get_session )):
94- if request .headers .get ('pwd' ) == admin_password :
99+ if request .headers .get ('pwd' ) == settings . ADMIN_PASSWORD :
95100 query = select (Codes )
96101 codes = (await s .execute (query )).scalars ().all ()
97102 return {'code' : 200 , 'msg' : '查询成功' , 'data' : codes }
98103 else :
99104 return {'code' : 404 , 'msg' : '密码错误' }
100105
101106
102- @app .delete (f'/{ admin_address } ' )
107+ @app .delete (f'/{ settings . ADMIN_ADDRESS } ' )
103108async def admin_delete (request : Request , code : str , s : AsyncSession = Depends (get_session )):
104- if request .headers .get ('pwd' ) == admin_password :
109+ if request .headers .get ('pwd' ) == settings . ADMIN_PASSWORD :
105110 query = select (Codes ).where (Codes .code == code )
106111 file = (await s .execute (query )).scalars ().first ()
107- threading . Thread ( target = delete_file , args = ( [{'type' : file .type , 'text' : file .text }],)). start ( )
112+ await asyncio . to_thread ( delete_file , [{'type' : file .type , 'text' : file .text }])
108113 await s .delete (file )
109114 await s .commit ()
110115 return {'code' : 200 , 'msg' : '删除成功' }
@@ -120,8 +125,8 @@ async def index():
120125def check_ip (ip ):
121126 # 检查ip是否被禁止
122127 if ip in error_ip_count :
123- if error_ip_count [ip ]['count' ] >= error_count :
124- if error_ip_count [ip ]['time' ] + datetime .timedelta (minutes = error_minute ) > datetime .datetime .now ():
128+ if error_ip_count [ip ]['count' ] >= settings . ERROR_COUNT :
129+ if error_ip_count [ip ]['time' ] + datetime .timedelta (minutes = settings . ERROR_MINUTE ) > datetime .datetime .now ():
125130 return False
126131 else :
127132 error_ip_count .pop (ip )
@@ -143,7 +148,7 @@ async def get_file(code: str, s: AsyncSession = Depends(get_session)):
143148 if info .type == 'text' :
144149 return {'code' : code , 'msg' : '查询成功' , 'data' : info .text }
145150 else :
146- return FileResponse ('.' + info .text , filename = info .name )
151+ return FileResponse (DATA_ROOT / info .text . lstrip ( STATIC_URL + '/' ) , filename = info .name )
147152 else :
148153 return {'code' : 404 , 'msg' : '口令不存在' }
149154
@@ -156,7 +161,7 @@ async def index(request: Request, code: str, s: AsyncSession = Depends(get_sessi
156161 query = select (Codes ).where (Codes .code == code )
157162 info = (await s .execute (query )).scalars ().first ()
158163 if not info :
159- return {'code' : 404 , 'msg' : f'取件码错误,错误{ error_count - ip_error (ip )} 次将被禁止10分钟' }
164+ return {'code' : 404 , 'msg' : f'取件码错误,错误{ settings . ERROR_COUNT - ip_error (ip )} 次将被禁止10分钟' }
160165 if info .exp_time < datetime .datetime .now () or info .count == 0 :
161166 threading .Thread (target = delete_file , args = ([{'type' : info .type , 'text' : info .text }],)).start ()
162167 await s .delete (info )
@@ -179,15 +184,6 @@ async def index(request: Request, code: str, s: AsyncSession = Depends(get_sessi
179184@app .post ('/share' )
180185async def share (text : str = Form (default = None ), style : str = Form (default = '2' ), value : int = Form (default = 1 ),
181186 file : UploadFile = File (default = None ), s : AsyncSession = Depends (get_session )):
182- query = select (Codes ).where (or_ (Codes .exp_time < datetime .datetime .now (), Codes .count == 0 ))
183- exps = (await s .execute (query )).scalars ().all ()
184- threading .Thread (target = delete_file , args = ([[{'type' : old .type , 'text' : old .text }] for old in exps ],)).start ()
185-
186- exps_ids = [exp .id for exp in exps ]
187- query = delete (Codes ).where (Codes .id .in_ (exps_ids ))
188- await s .execute (query )
189- await s .commit ()
190-
191187 code = await get_code (s )
192188 if style == '2' :
193189 if value > 7 :
@@ -205,7 +201,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
205201 key = uuid .uuid4 ().hex
206202 if file :
207203 size , _text , _type , name = get_file_name (key , file .filename .split ('.' )[- 1 ], file )
208- if size > file_size_limit :
204+ if size > settings . FILE_SIZE_LIMIT :
209205 return {'code' : 404 , 'msg' : '文件过大' }
210206 else :
211207 size , _text , _type , name = len (text ), text , 'text' , '文本分享'
0 commit comments