Skip to content

Commit 7275d9e

Browse files
committed
feat: add toshandler
1 parent 9dad26c commit 7275d9e

File tree

2 files changed

+191
-4
lines changed

2 files changed

+191
-4
lines changed

veadk/database/tos/toshandler.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import os
2+
from veadk.config import getenv
3+
from veadk.utils.logger import get_logger
4+
import tos
5+
from datetime import datetime
6+
import asyncio
7+
from typing import Literal
8+
from urllib.parse import urlparse
9+
10+
logger = get_logger(__name__)
11+
12+
13+
class TOSHandler:
14+
def __init__(self):
15+
"""Initialize TOS configuration information"""
16+
self.region = getenv("VOLCENGINE_REGION")
17+
self.ak = getenv("VOLCENGINE_ACCESS_KEY")
18+
self.sk = getenv("VOLCENGINE_SECRET_KEY")
19+
self.bucket_name = getenv("DATABASE_TOS_BUCKET")
20+
21+
def _init_tos_client(self):
22+
"""initialize TOS client"""
23+
try:
24+
return tos.TosClientV2(
25+
self.ak,
26+
self.sk,
27+
endpoint=f"tos-{self.region}.volces.com",
28+
region=self.region,
29+
)
30+
except Exception as e:
31+
logger.error(f"Client initialization failed:{e}")
32+
return None
33+
34+
def get_suffix(self, data_path: str) -> str:
35+
"""Extract the complete file suffix with leading dot (including compound suffixes such as .tar.gz)"""
36+
COMPOUND_SUFFIXES = {
37+
"tar.gz",
38+
"tar.bz2",
39+
"tar.xz",
40+
"tar.Z",
41+
"tar.lz",
42+
"tar.lzma",
43+
"tar.lzo",
44+
"gz",
45+
"bz2",
46+
"xz",
47+
"Z",
48+
"lz",
49+
"lzma",
50+
"lzo",
51+
}
52+
parsed = urlparse(data_path)
53+
path = parsed.path if parsed.scheme in ("http", "https") else data_path
54+
55+
filename = os.path.basename(path).split("?")[0].split("#")[0]
56+
57+
parts = filename.split(".")
58+
if len(parts) < 2:
59+
return ""
60+
for i in range(2, len(parts) + 1):
61+
candidate = ".".join(parts[-i:])
62+
if candidate in COMPOUND_SUFFIXES:
63+
return f".{candidate.lower()}"
64+
return f".{parts[-1].lower()}"
65+
66+
def gen_url(self, user_id, app_name, session_id, data_path):
67+
"""generate TOS URL"""
68+
suffix = self.get_suffix(data_path)
69+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
70+
url = (
71+
f"{self.bucket_name}/{app_name}/{user_id}-{session_id}-{timestamp}{suffix}"
72+
)
73+
return url
74+
75+
def parse_url(self, url):
76+
"""Parse the URL to obtain bucket_name and object_key"""
77+
"""bucket_name/object_key"""
78+
parts = url.split("/", 1)
79+
if len(parts) < 2:
80+
raise ValueError("URL format error, it should be: bucket_name/object_key")
81+
return parts
82+
83+
def create_bucket(self, client, bucket_name):
84+
"""If the bucket does not exist, create it"""
85+
try:
86+
client.head_bucket(self.bucket_name)
87+
logger.debug(f"Bucket {bucket_name} already exists")
88+
return True
89+
except tos.exceptions.TosServerError as e:
90+
if e.status_code == 404:
91+
client.create_bucket(
92+
bucket=bucket_name,
93+
storage_class=tos.StorageClassType.Storage_Class_Standard,
94+
acl=tos.ACLType.ACL_Private,
95+
)
96+
logger.debug(f"Bucket {bucket_name} created successfully")
97+
return True
98+
except Exception as e:
99+
logger.error(f"Bucket creation failed: {str(e)}")
100+
return False
101+
102+
def upload_to_tos(self, url: str, data, data_type: Literal["file", "bytes"]):
103+
if data_type not in ("file", "bytes"):
104+
error_msg = f"Upload failed: data_type error. Only 'file' and 'bytes' are supported, got {data_type}"
105+
logger.error(error_msg)
106+
raise ValueError(error_msg)
107+
if data_type == "file":
108+
return asyncio.to_thread(self._do_upload_file, url, data)
109+
elif data_type == "bytes":
110+
return asyncio.to_thread(self._do_upload_bytes, url, data)
111+
112+
def _do_upload_bytes(self, url, bytes):
113+
bucket_name, object_key = self.parse_url(url)
114+
client = self._init_tos_client()
115+
try:
116+
if not client:
117+
return False
118+
if not self.create_bucket(client, bucket_name):
119+
return False
120+
121+
client.put_object(bucket=bucket_name, key=object_key, content=bytes)
122+
return True
123+
except Exception as e:
124+
logger.error(f"Upload failed: {e}")
125+
return False
126+
finally:
127+
if client:
128+
client.close()
129+
130+
def _do_upload_file(self, url, file_path):
131+
bucket_name, object_key = self.parse_url(url)
132+
client = self._init_tos_client()
133+
try:
134+
if not client:
135+
return False
136+
if not self.create_bucket(client, bucket_name):
137+
return False
138+
139+
client.put_object_from_file(
140+
bucket=bucket_name, key=object_key, file_path=file_path
141+
)
142+
return True
143+
except Exception as e:
144+
logger.error(f"Upload failed: {e}")
145+
return False
146+
finally:
147+
if client:
148+
client.close()
149+
150+
def download_from_tos(self, url, save_path):
151+
"""download image from TOS"""
152+
try:
153+
bucket_name, object_key = self.parse_url(url)
154+
client = self._init_tos_client()
155+
if not client:
156+
return False
157+
158+
object_stream = client.get_object(bucket_name, object_key)
159+
160+
save_dir = os.path.dirname(save_path)
161+
if save_dir and not os.path.exists(save_dir):
162+
os.makedirs(save_dir, exist_ok=True)
163+
164+
with open(save_path, "wb") as f:
165+
for chunk in object_stream:
166+
f.write(chunk)
167+
168+
logger.debug(f"Image download success, saved to: {save_path}")
169+
client.close()
170+
return True
171+
172+
except Exception as e:
173+
logger.error(f"Image download failed: {str(e)}")
174+
if "client" in locals():
175+
client.close()
176+
return False

veadk/runner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
from typing import Union
1516

1617
from google.adk.agents import RunConfig
@@ -31,6 +32,7 @@
3132
from veadk.types import MediaMessage
3233
from veadk.utils.logger import get_logger
3334
from veadk.utils.misc import read_png_to_bytes
35+
from veadk.database.tos.toshandler import TOSHandler
3436

3537
logger = get_logger(__name__)
3638

@@ -89,22 +91,31 @@ def __init__(
8991
plugins=plugins,
9092
)
9193

92-
def _convert_messages(self, messages) -> list:
94+
def _convert_messages(self, messages, session_id) -> list:
9395
if isinstance(messages, str):
9496
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
9597
elif isinstance(messages, MediaMessage):
9698
assert messages.media.endswith(".png"), (
9799
"The MediaMessage only supports PNG format file for now."
98100
)
101+
data = read_png_to_bytes(messages.media)
102+
url = messages.media
103+
if self.agent.tracers:
104+
tos_handler = TOSHandler()
105+
url = tos_handler.gen_url(
106+
self.user_id, self.app_name, session_id, messages.media
107+
)
108+
asyncio.create_task(tos_handler.upload_to_tos(url, data, "bytes"))
109+
99110
messages = [
100111
types.Content(
101112
role="user",
102113
parts=[
103114
types.Part(text=messages.text),
104115
types.Part(
105116
inline_data=Blob(
106-
display_name=messages.media,
107-
data=read_png_to_bytes(messages.media),
117+
display_name=url,
118+
data=data,
108119
mime_type="image/png",
109120
)
110121
),
@@ -164,7 +175,7 @@ async def run(
164175
stream: bool = False,
165176
save_tracing_data: bool = False,
166177
):
167-
converted_messages: list = self._convert_messages(messages)
178+
converted_messages: list = self._convert_messages(messages, session_id)
168179

169180
await self.short_term_memory.create_session(
170181
app_name=self.app_name, user_id=self.user_id, session_id=session_id

0 commit comments

Comments
 (0)