Skip to content

Commit 4c9dd7e

Browse files
committed
feat: Normalize the code
1 parent 7275d9e commit 4c9dd7e

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed
Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tos
55
from datetime import datetime
66
import asyncio
7-
from typing import Literal
7+
from typing import Literal, Union
88
from urllib.parse import urlparse
99

1010
logger = get_logger(__name__)
@@ -13,7 +13,7 @@
1313
class TOSHandler:
1414
def __init__(self):
1515
"""Initialize TOS configuration information"""
16-
self.region = getenv("VOLCENGINE_REGION")
16+
self.region = getenv("DATABASE_TOS_REGION")
1717
self.ak = getenv("VOLCENGINE_ACCESS_KEY")
1818
self.sk = getenv("VOLCENGINE_SECRET_KEY")
1919
self.bucket_name = getenv("DATABASE_TOS_BUCKET")
@@ -63,24 +63,26 @@ def get_suffix(self, data_path: str) -> str:
6363
return f".{candidate.lower()}"
6464
return f".{parts[-1].lower()}"
6565

66-
def gen_url(self, user_id, app_name, session_id, data_path):
66+
def gen_url(
67+
self, user_id: str, app_name: str, session_id: str, data_path: str
68+
) -> str:
6769
"""generate TOS URL"""
68-
suffix = self.get_suffix(data_path)
69-
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
70-
url = (
70+
suffix: str = self.get_suffix(data_path)
71+
timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
72+
url: str = (
7173
f"{self.bucket_name}/{app_name}/{user_id}-{session_id}-{timestamp}{suffix}"
7274
)
7375
return url
7476

75-
def parse_url(self, url):
77+
def parse_url(self, url: str) -> tuple[str, str]:
7678
"""Parse the URL to obtain bucket_name and object_key"""
7779
"""bucket_name/object_key"""
7880
parts = url.split("/", 1)
7981
if len(parts) < 2:
8082
raise ValueError("URL format error, it should be: bucket_name/object_key")
8183
return parts
8284

83-
def create_bucket(self, client, bucket_name):
85+
def create_bucket(self, client: tos.TosClientV2, bucket_name: str) -> bool:
8486
"""If the bucket does not exist, create it"""
8587
try:
8688
client.head_bucket(self.bucket_name)
@@ -99,7 +101,9 @@ def create_bucket(self, client, bucket_name):
99101
logger.error(f"Bucket creation failed: {str(e)}")
100102
return False
101103

102-
def upload_to_tos(self, url: str, data, data_type: Literal["file", "bytes"]):
104+
def upload_to_tos(
105+
self, url: str, data: Union[str, bytes], data_type: Literal["file", "bytes"]
106+
):
103107
if data_type not in ("file", "bytes"):
104108
error_msg = f"Upload failed: data_type error. Only 'file' and 'bytes' are supported, got {data_type}"
105109
logger.error(error_msg)
@@ -109,7 +113,7 @@ def upload_to_tos(self, url: str, data, data_type: Literal["file", "bytes"]):
109113
elif data_type == "bytes":
110114
return asyncio.to_thread(self._do_upload_bytes, url, data)
111115

112-
def _do_upload_bytes(self, url, bytes):
116+
def _do_upload_bytes(self, url: str, bytes: bytes) -> bool:
113117
bucket_name, object_key = self.parse_url(url)
114118
client = self._init_tos_client()
115119
try:
@@ -127,7 +131,7 @@ def _do_upload_bytes(self, url, bytes):
127131
if client:
128132
client.close()
129133

130-
def _do_upload_file(self, url, file_path):
134+
def _do_upload_file(self, url: str, file_path: str) -> bool:
131135
bucket_name, object_key = self.parse_url(url)
132136
client = self._init_tos_client()
133137
try:
@@ -147,7 +151,7 @@ def _do_upload_file(self, url, file_path):
147151
if client:
148152
client.close()
149153

150-
def download_from_tos(self, url, save_path):
154+
def download_from_tos(self, url: str, save_path: str) -> bool:
151155
"""download image from TOS"""
152156
try:
153157
bucket_name, object_key = self.parse_url(url)

veadk/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from veadk.types import MediaMessage
3333
from veadk.utils.logger import get_logger
3434
from veadk.utils.misc import read_png_to_bytes
35-
from veadk.database.tos.toshandler import TOSHandler
35+
from veadk.database.tos.tos_handler import TOSHandler
3636

3737
logger = get_logger(__name__)
3838

0 commit comments

Comments
 (0)