Skip to content

Commit c2f4d98

Browse files
committed
add judgment for tos
1 parent 91c3329 commit c2f4d98

File tree

3 files changed

+66
-25
lines changed

3 files changed

+66
-25
lines changed

tests/test_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def _test_convert_messages(runner):
2828
role="user",
2929
)
3030
]
31-
actual_message = runner._convert_messages(message, session_id="test_session_id")
31+
actual_message = runner._convert_messages(
32+
message, session_id="test_session_id", upload_inline_data_to_tos=True
33+
)
3234
assert actual_message == expected_message
3335

3436
message = ["test message 1", "test message 2"]
@@ -42,7 +44,9 @@ def _test_convert_messages(runner):
4244
role="user",
4345
),
4446
]
45-
actual_message = runner._convert_messages(message, session_id="test_session_id")
47+
actual_message = runner._convert_messages(
48+
message, session_id="test_session_id", upload_inline_data_to_tos=True
49+
)
4650
assert actual_message == expected_message
4751

4852

veadk/integrations/ve_tos/ve_tos.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,27 @@
1515
import os
1616
from veadk.config import getenv
1717
from veadk.utils.logger import get_logger
18-
import tos
1918
import asyncio
2019
from typing import Union
2120
from pydantic import BaseModel, Field
2221
from typing import Any
2322
from urllib.parse import urlparse
2423
from datetime import datetime
2524

25+
# Initialize logger before using it
2626
logger = get_logger(__name__)
2727

28+
# Try to import tos module, and provide helpful error message if it fails
29+
try:
30+
import tos
31+
except ImportError as e:
32+
logger.error(
33+
"Failed to import 'tos' module. Please install it using: pip install tos\n"
34+
)
35+
raise ImportError(
36+
"Missing 'tos' module. Please install it using: pip install tos\n"
37+
) from e
38+
2839

2940
class TOSConfig(BaseModel):
3041
region: str = Field(
@@ -59,10 +70,13 @@ def model_post_init(self, __context: Any) -> None:
5970
logger.info("Connected to TOS successfully.")
6071
except Exception as e:
6172
logger.error(f"Client initialization failed:{e}")
62-
return None
73+
self._client = None
6374

6475
def create_bucket(self) -> bool:
6576
"""If the bucket does not exist, create it"""
77+
if not self._client:
78+
logger.error("TOS client is not initialized")
79+
return False
6680
try:
6781
self._client.head_bucket(self.config.bucket_name)
6882
logger.info(f"Bucket {self.config.bucket_name} already exists")
@@ -76,6 +90,9 @@ def create_bucket(self) -> bool:
7690
)
7791
logger.info(f"Bucket {self.config.bucket_name} created successfully")
7892
return True
93+
else:
94+
logger.error(f"Bucket creation failed: {str(e)}")
95+
return False
7996
except Exception as e:
8097
logger.error(f"Bucket creation failed: {str(e)}")
8198
return False
@@ -103,26 +120,24 @@ def upload(
103120
data: Union[str, bytes],
104121
):
105122
if isinstance(data, str):
106-
data_type = "file"
123+
# data is a file path
124+
return asyncio.to_thread(self._do_upload_file, object_key, data)
107125
elif isinstance(data, bytes):
108-
data_type = "bytes"
126+
# data is bytes content
127+
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
109128
else:
110129
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
111130
logger.error(error_msg)
112131
raise ValueError(error_msg)
113-
if data_type == "file":
114-
return asyncio.to_thread(self._do_upload_file, object_key, data)
115-
elif data_type == "bytes":
116-
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
117132

118-
def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
133+
def _do_upload_bytes(self, object_key: str, data: bytes) -> bool:
119134
try:
120135
if not self._client:
121136
return False
122137
if not self.create_bucket():
123138
return False
124139
self._client.put_object(
125-
bucket=self.config.bucket_name, key=object_key, content=bytes
140+
bucket=self.config.bucket_name, key=object_key, content=data
126141
)
127142
logger.debug(f"Upload success, object_key: {object_key}")
128143
self._close()
@@ -152,6 +167,9 @@ def _do_upload_file(self, object_key: str, file_path: str) -> bool:
152167

153168
def download(self, object_key: str, save_path: str) -> bool:
154169
"""download image from TOS"""
170+
if not self._client:
171+
logger.error("TOS client is not initialized")
172+
return False
155173
try:
156174
object_stream = self._client.get_object(self.config.bucket_name, object_key)
157175

veadk/runner.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,24 +87,34 @@ def __init__(
8787
plugins=plugins,
8888
)
8989

90-
def _convert_messages(self, messages, session_id) -> list:
90+
def _convert_messages(
91+
self, messages, session_id, upload_inline_data_to_tos
92+
) -> list:
9193
if isinstance(messages, str):
9294
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
9395
elif isinstance(messages, MediaMessage):
9496
assert messages.media.endswith(".png"), (
9597
"The MediaMessage only supports PNG format file for now."
9698
)
9799
data = read_png_to_bytes(messages.media)
98-
99-
ve_tos = VeTOS()
100-
object_key, tos_url = ve_tos.build_tos_url(
101-
self.user_id, self.app_name, session_id, messages.media
102-
)
103-
try:
104-
asyncio.create_task(ve_tos.upload(object_key, data))
105-
except Exception as e:
106-
logger.error(f"Upload to TOS failed: {e}")
107-
tos_url = None
100+
tos_url = "<tos_url>"
101+
if upload_inline_data_to_tos:
102+
try:
103+
ve_tos = VeTOS()
104+
object_key, tos_url = ve_tos.build_tos_url(
105+
self.user_id, self.app_name, session_id, messages.media
106+
)
107+
upload_task = ve_tos.upload(object_key, data)
108+
if upload_task is not None:
109+
asyncio.create_task(upload_task)
110+
except Exception as e:
111+
logger.error(f"Upload to TOS failed: {e}")
112+
tos_url = None
113+
114+
else:
115+
logger.warning(
116+
"Loss of multimodal data may occur in the tracing process."
117+
)
108118

109119
messages = [
110120
types.Content(
@@ -124,7 +134,11 @@ def _convert_messages(self, messages, session_id) -> list:
124134
elif isinstance(messages, list):
125135
converted_messages = []
126136
for message in messages:
127-
converted_messages.extend(self._convert_messages(message, session_id))
137+
converted_messages.extend(
138+
self._convert_messages(
139+
message, session_id, upload_inline_data_to_tos
140+
)
141+
)
128142
messages = converted_messages
129143
else:
130144
raise ValueError(f"Unknown message type: {type(messages)}")
@@ -179,6 +193,7 @@ async def event_generator():
179193
print() # end with a new line
180194
except LlmCallsLimitExceededError as e:
181195
logger.warning(f"Max number of llm calls limit exceeded: {e}")
196+
final_output = ""
182197

183198
return final_output
184199

@@ -189,8 +204,11 @@ async def run(
189204
stream: bool = False,
190205
run_config: RunConfig | None = None,
191206
save_tracing_data: bool = False,
207+
upload_inline_data_to_tos: bool = False,
192208
):
193-
converted_messages: list = self._convert_messages(messages, session_id)
209+
converted_messages: list = self._convert_messages(
210+
messages, session_id, upload_inline_data_to_tos
211+
)
194212

195213
await self.short_term_memory.create_session(
196214
app_name=self.app_name, user_id=self.user_id, session_id=session_id
@@ -276,6 +294,7 @@ async def event_generator():
276294
final_output += chunk
277295
except LlmCallsLimitExceededError as e:
278296
logger.warning(f"Max number of llm calls limit exceeded: {e}")
297+
final_output = ""
279298

280299
return final_output
281300

0 commit comments

Comments
 (0)