Skip to content

Commit c699e9c

Browse files
committed
feat: add tos_config & modify the code
1 parent 58d5130 commit c699e9c

File tree

3 files changed

+179
-194
lines changed

3 files changed

+179
-194
lines changed

veadk/database/tos/tos_client.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from veadk.config import getenv
17+
from veadk.utils.logger import get_logger
18+
import tos
19+
import asyncio
20+
from typing import Literal, Union
21+
from pydantic import BaseModel, Field
22+
from typing import Optional, Any
23+
24+
logger = get_logger(__name__)
25+
26+
27+
class TOSConfig(BaseModel):
28+
region: str = Field(
29+
default_factory=lambda: getenv("DATABASE_TOS_REGION"),
30+
description="TOS region",
31+
)
32+
ak: str = Field(
33+
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"),
34+
description="Volcengine access key",
35+
)
36+
sk: str = Field(
37+
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"),
38+
description="Volcengine secret key",
39+
)
40+
bucket_name: str = Field(
41+
default_factory=lambda: getenv("DATABASE_TOS_BUCKET"),
42+
description="TOS bucket name",
43+
)
44+
45+
46+
class TOSClient(BaseModel):
47+
config: TOSConfig = Field(default_factory=TOSConfig)
48+
client: Optional[Any] = Field(default=None, description="TOS client instance")
49+
50+
def __init__(self, **data: Any):
51+
super().__init__(**data)
52+
self.client = self._init()
53+
54+
def _init(self):
55+
"""initialize TOS client"""
56+
try:
57+
return tos.TosClientV2(
58+
self.config.ak,
59+
self.config.sk,
60+
endpoint=f"tos-{self.config.region}.volces.com",
61+
region=self.config.region,
62+
)
63+
except Exception as e:
64+
logger.error(f"Client initialization failed:{e}")
65+
return None
66+
67+
def create_bucket(self) -> bool:
68+
"""If the bucket does not exist, create it"""
69+
try:
70+
self.client.head_bucket(self.config.bucket_name)
71+
logger.info(f"Bucket {self.config.bucket_name} already exists")
72+
return True
73+
except tos.exceptions.TosServerError as e:
74+
if e.status_code == 404:
75+
self.client.create_bucket(
76+
bucket=self.config.bucket_name,
77+
storage_class=tos.StorageClassType.Storage_Class_Standard,
78+
acl=tos.ACLType.ACL_Private,
79+
)
80+
logger.info(f"Bucket {self.config.bucket_name} created successfully")
81+
return True
82+
except Exception as e:
83+
logger.error(f"Bucket creation failed: {str(e)}")
84+
return False
85+
86+
def upload(
87+
self,
88+
object_key: str,
89+
data: Union[str, bytes],
90+
data_type: Literal["file", "bytes"],
91+
):
92+
if data_type not in ("file", "bytes"):
93+
error_msg = f"Upload failed: data_type error. Only 'file' and 'bytes' are supported, got {data_type}"
94+
logger.error(error_msg)
95+
raise ValueError(error_msg)
96+
if data_type == "file":
97+
return asyncio.to_thread(self._do_upload_file, object_key, data)
98+
elif data_type == "bytes":
99+
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
100+
101+
def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
102+
try:
103+
if not self.client:
104+
return False
105+
if not self.create_bucket():
106+
return False
107+
108+
self.client.put_object(
109+
bucket=self.config.bucket_name, key=object_key, content=bytes
110+
)
111+
return True
112+
except Exception as e:
113+
logger.error(f"Upload failed: {e}")
114+
return False
115+
116+
def _do_upload_file(self, object_key: str, file_path: str) -> bool:
117+
client = self._init_tos_client()
118+
try:
119+
if not client:
120+
return False
121+
if not self.create_bucket(client, self.config.bucket_name):
122+
return False
123+
124+
client.put_object_from_file(
125+
bucket=self.config.bucket_name, key=object_key, file_path=file_path
126+
)
127+
return True
128+
except Exception as e:
129+
logger.error(f"Upload failed: {e}")
130+
return False
131+
132+
def download(self, object_key: str, save_path: str) -> bool:
133+
"""download image from TOS"""
134+
try:
135+
object_stream = self.client.get_object(self.config.bucket_name, object_key)
136+
137+
save_dir = os.path.dirname(save_path)
138+
if save_dir and not os.path.exists(save_dir):
139+
os.makedirs(save_dir, exist_ok=True)
140+
141+
with open(save_path, "wb") as f:
142+
for chunk in object_stream:
143+
f.write(chunk)
144+
145+
logger.debug(f"Image download success, saved to: {save_path}")
146+
return True
147+
148+
except Exception as e:
149+
logger.error(f"Image download failed: {str(e)}")
150+
151+
return False
152+
153+
def close_client(self):
154+
if self.client:
155+
self.client.close()

veadk/database/tos/tos_handler.py

Lines changed: 0 additions & 187 deletions
This file was deleted.

0 commit comments

Comments
 (0)