Skip to content

Commit 9ecd2c2

Browse files
authored
Merge pull request #11 from veoco/master
调整文件的大小计算和写入方法
2 parents e6fa174 + 66d9e43 commit 9ecd2c2

File tree

2 files changed

+54
-34
lines changed

2 files changed

+54
-34
lines changed

main.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
from pathlib import Path
66

7-
from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException
7+
from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException, BackgroundTasks
88
from starlette.responses import HTMLResponse, FileResponse
99
from starlette.staticfiles import StaticFiles
1010

@@ -51,9 +51,13 @@ async def delete_expire_files():
5151
async with AsyncSession(engine, expire_on_commit=False) as s:
5252
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
5353
exps = (await s.execute(query)).scalars().all()
54-
files = [{'type': old.type, 'text': old.text} for old in exps]
54+
files = []
55+
exps_ids = []
56+
for exp in exps:
57+
if exp.type != "text":
58+
files.append(exp.text)
59+
exps_ids.append(exp.id)
5560
await storage.delete_files(files)
56-
exps_ids = [exp.id for exp in exps]
5761
query = delete(Codes).where(Codes.id.in_(exps_ids))
5862
await s.execute(query)
5963
await s.commit()
@@ -83,9 +87,11 @@ async def admin_post(s: AsyncSession = Depends(get_session)):
8387
async def admin_delete(code: str, s: AsyncSession = Depends(get_session)):
8488
query = select(Codes).where(Codes.code == code)
8589
file = (await s.execute(query)).scalars().first()
86-
await storage.delete_file({'type': file.type, 'text': file.text})
87-
await s.delete(file)
88-
await s.commit()
90+
if file:
91+
if file.type != 'text':
92+
await storage.delete_file(file.text)
93+
await s.delete(file)
94+
await s.commit()
8995
return {'detail': '删除成功'}
9096

9197

@@ -115,7 +121,8 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend
115121
error_count = settings.ERROR_COUNT - ip_limit.add_ip(ip)
116122
raise HTTPException(status_code=404, detail=f"取件码错误,错误{error_count}次将被禁止10分钟")
117123
if info.exp_time < datetime.datetime.now() or info.count == 0:
118-
await storage.delete_file({'type': info.type, 'text': info.text})
124+
if info.type != "text":
125+
await storage.delete_file(info.text)
119126
await s.delete(info)
120127
await s.commit()
121128
raise HTTPException(status_code=404, detail="取件码已过期,请联系寄件人")
@@ -130,8 +137,8 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend
130137

131138

132139
@app.post('/share')
133-
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
134-
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
140+
async def share(background_tasks: BackgroundTasks, text: str = Form(default=None), style: str = Form(default='2'),
141+
value: int = Form(default=1), file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
135142
code = await get_code(s)
136143
if style == '2':
137144
if value > 7:
@@ -148,11 +155,11 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
148155
exp_count = -1
149156
key = uuid.uuid4().hex
150157
if file:
151-
file_bytes = await file.read()
152-
size = len(file_bytes)
158+
size = await storage.get_size(file)
153159
if size > settings.FILE_SIZE_LIMIT:
154160
raise HTTPException(status_code=400, detail="文件过大")
155-
_text, _type, name = await storage.save_file(file, file_bytes, key), file.content_type, file.filename
161+
_text, _type, name = await storage.get_text(file, key), file.content_type, file.filename
162+
background_tasks.add_task(storage.save_file, file, _text)
156163
else:
157164
size, _text, _type, name = len(text), text, 'text', '文本分享'
158165
info = Codes(

storage.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
22
import asyncio
3-
import datetime
3+
from datetime import datetime
44
from pathlib import Path
5+
from typing import BinaryIO
6+
7+
from fastapi import UploadFile
58

69
import settings
710

@@ -11,35 +14,45 @@ class FileSystemStorage:
1114
STATIC_URL = settings.STATIC_URL
1215
NAME = "filesystem"
1316

14-
async def get_filepath(self, path):
15-
return self.DATA_ROOT / path.lstrip(self.STATIC_URL + '/')
16-
17-
def _save(self, filepath, file_bytes):
18-
with open(filepath, 'wb') as f:
19-
f.write(file_bytes)
17+
async def get_filepath(self, text: str):
18+
return self.DATA_ROOT / text.lstrip(self.STATIC_URL + '/')
2019

21-
async def save_file(self, file, file_bytes, key):
22-
now = datetime.datetime.now()
23-
path = self.DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/"
20+
async def get_text(self, file: UploadFile, key: str):
2421
ext = file.filename.split('.')[-1]
25-
name = f'{key}.{ext}'
22+
now = datetime.now()
23+
path = self.DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/"
2624
if not path.exists():
2725
path.mkdir(parents=True)
28-
filepath = path / name
29-
await asyncio.to_thread(self._save, filepath, file_bytes)
26+
filepath = path / f'{key}.{ext}'
3027
text = f"{self.STATIC_URL}/{filepath.relative_to(self.DATA_ROOT)}"
3128
return text
3229

33-
async def delete_file(self, file):
34-
# 是文件就删除
35-
if file['type'] != 'text':
36-
filepath = self.DATA_ROOT / file['text'].lstrip(self.STATIC_URL + '/')
37-
await asyncio.to_thread(os.remove, filepath)
30+
async def get_size(self, file: UploadFile):
31+
f = file.file
32+
f.seek(0, os.SEEK_END)
33+
size = f.tell()
34+
f.seek(0, os.SEEK_SET)
35+
return size
3836

39-
async def delete_files(self, files):
40-
for file in files:
41-
if file['type'] != 'text':
42-
await self.delete_file(file)
37+
def _save(self, filepath, file: BinaryIO):
38+
with open(filepath, 'wb') as f:
39+
chunk_size = 256 * 1024
40+
chunk = file.read(chunk_size)
41+
while chunk:
42+
f.write(chunk)
43+
chunk = file.read(chunk_size)
44+
45+
async def save_file(self, file: UploadFile, text: str):
46+
filepath = await self.get_filepath(text)
47+
await asyncio.to_thread(self._save, filepath, file.file)
48+
49+
async def delete_file(self, text: str):
50+
filepath = await self.get_filepath(text)
51+
await asyncio.to_thread(os.remove, filepath)
52+
53+
async def delete_files(self, texts):
54+
tasks = [self.delete_file(text) for text in texts]
55+
await asyncio.gather(*tasks)
4356

4457

4558
STORAGE_ENGINE = {

0 commit comments

Comments
 (0)