Skip to content

Commit c20ee49

Browse files
doraemonlovewangjiaju
andauthored
feat(runner): update more image formats (#234)
* Update _convert_message * Update _convert_message --------- Co-authored-by: wangjiaju <[email protected]>
1 parent c1e5984 commit c20ee49

File tree

4 files changed

+26
-22
lines changed

4 files changed

+26
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"psycopg2-binary>=2.9.10", # For PostgreSQL database (short term memory)
3636
"pymysql>=1.1.1", # For MySQL database (short term memory)
3737
"opensearch-py==2.8.0",
38+
"filetype>=1.2.0",
3839
]
3940

4041
[project.scripts]

veadk/runner.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from veadk.memory.short_term_memory import ShortTermMemory
3535
from veadk.types import MediaMessage
3636
from veadk.utils.logger import get_logger
37-
from veadk.utils.misc import formatted_timestamp, read_png_to_bytes
37+
from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
3838

3939
logger = get_logger(__name__)
4040

@@ -50,11 +50,7 @@
5050
async def pre_run_process(self, process_func, new_message, user_id, session_id):
5151
if new_message.parts:
5252
for part in new_message.parts:
53-
if (
54-
part.inline_data
55-
and part.inline_data.mime_type == "image/png"
56-
and self.upload_inline_data_to_tos
57-
):
53+
if part.inline_data and self.upload_inline_data_to_tos:
5854
await process_func(
5955
part,
6056
self.app_name,
@@ -105,9 +101,20 @@ def _convert_messages(
105101
if isinstance(messages, str):
106102
_messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
107103
elif isinstance(messages, MediaMessage):
108-
assert messages.media.endswith(".png"), (
109-
"The MediaMessage only supports PNG format file for now."
104+
import filetype
105+
106+
file_data = read_file_to_bytes(messages.media)
107+
108+
kind = filetype.guess(file_data)
109+
if kind is None:
110+
raise ValueError("Unsupported or unknown file type.")
111+
112+
mime_type = kind.mime
113+
114+
assert mime_type.startswith(("image/", "video/")), (
115+
f"Unsupported media type: {mime_type}"
110116
)
117+
111118
_messages = [
112119
types.Content(
113120
role="user",
@@ -116,8 +123,8 @@ def _convert_messages(
116123
types.Part(
117124
inline_data=Blob(
118125
display_name=messages.media,
119-
data=read_png_to_bytes(messages.media),
120-
mime_type="image/png",
126+
data=file_data,
127+
mime_type=mime_type,
121128
)
122129
),
123130
],

veadk/tools/builtin_tools/generate_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from veadk.config import getenv
2929
from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
3030
from veadk.utils.logger import get_logger
31-
from veadk.utils.misc import formatted_timestamp, read_png_to_bytes
31+
from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
3232
from veadk.version import VERSION
3333

3434
logger = get_logger(__name__)
@@ -299,7 +299,7 @@ async def image_generate(
299299
artifact=Part(
300300
inline_data=Blob(
301301
display_name=filename,
302-
data=read_png_to_bytes(image_tos_url),
302+
data=read_file_to_bytes(image_tos_url),
303303
mime_type=mimetypes.guess_type(image_tos_url)[0],
304304
)
305305
),

veadk/utils/misc.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,14 @@ def formatted_timestamp() -> str:
3636
return time.strftime("%Y%m%d%H%M%S", time.localtime())
3737

3838

39-
def read_png_to_bytes(png_path: str) -> bytes:
40-
# Determine whether it is a local file or a network file
41-
if png_path.startswith(("http://", "https://")):
42-
# Network file: Download via URL and return bytes
43-
response = requests.get(png_path)
44-
response.raise_for_status() # Check if the HTTP request is successful
39+
def read_file_to_bytes(file_path: str) -> bytes:
40+
if file_path.startswith(("http://", "https://")):
41+
response = requests.get(file_path)
42+
response.raise_for_status()
4543
return response.content
4644
else:
47-
# Local file
48-
with open(png_path, "rb") as f:
49-
data = f.read()
50-
return data
45+
with open(file_path, "rb") as f:
46+
return f.read()
5147

5248

5349
def load_module_from_file(module_name: str, file_path: str) -> types.ModuleType:

0 commit comments

Comments
 (0)