1717from veadk .utils .logger import get_logger
1818import tos
1919import asyncio
20- from typing import Literal , Union
20+ from typing import Union
2121from pydantic import BaseModel , Field
22- from typing import Optional , Any
22+ from typing import Any
2323
2424logger = get_logger (__name__ )
2525
@@ -45,34 +45,29 @@ class TOSConfig(BaseModel):
4545
4646class TOSClient (BaseModel ):
4747 config : TOSConfig = Field (default_factory = TOSConfig )
48- client : Optional [Any ] = Field (default = None , description = "TOS client instance" )
4948
50- def __init__ (self , ** data : Any ):
51- super ().__init__ (** data )
52- self .client = self ._init ()
53-
54- def _init (self ):
55- """initialize TOS client"""
49+ def model_post_init (self , __context : Any ) -> None :
5650 try :
57- return tos .TosClientV2 (
51+ self . _client = tos .TosClientV2 (
5852 self .config .ak ,
5953 self .config .sk ,
6054 endpoint = f"tos-{ self .config .region } .volces.com" ,
6155 region = self .config .region ,
6256 )
57+ logger .info ("Connected to TOS successfully." )
6358 except Exception as e :
6459 logger .error (f"Client initialization failed:{ e } " )
6560 return None
6661
6762 def create_bucket (self ) -> bool :
6863 """If the bucket does not exist, create it"""
6964 try :
70- self .client .head_bucket (self .config .bucket_name )
65+ self ._client .head_bucket (self .config .bucket_name )
7166 logger .info (f"Bucket { self .config .bucket_name } already exists" )
7267 return True
7368 except tos .exceptions .TosServerError as e :
7469 if e .status_code == 404 :
75- self .client .create_bucket (
70+ self ._client .create_bucket (
7671 bucket = self .config .bucket_name ,
7772 storage_class = tos .StorageClassType .Storage_Class_Standard ,
7873 acl = tos .ACLType .ACL_Private ,
@@ -87,10 +82,13 @@ def upload(
8782 self ,
8883 object_key : str ,
8984 data : Union [str , bytes ],
90- data_type : Literal ["file" , "bytes" ],
9185 ):
92- if data_type not in ("file" , "bytes" ):
93- error_msg = f"Upload failed: data_type error. Only 'file' and 'bytes' are supported, got { data_type } "
86+ if isinstance (data , str ):
87+ data_type = "file"
88+ elif isinstance (data , bytes ):
89+ data_type = "bytes"
90+ else :
91+ error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got { type (data )} "
9492 logger .error (error_msg )
9593 raise ValueError (error_msg )
9694 if data_type == "file" :
@@ -100,30 +98,30 @@ def upload(
10098
10199 def _do_upload_bytes (self , object_key : str , bytes : bytes ) -> bool :
102100 try :
103- if not self .client :
101+ if not self ._client :
104102 return False
105103 if not self .create_bucket ():
106104 return False
107-
108- self .client .put_object (
105+ self ._client .put_object (
109106 bucket = self .config .bucket_name , key = object_key , content = bytes
110107 )
108+ logger .debug (f"Upload success, object_key: { object_key } " )
111109 return True
112110 except Exception as e :
113111 logger .error (f"Upload failed: { e } " )
114112 return False
115113
116114 def _do_upload_file (self , object_key : str , file_path : str ) -> bool :
117- client = self ._init_tos_client ()
118115 try :
119- if not client :
116+ if not self . _client :
120117 return False
121- if not self .create_bucket (client , self .config .bucket_name ):
118+ if not self .create_bucket (self . _client , self .config .bucket_name ):
122119 return False
123120
124- client .put_object_from_file (
121+ self . _client .put_object_from_file (
125122 bucket = self .config .bucket_name , key = object_key , file_path = file_path
126123 )
124+ logger .debug (f"Upload success, object_key: { object_key } " )
127125 return True
128126 except Exception as e :
129127 logger .error (f"Upload failed: { e } " )
@@ -132,7 +130,7 @@ def _do_upload_file(self, object_key: str, file_path: str) -> bool:
132130 def download (self , object_key : str , save_path : str ) -> bool :
133131 """download image from TOS"""
134132 try :
135- object_stream = self .client .get_object (self .config .bucket_name , object_key )
133+ object_stream = self ._client .get_object (self .config .bucket_name , object_key )
136134
137135 save_dir = os .path .dirname (save_path )
138136 if save_dir and not os .path .exists (save_dir ):
@@ -150,6 +148,6 @@ def download(self, object_key: str, save_path: str) -> bool:
150148
151149 return False
152150
153- def close_client (self ):
154- if self .client :
155- self .client .close ()
151+ def close (self ):
152+ if self ._client :
153+ self ._client .close ()
0 commit comments