Skip to content

Commit c23b334

Browse files
committed
feat: add License Header
1 parent 4c9dd7e commit c23b334

File tree

3 files changed

+30
-22
lines changed

3 files changed

+30
-22
lines changed

tests/test_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _test_convert_messages(runner):
2828
role="user",
2929
)
3030
]
31-
actual_message = runner._convert_messages(message)
31+
actual_message = runner._convert_messages(message, session_id="test_session_id")
3232
assert actual_message == expected_message
3333

3434
message = ["test message 1", "test message 2"]
@@ -42,7 +42,7 @@ def _test_convert_messages(runner):
4242
role="user",
4343
),
4444
]
45-
actual_message = runner._convert_messages(message)
45+
actual_message = runner._convert_messages(message, session_id="test_session_id")
4646
assert actual_message == expected_message
4747

4848

veadk/database/tos/tos_handler.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import os
216
from veadk.config import getenv
317
from veadk.utils.logger import get_logger
@@ -17,6 +31,7 @@ def __init__(self):
1731
self.ak = getenv("VOLCENGINE_ACCESS_KEY")
1832
self.sk = getenv("VOLCENGINE_SECRET_KEY")
1933
self.bucket_name = getenv("DATABASE_TOS_BUCKET")
34+
self.client = self._init_tos_client()
2035

2136
def _init_tos_client(self):
2237
"""initialize TOS client"""
@@ -82,15 +97,15 @@ def parse_url(self, url: str) -> tuple[str, str]:
8297
raise ValueError("URL format error, it should be: bucket_name/object_key")
8398
return parts
8499

85-
def create_bucket(self, client: tos.TosClientV2, bucket_name: str) -> bool:
100+
def create_bucket(self, bucket_name: str) -> bool:
86101
"""If the bucket does not exist, create it"""
87102
try:
88-
client.head_bucket(self.bucket_name)
103+
self.client.head_bucket(self.bucket_name)
89104
logger.debug(f"Bucket {bucket_name} already exists")
90105
return True
91106
except tos.exceptions.TosServerError as e:
92107
if e.status_code == 404:
93-
client.create_bucket(
108+
self.client.create_bucket(
94109
bucket=bucket_name,
95110
storage_class=tos.StorageClassType.Storage_Class_Standard,
96111
acl=tos.ACLType.ACL_Private,
@@ -115,21 +130,17 @@ def upload_to_tos(
115130

116131
def _do_upload_bytes(self, url: str, bytes: bytes) -> bool:
117132
bucket_name, object_key = self.parse_url(url)
118-
client = self._init_tos_client()
119133
try:
120-
if not client:
134+
if not self.client:
121135
return False
122-
if not self.create_bucket(client, bucket_name):
136+
if not self.create_bucket(bucket_name):
123137
return False
124138

125-
client.put_object(bucket=bucket_name, key=object_key, content=bytes)
139+
self.client.put_object(bucket=bucket_name, key=object_key, content=bytes)
126140
return True
127141
except Exception as e:
128142
logger.error(f"Upload failed: {e}")
129143
return False
130-
finally:
131-
if client:
132-
client.close()
133144

134145
def _do_upload_file(self, url: str, file_path: str) -> bool:
135146
bucket_name, object_key = self.parse_url(url)
@@ -147,19 +158,13 @@ def _do_upload_file(self, url: str, file_path: str) -> bool:
147158
except Exception as e:
148159
logger.error(f"Upload failed: {e}")
149160
return False
150-
finally:
151-
if client:
152-
client.close()
153161

154162
def download_from_tos(self, url: str, save_path: str) -> bool:
155163
"""download image from TOS"""
156164
try:
157165
bucket_name, object_key = self.parse_url(url)
158-
client = self._init_tos_client()
159-
if not client:
160-
return False
161166

162-
object_stream = client.get_object(bucket_name, object_key)
167+
object_stream = self.client.get_object(bucket_name, object_key)
163168

164169
save_dir = os.path.dirname(save_path)
165170
if save_dir and not os.path.exists(save_dir):
@@ -170,11 +175,13 @@ def download_from_tos(self, url: str, save_path: str) -> bool:
170175
f.write(chunk)
171176

172177
logger.debug(f"Image download success, saved to: {save_path}")
173-
client.close()
174178
return True
175179

176180
except Exception as e:
177181
logger.error(f"Image download failed: {str(e)}")
178-
if "client" in locals():
179-
client.close()
182+
180183
return False
184+
185+
def close_client(self):
186+
if self.client:
187+
self.client.close()

veadk/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _convert_messages(self, messages, session_id) -> list:
106106
self.user_id, self.app_name, session_id, messages.media
107107
)
108108
asyncio.create_task(tos_handler.upload_to_tos(url, data, "bytes"))
109+
tos_handler.close_client()
109110

110111
messages = [
111112
types.Content(

0 commit comments

Comments
 (0)