Skip to content

Commit f4e6dfe

Browse files
committed
数据库操作改为异步
1 parent 195c6f9 commit f4e6dfe

File tree

2 files changed

+56
-48
lines changed

2 files changed

+56
-48
lines changed

database.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
import datetime
22

3-
from sqlalchemy import create_engine, DateTime
4-
from sqlalchemy.orm import sessionmaker
3+
from sqlalchemy import Boolean, Column, Integer, String, DateTime
54
from sqlalchemy.ext.declarative import declarative_base
6-
from sqlalchemy import Boolean, Column, Integer, String
5+
from sqlalchemy.ext.asyncio import create_async_engine
6+
from sqlalchemy.ext.asyncio.session import AsyncSession
7+
8+
9+
engine = create_async_engine("sqlite+aiosqlite:///database.db")
710

8-
engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
911
Base = declarative_base()
10-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
12+
13+
14+
async def get_session():
15+
async with AsyncSession(engine, expire_on_commit=False) as s:
16+
yield s
1117

1218

1319
class Codes(Base):
14-
__tablename__ = 'codes'
20+
__tablename__ = "codes"
1521
id = Column(Integer, primary_key=True, index=True)
1622
code = Column(String(10), unique=True, index=True)
1723
key = Column(String(30), unique=True, index=True)

main.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,23 @@
22
import os
33
import uuid
44
import threading
5+
import random
6+
57
from fastapi import FastAPI, Depends, UploadFile, Form, File
6-
from sqlalchemy import or_
7-
from sqlalchemy.orm import Session
88
from starlette.requests import Request
99
from starlette.responses import HTMLResponse
10-
import random
11-
1210
from starlette.staticfiles import StaticFiles
1311

14-
import database
15-
from database import engine, SessionLocal, Base
12+
from sqlalchemy import or_, select, update, delete, create_engine
13+
from sqlalchemy import select, update, delete
14+
from sqlalchemy.ext.asyncio.session import AsyncSession
15+
16+
from database import engine, get_session, Base, Codes
17+
1618

19+
engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
1720
Base.metadata.create_all(bind=engine)
21+
1822
app = FastAPI()
1923
if not os.path.exists('./static'):
2024
os.makedirs('./static')
@@ -58,17 +62,9 @@ def delete_file(files):
5862
os.remove('.' + file['text'])
5963

6064

61-
def get_db():
62-
db = SessionLocal()
63-
try:
64-
yield db
65-
finally:
66-
db.close()
67-
68-
69-
def get_code(db: Session = Depends(get_db)):
65+
async def get_code(s: AsyncSession):
7066
code = random.randint(10000, 99999)
71-
while db.query(database.Codes).filter(database.Codes.code == code).first():
67+
while (await s.execute(select(Codes.id).where(Codes.code == code))).scalar():
7268
code = random.randint(10000, 99999)
7369
return str(code)
7470

@@ -94,21 +90,23 @@ async def admin():
9490

9591

9692
@app.post(f'/{admin_address}')
97-
async def admin_post(request: Request, db: Session = Depends(get_db)):
93+
async def admin_post(request: Request, s: AsyncSession = Depends(get_session)):
9894
if request.headers.get('pwd') == admin_password:
99-
codes = db.query(database.Codes).all()
95+
query = select(Codes)
96+
codes = (await s.execute(query)).scalars().all()
10097
return {'code': 200, 'msg': '查询成功', 'data': codes}
10198
else:
10299
return {'code': 404, 'msg': '密码错误'}
103100

104101

105102
@app.delete(f'/{admin_address}')
106-
async def admin_delete(request: Request, code: str, db: Session = Depends(get_db)):
103+
async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(get_session)):
107104
if request.headers.get('pwd') == admin_password:
108-
file = db.query(database.Codes).filter(database.Codes.code == code).first()
105+
query = select(Codes).where(Codes.code == code)
106+
file = (await s.execute(query)).scalars().first()
109107
threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start()
110-
db.delete(file)
111-
db.commit()
108+
await s.delete(file)
109+
await s.commit()
112110
return {'code': 200, 'msg': '删除成功'}
113111
else:
114112
return {'code': 404, 'msg': '密码错误'}
@@ -138,20 +136,24 @@ def ip_error(ip):
138136

139137

140138
@app.post('/')
141-
async def index(request: Request, code: str, db: Session = Depends(get_db)):
139+
async def index(request: Request, code: str, s: AsyncSession = Depends(get_session)):
142140
ip = request.client.host
143141
if not check_ip(ip):
144142
return {'code': 404, 'msg': '错误次数过多,请稍后再试'}
145-
info = db.query(database.Codes).filter(database.Codes.code == code).first()
143+
query = select(Codes).where(Codes.code == code)
144+
info = (await s.execute(query)).scalars().first()
146145
if not info:
147146
return {'code': 404, 'msg': f'取件码错误,错误{error_count - ip_error(ip)}次将被禁止10分钟'}
148147
if info.exp_time < datetime.datetime.now() or info.count == 0:
149148
threading.Thread(target=delete_file, args=([{'type': info.type, 'text': info.text}],)).start()
150-
db.delete(info)
151-
db.commit()
149+
await s.delete(info)
150+
await s.commit()
152151
return {'code': 404, 'msg': '取件码已过期,请联系寄件人'}
153-
info.count -= 1
154-
db.commit()
152+
153+
count = info.count - 1
154+
query = update(Codes).where(Codes.id == info.id).values(count=count)
155+
await s.execute(query)
156+
await s.commit()
155157
return {
156158
'code': 200,
157159
'msg': '取件成功,请点击"取"查看',
@@ -161,17 +163,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)):
161163

162164
@app.post('/share')
163165
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
164-
file: UploadFile = File(default=None), db: Session = Depends(get_db)):
165-
exps = db.query(database.Codes).filter(
166-
or_(
167-
database.Codes.exp_time < datetime.datetime.now(),
168-
database.Codes.count == 0
169-
)
170-
)
171-
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps.all()],)).start()
172-
exps.delete()
173-
db.commit()
174-
code = get_code(db)
166+
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
167+
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
168+
exps = (await s.execute(query)).scalars().all()
169+
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start()
170+
171+
exps_ids = [exp.id for exp in exps]
172+
query = delete(Codes).where(Codes.id.in_(exps_ids))
173+
await s.execute(query)
174+
await s.commit()
175+
176+
code = await get_code(s)
175177
if style == '2':
176178
if value > 7:
177179
return {'code': 404, 'msg': '最大有效天数为7天'}
@@ -192,7 +194,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
192194
return {'code': 404, 'msg': '文件过大'}
193195
else:
194196
size, _text, _type, name = len(text), text, 'text', '文本分享'
195-
info = database.Codes(
197+
info = Codes(
196198
code=code,
197199
text=_text,
198200
size=size,
@@ -202,8 +204,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
202204
exp_time=exp_time,
203205
key=key
204206
)
205-
db.add(info)
206-
db.commit()
207+
s.add(info)
208+
await s.commit()
207209
return {
208210
'code': 200,
209211
'msg': '分享成功,请点击文件箱查看取件码',

0 commit comments

Comments
 (0)