Skip to content

Commit d25a921

Browse files
authored
Merge pull request #1 from veoco/master
使用 aiosqlite 驱动异步化数据库操作
2 parents ea64751 + b7315bb commit d25a921

File tree

3 files changed

+58
-50
lines changed

3 files changed

+58
-50
lines changed

database.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
import datetime
22

3-
from sqlalchemy import create_engine, DateTime
4-
from sqlalchemy.orm import sessionmaker
3+
from sqlalchemy import Boolean, Column, Integer, String, DateTime
54
from sqlalchemy.ext.declarative import declarative_base
6-
from sqlalchemy import Boolean, Column, Integer, String
5+
from sqlalchemy.ext.asyncio import create_async_engine
6+
from sqlalchemy.ext.asyncio.session import AsyncSession
7+
8+
9+
engine = create_async_engine("sqlite+aiosqlite:///database.db")
710

8-
engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
911
Base = declarative_base()
10-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
12+
13+
14+
async def get_session():
15+
async with AsyncSession(engine, expire_on_commit=False) as s:
16+
yield s
1117

1218

1319
class Codes(Base):
14-
__tablename__ = 'codes'
20+
__tablename__ = "codes"
1521
id = Column(Integer, primary_key=True, index=True)
1622
code = Column(String(10), unique=True, index=True)
1723
key = Column(String(30), unique=True, index=True)

main.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22
import os
33
import uuid
44
import threading
5+
import random
6+
57
from fastapi import FastAPI, Depends, UploadFile, Form, File
6-
from sqlalchemy import or_
7-
from sqlalchemy.orm import Session
88
from starlette.requests import Request
99
from starlette.responses import HTMLResponse, FileResponse
1010
import random
11-
1211
from starlette.staticfiles import StaticFiles
1312

14-
import database
15-
from database import engine, SessionLocal, Base
13+
from sqlalchemy import or_, select, update, delete, create_engine
14+
from sqlalchemy import select, update, delete
15+
from sqlalchemy.ext.asyncio.session import AsyncSession
16+
17+
from database import engine, get_session, Base, Codes
18+
1619

20+
engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
1721
Base.metadata.create_all(bind=engine)
22+
1823
app = FastAPI()
1924
if not os.path.exists('./static'):
2025
os.makedirs('./static')
@@ -58,17 +63,9 @@ def delete_file(files):
5863
os.remove('.' + file['text'])
5964

6065

61-
def get_db():
62-
db = SessionLocal()
63-
try:
64-
yield db
65-
finally:
66-
db.close()
67-
68-
69-
def get_code(db: Session = Depends(get_db)):
66+
async def get_code(s: AsyncSession):
7067
code = random.randint(10000, 99999)
71-
while db.query(database.Codes).filter(database.Codes.code == code).first():
68+
while (await s.execute(select(Codes.id).where(Codes.code == code))).scalar():
7269
code = random.randint(10000, 99999)
7370
return str(code)
7471

@@ -94,21 +91,23 @@ async def admin():
9491

9592

9693
@app.post(f'/{admin_address}')
97-
async def admin_post(request: Request, db: Session = Depends(get_db)):
94+
async def admin_post(request: Request, s: AsyncSession = Depends(get_session)):
9895
if request.headers.get('pwd') == admin_password:
99-
codes = db.query(database.Codes).all()
96+
query = select(Codes)
97+
codes = (await s.execute(query)).scalars().all()
10098
return {'code': 200, 'msg': '查询成功', 'data': codes}
10199
else:
102100
return {'code': 404, 'msg': '密码错误'}
103101

104102

105103
@app.delete(f'/{admin_address}')
106-
async def admin_delete(request: Request, code: str, db: Session = Depends(get_db)):
104+
async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(get_session)):
107105
if request.headers.get('pwd') == admin_password:
108-
file = db.query(database.Codes).filter(database.Codes.code == code).first()
106+
query = select(Codes).where(Codes.code == code)
107+
file = (await s.execute(query)).scalars().first()
109108
threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start()
110-
db.delete(file)
111-
db.commit()
109+
await s.delete(file)
110+
await s.commit()
112111
return {'code': 200, 'msg': '删除成功'}
113112
else:
114113
return {'code': 404, 'msg': '密码错误'}
@@ -150,22 +149,26 @@ async def get_file(code: str, db: Session = Depends(get_db)):
150149

151150

152151
@app.post('/')
153-
async def index(request: Request, code: str, db: Session = Depends(get_db)):
152+
async def index(request: Request, code: str, s: AsyncSession = Depends(get_session)):
154153
ip = request.client.host
155154
if not check_ip(ip):
156155
return {'code': 404, 'msg': '错误次数过多,请稍后再试'}
157-
info = db.query(database.Codes).filter(database.Codes.code == code).first()
156+
query = select(Codes).where(Codes.code == code)
157+
info = (await s.execute(query)).scalars().first()
158158
if not info:
159159
return {'code': 404, 'msg': f'取件码错误,错误{error_count - ip_error(ip)}次将被禁止10分钟'}
160160
if info.exp_time < datetime.datetime.now() or info.count == 0:
161161
threading.Thread(target=delete_file, args=([{'type': info.type, 'text': info.text}],)).start()
162-
db.delete(info)
163-
db.commit()
162+
await s.delete(info)
163+
await s.commit()
164164
return {'code': 404, 'msg': '取件码已过期,请联系寄件人'}
165-
info.count -= 1
166-
db.commit()
165+
count = info.count - 1
166+
query = update(Codes).where(Codes.id == info.id).values(count=count)
167+
await s.execute(query)
168+
await s.commit()
167169
if info.type != 'text':
168170
info.text = f'/select?code={code}'
171+
169172
return {
170173
'code': 200,
171174
'msg': '取件成功,请点击"取"查看',
@@ -175,17 +178,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)):
175178

176179
@app.post('/share')
177180
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
178-
file: UploadFile = File(default=None), db: Session = Depends(get_db)):
179-
exps = db.query(database.Codes).filter(
180-
or_(
181-
database.Codes.exp_time < datetime.datetime.now(),
182-
database.Codes.count == 0
183-
)
184-
)
185-
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps.all()],)).start()
186-
exps.delete()
187-
db.commit()
188-
code = get_code(db)
181+
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
182+
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
183+
exps = (await s.execute(query)).scalars().all()
184+
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start()
185+
186+
exps_ids = [exp.id for exp in exps]
187+
query = delete(Codes).where(Codes.id.in_(exps_ids))
188+
await s.execute(query)
189+
await s.commit()
190+
191+
code = await get_code(s)
189192
if style == '2':
190193
if value > 7:
191194
return {'code': 404, 'msg': '最大有效天数为7天'}
@@ -206,7 +209,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
206209
return {'code': 404, 'msg': '文件过大'}
207210
else:
208211
size, _text, _type, name = len(text), text, 'text', '文本分享'
209-
info = database.Codes(
212+
info = Codes(
210213
code=code,
211214
text=_text,
212215
size=size,
@@ -216,8 +219,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
216219
exp_time=exp_time,
217220
key=key
218221
)
219-
db.add(info)
220-
db.commit()
222+
s.add(info)
223+
await s.commit()
221224
return {
222225
'code': 200,
223226
'msg': '分享成功,请点击文件箱查看取件码',

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
fastapi==0.88.0
2-
python-multipart==0.0.5
1+
fastapi[all]==0.88.0
2+
aiosqlite==0.17.0
33
SQLAlchemy==1.4.44
4-
uvicorn==0.20.0

0 commit comments

Comments
 (0)