22import os
33import uuid
44import threading
5+ import random
6+
57from fastapi import FastAPI , Depends , UploadFile , Form , File
6- from sqlalchemy import or_
7- from sqlalchemy .orm import Session
88from starlette .requests import Request
99from starlette .responses import HTMLResponse , FileResponse
1010import random
11-
1211from starlette .staticfiles import StaticFiles
1312
14- import database
15- from database import engine , SessionLocal , Base
13+ from sqlalchemy import or_ , select , update , delete , create_engine
14+ from sqlalchemy import select , update , delete
15+ from sqlalchemy .ext .asyncio .session import AsyncSession
16+
17+ from database import engine , get_session , Base , Codes
18+
1619
20+ engine = create_engine ('sqlite:///database.db' , connect_args = {"check_same_thread" : False })
1721Base .metadata .create_all (bind = engine )
22+
1823app = FastAPI ()
1924if not os .path .exists ('./static' ):
2025 os .makedirs ('./static' )
@@ -58,17 +63,9 @@ def delete_file(files):
5863 os .remove ('.' + file ['text' ])
5964
6065
61- def get_db ():
62- db = SessionLocal ()
63- try :
64- yield db
65- finally :
66- db .close ()
67-
68-
69- def get_code (db : Session = Depends (get_db )):
66+ async def get_code (s : AsyncSession ):
7067 code = random .randint (10000 , 99999 )
71- while db . query ( database . Codes ). filter ( database . Codes .code == code ). first ():
68+ while ( await s . execute ( select ( Codes . id ). where ( Codes .code == code ))). scalar ():
7269 code = random .randint (10000 , 99999 )
7370 return str (code )
7471
@@ -94,21 +91,23 @@ async def admin():
9491
9592
9693@app .post (f'/{ admin_address } ' )
97- async def admin_post (request : Request , db : Session = Depends (get_db )):
94+ async def admin_post (request : Request , s : AsyncSession = Depends (get_session )):
9895 if request .headers .get ('pwd' ) == admin_password :
99- codes = db .query (database .Codes ).all ()
96+ query = select (Codes )
97+ codes = (await s .execute (query )).scalars ().all ()
10098 return {'code' : 200 , 'msg' : '查询成功' , 'data' : codes }
10199 else :
102100 return {'code' : 404 , 'msg' : '密码错误' }
103101
104102
105103@app .delete (f'/{ admin_address } ' )
106- async def admin_delete (request : Request , code : str , db : Session = Depends (get_db )):
104+ async def admin_delete (request : Request , code : str , s : AsyncSession = Depends (get_session )):
107105 if request .headers .get ('pwd' ) == admin_password :
108- file = db .query (database .Codes ).filter (database .Codes .code == code ).first ()
106+ query = select (Codes ).where (Codes .code == code )
107+ file = (await s .execute (query )).scalars ().first ()
109108 threading .Thread (target = delete_file , args = ([{'type' : file .type , 'text' : file .text }],)).start ()
110- db .delete (file )
111- db .commit ()
109+ await s .delete (file )
110+ await s .commit ()
112111 return {'code' : 200 , 'msg' : '删除成功' }
113112 else :
114113 return {'code' : 404 , 'msg' : '密码错误' }
@@ -150,22 +149,26 @@ async def get_file(code: str, db: Session = Depends(get_db)):
150149
151150
152151@app .post ('/' )
153- async def index (request : Request , code : str , db : Session = Depends (get_db )):
152+ async def index (request : Request , code : str , s : AsyncSession = Depends (get_session )):
154153 ip = request .client .host
155154 if not check_ip (ip ):
156155 return {'code' : 404 , 'msg' : '错误次数过多,请稍后再试' }
157- info = db .query (database .Codes ).filter (database .Codes .code == code ).first ()
156+ query = select (Codes ).where (Codes .code == code )
157+ info = (await s .execute (query )).scalars ().first ()
158158 if not info :
159159 return {'code' : 404 , 'msg' : f'取件码错误,错误{ error_count - ip_error (ip )} 次将被禁止10分钟' }
160160 if info .exp_time < datetime .datetime .now () or info .count == 0 :
161161 threading .Thread (target = delete_file , args = ([{'type' : info .type , 'text' : info .text }],)).start ()
162- db .delete (info )
163- db .commit ()
162+ await s .delete (info )
163+ await s .commit ()
164164 return {'code' : 404 , 'msg' : '取件码已过期,请联系寄件人' }
165- info .count -= 1
166- db .commit ()
165+ count = info .count - 1
166+ query = update (Codes ).where (Codes .id == info .id ).values (count = count )
167+ await s .execute (query )
168+ await s .commit ()
167169 if info .type != 'text' :
168170 info .text = f'/select?code={ code } '
171+
169172 return {
170173 'code' : 200 ,
171174 'msg' : '取件成功,请点击"取"查看' ,
@@ -175,17 +178,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)):
175178
176179@app .post ('/share' )
177180async def share (text : str = Form (default = None ), style : str = Form (default = '2' ), value : int = Form (default = 1 ),
178- file : UploadFile = File (default = None ), db : Session = Depends (get_db )):
179- exps = db . query ( database . Codes ).filter (
180- or_ (
181- database . Codes . exp_time < datetime . datetime . now (),
182- database . Codes . count == 0
183- )
184- )
185- threading . Thread ( target = delete_file , args = ([[{ 'type' : old . type , 'text' : old . text }] for old in exps . all ()],)). start ( )
186- exps . delete ()
187- db . commit ()
188- code = get_code (db )
181+ file : UploadFile = File (default = None ), s : AsyncSession = Depends (get_session )):
182+ query = select ( Codes ).where ( or_ ( Codes . exp_time < datetime . datetime . now (), Codes . count == 0 ))
183+ exps = ( await s . execute ( query )). scalars (). all ()
184+ threading . Thread ( target = delete_file , args = ([[{ 'type' : old . type , 'text' : old . text }] for old in exps ],)). start ()
185+
186+ exps_ids = [ exp . id for exp in exps ]
187+ query = delete ( Codes ). where ( Codes . id . in_ ( exps_ids ) )
188+ await s . execute ( query )
189+ await s . commit ()
190+
191+ code = await get_code (s )
189192 if style == '2' :
190193 if value > 7 :
191194 return {'code' : 404 , 'msg' : '最大有效天数为7天' }
@@ -206,7 +209,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
206209 return {'code' : 404 , 'msg' : '文件过大' }
207210 else :
208211 size , _text , _type , name = len (text ), text , 'text' , '文本分享'
209- info = database . Codes (
212+ info = Codes (
210213 code = code ,
211214 text = _text ,
212215 size = size ,
@@ -216,8 +219,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
216219 exp_time = exp_time ,
217220 key = key
218221 )
219- db .add (info )
220- db .commit ()
222+ s .add (info )
223+ await s .commit ()
221224 return {
222225 'code' : 200 ,
223226 'msg' : '分享成功,请点击文件箱查看取件码' ,
0 commit comments