|
3 | 3 | import uuid |
4 | 4 | import threading |
5 | 5 | import random |
6 | | - |
7 | 6 | from fastapi import FastAPI, Depends, UploadFile, Form, File |
8 | 7 | from starlette.requests import Request |
9 | 8 | from starlette.responses import HTMLResponse, FileResponse |
10 | | -import random |
11 | 9 | from starlette.staticfiles import StaticFiles |
12 | 10 |
|
13 | | -from sqlalchemy import or_, select, update, delete, create_engine |
14 | | -from sqlalchemy import select, update, delete |
| 11 | +from sqlalchemy import or_, select, update, delete |
15 | 12 | from sqlalchemy.ext.asyncio.session import AsyncSession |
16 | 13 |
|
17 | | -from database import engine, get_session, Base, Codes |
18 | | - |
19 | | - |
20 | | -engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False}) |
21 | | -Base.metadata.create_all(bind=engine) |
| 14 | +from database import get_session, Codes |
22 | 15 |
|
23 | 16 | app = FastAPI() |
24 | 17 | if not os.path.exists('./static'): |
@@ -137,13 +130,14 @@ def ip_error(ip): |
137 | 130 |
|
138 | 131 |
|
139 | 132 | @app.get('/select') |
140 | | -async def get_file(code: str, db: Session = Depends(get_db)): |
141 | | - file = db.query(database.Codes).filter(database.Codes.code == code).first() |
142 | | - if file: |
143 | | - if file.type == 'text': |
144 | | - return {'code': code, 'msg': '查询成功', 'data': file.text} |
| 133 | +async def get_file(code: str, s: AsyncSession = Depends(get_session)): |
| 134 | + query = select(Codes).where(Codes.code == code) |
| 135 | + info = (await s.execute(query)).scalars().first() |
| 136 | + if info: |
| 137 | + if info.type == 'text': |
| 138 | + return {'code': code, 'msg': '查询成功', 'data': info.text} |
145 | 139 | else: |
146 | | - return FileResponse('.' + file.text, filename=file.name) |
| 140 | + return FileResponse('.' + info.text, filename=info.name) |
147 | 141 | else: |
148 | 142 | return {'code': 404, 'msg': '口令不存在'} |
149 | 143 |
|
@@ -182,7 +176,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'), |
182 | 176 | query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0)) |
183 | 177 | exps = (await s.execute(query)).scalars().all() |
184 | 178 | threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start() |
185 | | - |
| 179 | + |
186 | 180 | exps_ids = [exp.id for exp in exps] |
187 | 181 | query = delete(Codes).where(Codes.id.in_(exps_ids)) |
188 | 182 | await s.execute(query) |
|
0 commit comments