Skip to content

Commit 2c44659

Browse files
committed
feat: modify the code
1 parent c699e9c commit 2c44659

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

veadk/database/tos/tos_client.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from veadk.utils.logger import get_logger
1818
import tos
1919
import asyncio
20-
from typing import Literal, Union
20+
from typing import Union
2121
from pydantic import BaseModel, Field
22-
from typing import Optional, Any
22+
from typing import Any
2323

2424
logger = get_logger(__name__)
2525

@@ -45,34 +45,29 @@ class TOSConfig(BaseModel):
4545

4646
class TOSClient(BaseModel):
4747
config: TOSConfig = Field(default_factory=TOSConfig)
48-
client: Optional[Any] = Field(default=None, description="TOS client instance")
4948

50-
def __init__(self, **data: Any):
51-
super().__init__(**data)
52-
self.client = self._init()
53-
54-
def _init(self):
55-
"""initialize TOS client"""
49+
def model_post_init(self, __context: Any) -> None:
5650
try:
57-
return tos.TosClientV2(
51+
self._client = tos.TosClientV2(
5852
self.config.ak,
5953
self.config.sk,
6054
endpoint=f"tos-{self.config.region}.volces.com",
6155
region=self.config.region,
6256
)
57+
logger.info("Connected to TOS successfully.")
6358
except Exception as e:
6459
logger.error(f"Client initialization failed:{e}")
6560
return None
6661

6762
def create_bucket(self) -> bool:
6863
"""If the bucket does not exist, create it"""
6964
try:
70-
self.client.head_bucket(self.config.bucket_name)
65+
self._client.head_bucket(self.config.bucket_name)
7166
logger.info(f"Bucket {self.config.bucket_name} already exists")
7267
return True
7368
except tos.exceptions.TosServerError as e:
7469
if e.status_code == 404:
75-
self.client.create_bucket(
70+
self._client.create_bucket(
7671
bucket=self.config.bucket_name,
7772
storage_class=tos.StorageClassType.Storage_Class_Standard,
7873
acl=tos.ACLType.ACL_Private,
@@ -87,10 +82,13 @@ def upload(
8782
self,
8883
object_key: str,
8984
data: Union[str, bytes],
90-
data_type: Literal["file", "bytes"],
9185
):
92-
if data_type not in ("file", "bytes"):
93-
error_msg = f"Upload failed: data_type error. Only 'file' and 'bytes' are supported, got {data_type}"
86+
if isinstance(data, str):
87+
data_type = "file"
88+
elif isinstance(data, bytes):
89+
data_type = "bytes"
90+
else:
91+
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
9492
logger.error(error_msg)
9593
raise ValueError(error_msg)
9694
if data_type == "file":
@@ -100,30 +98,30 @@ def upload(
10098

10199
def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
102100
try:
103-
if not self.client:
101+
if not self._client:
104102
return False
105103
if not self.create_bucket():
106104
return False
107-
108-
self.client.put_object(
105+
self._client.put_object(
109106
bucket=self.config.bucket_name, key=object_key, content=bytes
110107
)
108+
logger.debug(f"Upload success, object_key: {object_key}")
111109
return True
112110
except Exception as e:
113111
logger.error(f"Upload failed: {e}")
114112
return False
115113

116114
def _do_upload_file(self, object_key: str, file_path: str) -> bool:
117-
client = self._init_tos_client()
118115
try:
119-
if not client:
116+
if not self._client:
120117
return False
121-
if not self.create_bucket(client, self.config.bucket_name):
118+
if not self.create_bucket(self._client, self.config.bucket_name):
122119
return False
123120

124-
client.put_object_from_file(
121+
self._client.put_object_from_file(
125122
bucket=self.config.bucket_name, key=object_key, file_path=file_path
126123
)
124+
logger.debug(f"Upload success, object_key: {object_key}")
127125
return True
128126
except Exception as e:
129127
logger.error(f"Upload failed: {e}")
@@ -132,7 +130,7 @@ def _do_upload_file(self, object_key: str, file_path: str) -> bool:
132130
def download(self, object_key: str, save_path: str) -> bool:
133131
"""download image from TOS"""
134132
try:
135-
object_stream = self.client.get_object(self.config.bucket_name, object_key)
133+
object_stream = self._client.get_object(self.config.bucket_name, object_key)
136134

137135
save_dir = os.path.dirname(save_path)
138136
if save_dir and not os.path.exists(save_dir):
@@ -150,6 +148,6 @@ def download(self, object_key: str, save_path: str) -> bool:
150148

151149
return False
152150

153-
def close_client(self):
154-
if self.client:
155-
self.client.close()
151+
def close(self):
152+
if self._client:
153+
self._client.close()

veadk/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def _convert_messages(self, messages, session_id) -> list:
123123
object_key = self._build_object_key(
124124
self.user_id, self.app_name, session_id, messages.media
125125
)
126-
asyncio.create_task(tos_client.upload(object_key, data, "bytes"))
126+
asyncio.create_task(tos_client.upload(object_key, data))
127+
tos_client.close()
127128

128129
messages = [
129130
types.Content(

0 commit comments

Comments
 (0)