22import os
33import uuid
44import threading
5+ import random
6+
57from fastapi import FastAPI , Depends , UploadFile , Form , File
6- from sqlalchemy import or_
7- from sqlalchemy .orm import Session
88from starlette .requests import Request
99from starlette .responses import HTMLResponse
10- import random
11-
1210from starlette .staticfiles import StaticFiles
1311
14- import database
15- from database import engine , SessionLocal , Base
12+ from sqlalchemy import or_ , select , update , delete , create_engine
13+ from sqlalchemy import select , update , delete
14+ from sqlalchemy .ext .asyncio .session import AsyncSession
15+
16+ from database import engine , get_session , Base , Codes
17+
1618
19+ engine = create_engine ('sqlite:///database.db' , connect_args = {"check_same_thread" : False })
1720Base .metadata .create_all (bind = engine )
21+
1822app = FastAPI ()
1923if not os .path .exists ('./static' ):
2024 os .makedirs ('./static' )
@@ -58,17 +62,9 @@ def delete_file(files):
5862 os .remove ('.' + file ['text' ])
5963
6064
61- def get_db ():
62- db = SessionLocal ()
63- try :
64- yield db
65- finally :
66- db .close ()
67-
68-
69- def get_code (db : Session = Depends (get_db )):
65+ async def get_code (s : AsyncSession ):
7066 code = random .randint (10000 , 99999 )
71- while db . query ( database . Codes ). filter ( database . Codes .code == code ). first ():
67+ while ( await s . execute ( select ( Codes . id ). where ( Codes .code == code ))). scalar ():
7268 code = random .randint (10000 , 99999 )
7369 return str (code )
7470
@@ -94,21 +90,23 @@ async def admin():
9490
9591
9692@app .post (f'/{ admin_address } ' )
97- async def admin_post (request : Request , db : Session = Depends (get_db )):
93+ async def admin_post (request : Request , s : AsyncSession = Depends (get_session )):
9894 if request .headers .get ('pwd' ) == admin_password :
99- codes = db .query (database .Codes ).all ()
95+ query = select (Codes )
96+ codes = (await s .execute (query )).scalars ().all ()
10097 return {'code' : 200 , 'msg' : '查询成功' , 'data' : codes }
10198 else :
10299 return {'code' : 404 , 'msg' : '密码错误' }
103100
104101
105102@app .delete (f'/{ admin_address } ' )
106- async def admin_delete (request : Request , code : str , db : Session = Depends (get_db )):
103+ async def admin_delete (request : Request , code : str , s : AsyncSession = Depends (get_session )):
107104 if request .headers .get ('pwd' ) == admin_password :
108- file = db .query (database .Codes ).filter (database .Codes .code == code ).first ()
105+ query = select (Codes ).where (Codes .code == code )
106+ file = (await s .execute (query )).scalars ().first ()
109107 threading .Thread (target = delete_file , args = ([{'type' : file .type , 'text' : file .text }],)).start ()
110- db .delete (file )
111- db .commit ()
108+ await s .delete (file )
109+ await s .commit ()
112110 return {'code' : 200 , 'msg' : '删除成功' }
113111 else :
114112 return {'code' : 404 , 'msg' : '密码错误' }
@@ -138,20 +136,24 @@ def ip_error(ip):
138136
139137
140138@app .post ('/' )
141- async def index (request : Request , code : str , db : Session = Depends (get_db )):
139+ async def index (request : Request , code : str , s : AsyncSession = Depends (get_session )):
142140 ip = request .client .host
143141 if not check_ip (ip ):
144142 return {'code' : 404 , 'msg' : '错误次数过多,请稍后再试' }
145- info = db .query (database .Codes ).filter (database .Codes .code == code ).first ()
143+ query = select (Codes ).where (Codes .code == code )
144+ info = (await s .execute (query )).scalars ().first ()
146145 if not info :
147146 return {'code' : 404 , 'msg' : f'取件码错误,错误{ error_count - ip_error (ip )} 次将被禁止10分钟' }
148147 if info .exp_time < datetime .datetime .now () or info .count == 0 :
149148 threading .Thread (target = delete_file , args = ([{'type' : info .type , 'text' : info .text }],)).start ()
150- db .delete (info )
151- db .commit ()
149+ await s .delete (info )
150+ await s .commit ()
152151 return {'code' : 404 , 'msg' : '取件码已过期,请联系寄件人' }
153- info .count -= 1
154- db .commit ()
152+
153+ count = info .count - 1
154+ query = update (Codes ).where (Codes .id == info .id ).values (count = count )
155+ await s .execute (query )
156+ await s .commit ()
155157 return {
156158 'code' : 200 ,
157159 'msg' : '取件成功,请点击"取"查看' ,
@@ -161,17 +163,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)):
161163
162164@app .post ('/share' )
163165async def share (text : str = Form (default = None ), style : str = Form (default = '2' ), value : int = Form (default = 1 ),
164- file : UploadFile = File (default = None ), db : Session = Depends (get_db )):
165- exps = db . query ( database . Codes ).filter (
166- or_ (
167- database . Codes . exp_time < datetime . datetime . now (),
168- database . Codes . count == 0
169- )
170- )
171- threading . Thread ( target = delete_file , args = ([[{ 'type' : old . type , 'text' : old . text }] for old in exps . all ()],)). start ( )
172- exps . delete ()
173- db . commit ()
174- code = get_code (db )
166+ file : UploadFile = File (default = None ), s : AsyncSession = Depends (get_session )):
167+ query = select ( Codes ).where ( or_ ( Codes . exp_time < datetime . datetime . now (), Codes . count == 0 ))
168+ exps = ( await s . execute ( query )). scalars (). all ()
169+ threading . Thread ( target = delete_file , args = ([[{ 'type' : old . type , 'text' : old . text }] for old in exps ],)). start ()
170+
171+ exps_ids = [ exp . id for exp in exps ]
172+ query = delete ( Codes ). where ( Codes . id . in_ ( exps_ids ) )
173+ await s . execute ( query )
174+ await s . commit ()
175+
176+ code = await get_code (s )
175177 if style == '2' :
176178 if value > 7 :
177179 return {'code' : 404 , 'msg' : '最大有效天数为7天' }
@@ -192,7 +194,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
192194 return {'code' : 404 , 'msg' : '文件过大' }
193195 else :
194196 size , _text , _type , name = len (text ), text , 'text' , '文本分享'
195- info = database . Codes (
197+ info = Codes (
196198 code = code ,
197199 text = _text ,
198200 size = size ,
@@ -202,8 +204,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
202204 exp_time = exp_time ,
203205 key = key
204206 )
205- db .add (info )
206- db .commit ()
207+ s .add (info )
208+ await s .commit ()
207209 return {
208210 'code' : 200 ,
209211 'msg' : '分享成功,请点击文件箱查看取件码' ,
0 commit comments