Skip to content

Commit 8eb63bd

Browse files
authored
fix(integrations): tos judgment (#125)
* add judgment for tos * fix judgment for tos * fix judgment for tos
1 parent 91c3329 commit 8eb63bd

File tree

4 files changed

+130
-31
lines changed

4 files changed

+130
-31
lines changed

tests/test_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def _test_convert_messages(runner):
2828
role="user",
2929
)
3030
]
31-
actual_message = runner._convert_messages(message, session_id="test_session_id")
31+
actual_message = runner._convert_messages(
32+
message, session_id="test_session_id", upload_inline_data_to_tos=True
33+
)
3234
assert actual_message == expected_message
3335

3436
message = ["test message 1", "test message 2"]
@@ -42,7 +44,9 @@ def _test_convert_messages(runner):
4244
role="user",
4345
),
4446
]
45-
actual_message = runner._convert_messages(message, session_id="test_session_id")
47+
actual_message = runner._convert_messages(
48+
message, session_id="test_session_id", upload_inline_data_to_tos=True
49+
)
4650
assert actual_message == expected_message
4751

4852

tests/test_tos.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,29 @@
1414

1515
import pytest
1616
from unittest import mock
17-
import veadk.integrations.ve_tos.ve_tos as tos_mod
17+
18+
# Check if tos module is available
19+
import importlib
20+
21+
TOS_AVAILABLE = False
22+
try:
23+
importlib.import_module("veadk.integrations.ve_tos.ve_tos")
24+
TOS_AVAILABLE = True
25+
except ImportError:
26+
pass
27+
28+
# Skip tests that require tos module if it's not available
29+
require_tos = pytest.mark.skipif(not TOS_AVAILABLE, reason="tos module not available")
1830

1931
# 使用 pytest-asyncio
2032
pytest_plugins = ("pytest_asyncio",)
2133

2234

2335
@pytest.fixture
36+
@require_tos
2437
def mock_client(monkeypatch):
38+
import veadk.integrations.ve_tos.ve_tos as tos_mod
39+
2540
fake_client = mock.Mock()
2641

2742
monkeypatch.setenv("DATABASE_TOS_REGION", "test-region")
@@ -33,9 +48,17 @@ def mock_client(monkeypatch):
3348

3449
class FakeExceptions:
3550
class TosServerError(Exception):
36-
def __init__(self, msg):
51+
def __init__(
52+
self,
53+
msg: str,
54+
code: int = 0,
55+
host_id: str = "",
56+
resource: str = "",
57+
request_id: str = "",
58+
header=None,
59+
):
3760
super().__init__(msg)
38-
self.status_code = None
61+
self.status_code = code
3962

4063
monkeypatch.setattr(tos_mod.tos, "exceptions", FakeExceptions)
4164
monkeypatch.setattr(
@@ -51,27 +74,34 @@ def __init__(self, msg):
5174

5275

5376
@pytest.fixture
77+
@require_tos
5478
def tos_client(mock_client):
79+
import veadk.integrations.ve_tos.ve_tos as tos_mod
80+
5581
return tos_mod.VeTOS()
5682

5783

84+
@require_tos
5885
def test_create_bucket_exists(tos_client, mock_client):
5986
mock_client.head_bucket.return_value = None # head_bucket 正常返回表示存在
6087
result = tos_client.create_bucket()
6188
assert result is True
6289
mock_client.create_bucket.assert_not_called()
6390

6491

92+
@require_tos
6593
def test_create_bucket_not_exists(tos_client, mock_client):
66-
exc = tos_mod.tos.exceptions.TosServerError("not found")
67-
exc.status_code = 404
94+
import veadk.integrations.ve_tos.ve_tos as tos_mod
95+
96+
exc = tos_mod.tos.exceptions.TosServerError(msg="not found", code=404)
6897
mock_client.head_bucket.side_effect = exc
6998

7099
result = tos_client.create_bucket()
71100
assert result is True
72101
mock_client.create_bucket.assert_called_once()
73102

74103

104+
@require_tos
75105
@pytest.mark.asyncio
76106
async def test_upload_bytes_success(tos_client, mock_client):
77107
mock_client.head_bucket.return_value = True
@@ -83,6 +113,7 @@ async def test_upload_bytes_success(tos_client, mock_client):
83113
mock_client.close.assert_called_once()
84114

85115

116+
@require_tos
86117
@pytest.mark.asyncio
87118
async def test_upload_file_success(tmp_path, tos_client, mock_client):
88119
mock_client.head_bucket.return_value = True
@@ -95,6 +126,7 @@ async def test_upload_file_success(tmp_path, tos_client, mock_client):
95126
mock_client.close.assert_called_once()
96127

97128

129+
@require_tos
98130
def test_download_success(tmp_path, tos_client, mock_client):
99131
save_path = tmp_path / "out.txt"
100132
mock_client.get_object.return_value = [b"abc", b"def"]
@@ -104,7 +136,32 @@ def test_download_success(tmp_path, tos_client, mock_client):
104136
assert save_path.read_bytes() == b"abcdef"
105137

106138

139+
@require_tos
107140
def test_download_fail(tos_client, mock_client):
108141
mock_client.get_object.side_effect = Exception("boom")
109142
result = tos_client.download("obj-key", "somewhere.txt")
110143
assert result is False
144+
145+
146+
@require_tos
147+
@pytest.mark.skipif(TOS_AVAILABLE, reason="tos module is available")
148+
def test_tos_import_error():
149+
"""Test VeTOS behavior when tos module is not installed"""
150+
# Remove tos from sys.modules to simulate it's not installed
151+
import sys
152+
153+
original_tos = sys.modules.get("tos")
154+
if "tos" in sys.modules:
155+
del sys.modules["tos"]
156+
157+
try:
158+
# Try to import ve_tos module, which should raise ImportError
159+
with pytest.raises(ImportError) as exc_info:
160+
pass
161+
162+
# Check that the error message contains installation instructions
163+
assert "pip install tos" in str(exc_info.value)
164+
finally:
165+
# Restore original state
166+
if original_tos is not None:
167+
sys.modules["tos"] = original_tos

veadk/integrations/ve_tos/ve_tos.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,27 @@
1515
import os
1616
from veadk.config import getenv
1717
from veadk.utils.logger import get_logger
18-
import tos
1918
import asyncio
2019
from typing import Union
2120
from pydantic import BaseModel, Field
2221
from typing import Any
2322
from urllib.parse import urlparse
2423
from datetime import datetime
2524

25+
# Initialize logger before using it
2626
logger = get_logger(__name__)
2727

28+
# Try to import tos module, and provide helpful error message if it fails
29+
try:
30+
import tos
31+
except ImportError as e:
32+
logger.error(
33+
"Failed to import 'tos' module. Please install it using: pip install tos\n"
34+
)
35+
raise ImportError(
36+
"Missing 'tos' module. Please install it using: pip install tos\n"
37+
) from e
38+
2839

2940
class TOSConfig(BaseModel):
3041
region: str = Field(
@@ -59,10 +70,13 @@ def model_post_init(self, __context: Any) -> None:
5970
logger.info("Connected to TOS successfully.")
6071
except Exception as e:
6172
logger.error(f"Client initialization failed:{e}")
62-
return None
73+
self._client = None
6374

6475
def create_bucket(self) -> bool:
6576
"""If the bucket does not exist, create it"""
77+
if not self._client:
78+
logger.error("TOS client is not initialized")
79+
return False
6680
try:
6781
self._client.head_bucket(self.config.bucket_name)
6882
logger.info(f"Bucket {self.config.bucket_name} already exists")
@@ -76,6 +90,9 @@ def create_bucket(self) -> bool:
7690
)
7791
logger.info(f"Bucket {self.config.bucket_name} created successfully")
7892
return True
93+
else:
94+
logger.error(f"Bucket creation failed: {str(e)}")
95+
return False
7996
except Exception as e:
8097
logger.error(f"Bucket creation failed: {str(e)}")
8198
return False
@@ -103,26 +120,24 @@ def upload(
103120
data: Union[str, bytes],
104121
):
105122
if isinstance(data, str):
106-
data_type = "file"
123+
# data is a file path
124+
return asyncio.to_thread(self._do_upload_file, object_key, data)
107125
elif isinstance(data, bytes):
108-
data_type = "bytes"
126+
# data is bytes content
127+
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
109128
else:
110129
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
111130
logger.error(error_msg)
112131
raise ValueError(error_msg)
113-
if data_type == "file":
114-
return asyncio.to_thread(self._do_upload_file, object_key, data)
115-
elif data_type == "bytes":
116-
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
117132

118-
def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
133+
def _do_upload_bytes(self, object_key: str, data: bytes) -> bool:
119134
try:
120135
if not self._client:
121136
return False
122137
if not self.create_bucket():
123138
return False
124139
self._client.put_object(
125-
bucket=self.config.bucket_name, key=object_key, content=bytes
140+
bucket=self.config.bucket_name, key=object_key, content=data
126141
)
127142
logger.debug(f"Upload success, object_key: {object_key}")
128143
self._close()
@@ -152,6 +167,9 @@ def _do_upload_file(self, object_key: str, file_path: str) -> bool:
152167

153168
def download(self, object_key: str, save_path: str) -> bool:
154169
"""download image from TOS"""
170+
if not self._client:
171+
logger.error("TOS client is not initialized")
172+
return False
155173
try:
156174
object_stream = self._client.get_object(self.config.bucket_name, object_key)
157175

veadk/runner.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from veadk.agents.sequential_agent import SequentialAgent
3030
from veadk.config import getenv
3131
from veadk.evaluation import EvalSetRecorder
32-
from veadk.integrations.ve_tos.ve_tos import VeTOS
3332
from veadk.memory.short_term_memory import ShortTermMemory
3433
from veadk.types import MediaMessage
3534
from veadk.utils.logger import get_logger
@@ -87,24 +86,36 @@ def __init__(
8786
plugins=plugins,
8887
)
8988

90-
def _convert_messages(self, messages, session_id) -> list:
89+
def _convert_messages(
90+
self, messages, session_id, upload_inline_data_to_tos
91+
) -> list:
9192
if isinstance(messages, str):
9293
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
9394
elif isinstance(messages, MediaMessage):
9495
assert messages.media.endswith(".png"), (
9596
"The MediaMessage only supports PNG format file for now."
9697
)
9798
data = read_png_to_bytes(messages.media)
98-
99-
ve_tos = VeTOS()
100-
object_key, tos_url = ve_tos.build_tos_url(
101-
self.user_id, self.app_name, session_id, messages.media
102-
)
103-
try:
104-
asyncio.create_task(ve_tos.upload(object_key, data))
105-
except Exception as e:
106-
logger.error(f"Upload to TOS failed: {e}")
107-
tos_url = None
99+
tos_url = "<tos_url>"
100+
if upload_inline_data_to_tos:
101+
try:
102+
from veadk.integrations.ve_tos.ve_tos import VeTOS
103+
104+
ve_tos = VeTOS()
105+
object_key, tos_url = ve_tos.build_tos_url(
106+
self.user_id, self.app_name, session_id, messages.media
107+
)
108+
upload_task = ve_tos.upload(object_key, data)
109+
if upload_task is not None:
110+
asyncio.create_task(upload_task)
111+
except Exception as e:
112+
logger.error(f"Upload to TOS failed: {e}")
113+
tos_url = None
114+
115+
else:
116+
logger.warning(
117+
"Loss of multimodal data may occur in the tracing process."
118+
)
108119

109120
messages = [
110121
types.Content(
@@ -124,7 +135,11 @@ def _convert_messages(self, messages, session_id) -> list:
124135
elif isinstance(messages, list):
125136
converted_messages = []
126137
for message in messages:
127-
converted_messages.extend(self._convert_messages(message, session_id))
138+
converted_messages.extend(
139+
self._convert_messages(
140+
message, session_id, upload_inline_data_to_tos
141+
)
142+
)
128143
messages = converted_messages
129144
else:
130145
raise ValueError(f"Unknown message type: {type(messages)}")
@@ -179,6 +194,7 @@ async def event_generator():
179194
print() # end with a new line
180195
except LlmCallsLimitExceededError as e:
181196
logger.warning(f"Max number of llm calls limit exceeded: {e}")
197+
final_output = ""
182198

183199
return final_output
184200

@@ -189,8 +205,11 @@ async def run(
189205
stream: bool = False,
190206
run_config: RunConfig | None = None,
191207
save_tracing_data: bool = False,
208+
upload_inline_data_to_tos: bool = False,
192209
):
193-
converted_messages: list = self._convert_messages(messages, session_id)
210+
converted_messages: list = self._convert_messages(
211+
messages, session_id, upload_inline_data_to_tos
212+
)
194213

195214
await self.short_term_memory.create_session(
196215
app_name=self.app_name, user_id=self.user_id, session_id=session_id
@@ -276,6 +295,7 @@ async def event_generator():
276295
final_output += chunk
277296
except LlmCallsLimitExceededError as e:
278297
logger.warning(f"Max number of llm calls limit exceeded: {e}")
298+
final_output = ""
279299

280300
return final_output
281301

0 commit comments

Comments
 (0)