|
3 | 3 | import uuid |
4 | 4 | import threading |
5 | 5 | import random |
| 6 | +import asyncio |
| 7 | + |
6 | 8 | from fastapi import FastAPI, Depends, UploadFile, Form, File |
7 | 9 | from starlette.requests import Request |
8 | 10 | from starlette.responses import HTMLResponse, FileResponse |
|
11 | 13 | from sqlalchemy import or_, select, update, delete |
12 | 14 | from sqlalchemy.ext.asyncio.session import AsyncSession |
13 | 15 |
|
14 | | -from database import get_session, Codes, init_models |
| 16 | +from database import get_session, Codes, init_models, engine |
15 | 17 |
|
16 | 18 | app = FastAPI() |
17 | 19 | if not os.path.exists('./static'): |
|
23 | 25 | async def startup(): |
24 | 26 | await init_models() |
25 | 27 |
|
| 28 | + asyncio.create_task(delete_expire_files()) |
| 29 | + |
26 | 30 |
|
27 | 31 | ############################################ |
28 | 32 | # 需要修改的参数 |
@@ -62,6 +66,21 @@ def delete_file(files): |
62 | 66 | os.remove('.' + file['text']) |
63 | 67 |
|
64 | 68 |
|
| 69 | +async def delete_expire_files(): |
| 70 | + while True: |
| 71 | + async with AsyncSession(engine, expire_on_commit=False) as s: |
| 72 | + query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0)) |
| 73 | + exps = (await s.execute(query)).scalars().all() |
| 74 | + await asyncio.to_thread(delete_file, [{'type': old.type, 'text': old.text} for old in exps]) |
| 75 | + |
| 76 | + exps_ids = [exp.id for exp in exps] |
| 77 | + query = delete(Codes).where(Codes.id.in_(exps_ids)) |
| 78 | + await s.execute(query) |
| 79 | + await s.commit() |
| 80 | + |
| 81 | + await asyncio.sleep(random.randint(60, 300)) |
| 82 | + |
| 83 | + |
65 | 84 | async def get_code(s: AsyncSession): |
66 | 85 | code = random.randint(10000, 99999) |
67 | 86 | while (await s.execute(select(Codes.id).where(Codes.code == code))).scalar(): |
@@ -104,7 +123,7 @@ async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(ge |
104 | 123 | if request.headers.get('pwd') == admin_password: |
105 | 124 | query = select(Codes).where(Codes.code == code) |
106 | 125 | file = (await s.execute(query)).scalars().first() |
107 | | - threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start() |
| 126 | + await asyncio.to_thread(delete_file, [{'type': file.type, 'text': file.text}]) |
108 | 127 | await s.delete(file) |
109 | 128 | await s.commit() |
110 | 129 | return {'code': 200, 'msg': '删除成功'} |
@@ -179,15 +198,6 @@ async def index(request: Request, code: str, s: AsyncSession = Depends(get_sessi |
179 | 198 | @app.post('/share') |
180 | 199 | async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1), |
181 | 200 | file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)): |
182 | | - query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0)) |
183 | | - exps = (await s.execute(query)).scalars().all() |
184 | | - threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start() |
185 | | - |
186 | | - exps_ids = [exp.id for exp in exps] |
187 | | - query = delete(Codes).where(Codes.id.in_(exps_ids)) |
188 | | - await s.execute(query) |
189 | | - await s.commit() |
190 | | - |
191 | 201 | code = await get_code(s) |
192 | 202 | if style == '2': |
193 | 203 | if value > 7: |
|
0 commit comments