Skip to content

Commit f90171f

Browse files
doraemonlovewangjiaju
andauthored
feat(tools): support upload multimedia data to cozeloop in tools
* Add seedream-4-0 * Add default model name * upload base64 to tos --------- Co-authored-by: wangjiaju <[email protected]>
1 parent 89bd449 commit f90171f

File tree

5 files changed

+508
-72
lines changed

5 files changed

+508
-72
lines changed

veadk/consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,8 @@
6060
DEFAULT_TOS_BUCKET_NAME = "ark-tutorial"
6161

6262
DEFAULT_COZELOOP_SPACE_NAME = "VeADK Space"
63+
64+
DEFAULT_TEXT_TO_IMAGE_MODEL_NAME = "doubao-seedream-3-0-t2i-250415"
65+
DEFAULT_IMAGE_EDIT_MODEL_NAME = "doubao-seededit-3-0-i2i-250628"
66+
DEFAULT_VIDEO_MODEL_NAME = "doubao-seedance-1-0-pro-250528"
67+
DEFAULT_IMAGE_GENERATE_MODEL_NAME = "doubao-seedream-4-0-250828"
Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
17+
from google.adk.tools import ToolContext
18+
from veadk.config import getenv
19+
from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
20+
21+
import base64
22+
from volcenginesdkarkruntime import Ark
23+
from opentelemetry import trace
24+
import traceback
25+
from veadk.version import VERSION
26+
from opentelemetry.trace import Span
27+
from veadk.utils.logger import get_logger
28+
from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
29+
import json
30+
31+
32+
logger = get_logger(__name__)
33+
34+
client = Ark(
35+
api_key=getenv("MODEL_AGENT_API_KEY"),
36+
base_url=DEFAULT_MODEL_AGENT_API_BASE,
37+
)
38+
39+
40+
async def image_generate(
41+
tasks: list,
42+
tool_context: ToolContext,
43+
) -> Dict:
44+
"""
45+
Seedream 4.0: batch image generation via tasks.
46+
Args:
47+
tasks (list[dict]):
48+
A list of image-generation tasks. Each task is a dict.
49+
Per-task schema
50+
---------------
51+
Required:
52+
- task_type (str):
53+
One of:
54+
* "multi_image_to_group" # 多图生组图
55+
* "single_image_to_group" # 单图生组图
56+
* "text_to_group" # 文生组图
57+
* "multi_image_to_single" # 多图生单图
58+
* "single_image_to_single" # 单图生单图
59+
* "text_to_single" # 文生单图
60+
- prompt (str)
61+
Text description of the desired image(s). 中文/English 均可。
62+
若要指定生成图片的数量,请在prompt中添加"生成N张图片",其中N为具体的数字。
63+
Optional:
64+
- size (str)
65+
指定生成图像的大小,有两种用法(二选一,不可混用):
66+
方式 1:分辨率级别
67+
可选值: "1K", "2K", "4K"
68+
模型会结合 prompt 中的语义推断合适的宽高比、长宽。
69+
方式 2:具体宽高值
70+
格式: "<宽度>x<高度>",如 "2048x2048", "2384x1728"
71+
约束:
72+
* 总像素数范围: [1024x1024, 4096x4096]
73+
* 宽高比范围: [1/16, 16]
74+
推荐值:
75+
- 1:1 → 2048x2048
76+
- 4:3 → 2384x1728
77+
- 3:4 → 1728x2304
78+
- 16:9 → 2560x1440
79+
- 9:16 → 1440x2560
80+
- 3:2 → 2496x1664
81+
- 2:3 → 1664x2496
82+
- 21:9 → 3024x1296
83+
默认值: "2048x2048"
84+
- response_format (str)
85+
Return format: "url" (default, URL 24h 过期) | "b64_json".
86+
- watermark (bool)
87+
Add watermark. Default: true.
88+
- image (str | list[str]) # 仅“非文生图”需要。文生图请不要提供 image
89+
Reference image(s) as URL or Base64.
90+
* 生成“单图”的任务:传入 string(exactly 1 image)。
91+
* 生成“组图”的任务:传入 array(2–10 images)。
92+
- sequential_image_generation (str)
93+
控制是否生成“组图”。Default: "disabled".
94+
* 若要生成组图:必须设为 "auto"。
95+
- max_images (int)
96+
仅当生成组图时生效。控制模型能生成的最多张数,范围 [1, 15], 不设置默认为15。
97+
注意这个参数不等于生成的图片数量,而是模型最多能生成的图片数量。
98+
在单图组图场景最多 14;多图组图场景需满足 (len(images)+max_images ≤ 15)。
99+
Model 行为说明(如何由参数推断模式)
100+
---------------------------------
101+
1) 文生单图: 不提供 image 且 (S 未设置或 S="disabled") → 1 张图。
102+
2) 文生组图: 不提供 image 且 S="auto" → 组图,数量由 max_images 控制。
103+
3) 单图生单图: image=string 且 (S 未设置或 S="disabled") → 1 张图。
104+
4) 单图生组图: image=string 且 S="auto" → 组图,数量 ≤14。
105+
5) 多图生单图: image=array (2–10) 且 (S 未设置或 S="disabled") → 1 张图。
106+
6) 多图生组图: image=array (2–10) 且 S="auto" → 组图,需满足总数 ≤15。
107+
返回结果
108+
--------
109+
Dict with generation summary.
110+
Example:
111+
{
112+
"status": "success",
113+
"success_list": [
114+
{"image_name": "url"}
115+
],
116+
"error_list": ["image_name"]
117+
}
118+
Notes:
119+
- 组图任务必须 sequential_image_generation="auto"。
120+
- 如果想要指定生成组图的数量,请在prompt里添加数量说明,例如:"生成3张图片"。
121+
- size 推荐使用 2048x2048 或表格里的标准比例,确保生成质量。
122+
"""
123+
124+
success_list = []
125+
error_list = []
126+
127+
for idx, item in enumerate(tasks):
128+
input_part = {"role": "user"}
129+
output_part = {"message.role": "model"}
130+
total_tokens = 0
131+
output_tokens = 0
132+
tracer = trace.get_tracer("gcp.vertex.agent")
133+
with tracer.start_as_current_span("call_llm") as span:
134+
task_type = item.get("task_type", "text_to_single")
135+
prompt = item.get("prompt", "")
136+
response_format = item.get("response_format", None)
137+
size = item.get("size", None)
138+
watermark = item.get("watermark", None)
139+
image = item.get("image", None)
140+
sequential_image_generation = item.get("sequential_image_generation", None)
141+
max_images = item.get("max_images", None)
142+
143+
input_part["parts.0.type"] = "text"
144+
input_part["parts.0.text"] = json.dumps(item, ensure_ascii=False)
145+
inputs = {
146+
"prompt": prompt,
147+
}
148+
149+
if size:
150+
inputs["size"] = size
151+
if response_format:
152+
inputs["response_format"] = response_format
153+
if watermark:
154+
inputs["watermark"] = watermark
155+
if image:
156+
if task_type.startswith("single"):
157+
assert isinstance(image, str), (
158+
f"single_* task_type image must be str, got {type(image)}"
159+
)
160+
input_part["parts.1.type"] = "image_url"
161+
input_part["parts.1.image_url.name"] = "origin_image"
162+
input_part["parts.1.image_url.url"] = image
163+
elif task_type.startswith("multi"):
164+
assert isinstance(image, list), (
165+
f"multi_* task_type image must be list, got {type(image)}"
166+
)
167+
assert len(image) <= 10, (
168+
f"multi_* task_type image list length must be <= 10, got {len(image)}"
169+
)
170+
for i, image_url in enumerate(image):
171+
input_part[f"parts.{i + 1}.type"] = "image_url"
172+
input_part[f"parts.{i + 1}.image_url.name"] = (
173+
f"origin_image_{i}"
174+
)
175+
input_part[f"parts.{i + 1}.image_url.url"] = image_url
176+
177+
if sequential_image_generation:
178+
inputs["sequential_image_generation"] = sequential_image_generation
179+
180+
try:
181+
if (
182+
sequential_image_generation
183+
and sequential_image_generation == "auto"
184+
and max_images
185+
):
186+
response = client.images.generate(
187+
model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
188+
**inputs,
189+
sequential_image_generation_options=SequentialImageGenerationOptions(
190+
max_images=max_images
191+
),
192+
)
193+
else:
194+
response = client.images.generate(
195+
model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, **inputs
196+
)
197+
if not response.error:
198+
for i, image_data in enumerate(response.data):
199+
image_name = f"task_{idx}_image_{i}"
200+
if "error" in image_data:
201+
error_details = (
202+
f"Image {image_name} error: {image_data.error}"
203+
)
204+
logger.error(error_details)
205+
error_list.append(image_name)
206+
continue
207+
if image_data.url:
208+
image = image_data.url
209+
tool_context.state[f"{image_name}_url"] = image
210+
211+
output_part[f"message.parts.{i}.type"] = "image_url"
212+
output_part[f"message.parts.{i}.image_url.name"] = (
213+
image_name
214+
)
215+
output_part[f"message.parts.{i}.image_url.url"] = image
216+
217+
else:
218+
image = image_data.b64_json
219+
image_bytes = base64.b64decode(image)
220+
221+
tos_url = _upload_image_to_tos(
222+
image_bytes=image_bytes, object_key=f"{image_name}.png"
223+
)
224+
if tos_url:
225+
tool_context.state[f"{image_name}_url"] = tos_url
226+
image = tos_url
227+
output_part[f"message.parts.{i}.type"] = "image_url"
228+
output_part[f"message.parts.{i}.image_url.name"] = (
229+
image_name
230+
)
231+
output_part[f"message.parts.{i}.image_url.url"] = image
232+
else:
233+
logger.error(
234+
f"Upload image to TOS failed: {image_name}"
235+
)
236+
error_list.append(image_name)
237+
continue
238+
239+
logger.debug(f"Image saved as ADK artifact: {image_name}")
240+
241+
total_tokens += response.usage.total_tokens
242+
output_tokens += response.usage.output_tokens
243+
success_list.append({image_name: image})
244+
else:
245+
error_details = (
246+
f"No images returned by Doubao model: {response.error}"
247+
)
248+
logger.error(error_details)
249+
error_list.append(f"task_{idx}")
250+
251+
except Exception as e:
252+
error_details = f"Error: {e}"
253+
logger.error(error_details)
254+
traceback.print_exc()
255+
error_list.append(f"task_{idx}")
256+
257+
add_span_attributes(
258+
span,
259+
tool_context,
260+
input_part=input_part,
261+
output_part=output_part,
262+
output_tokens=output_tokens,
263+
total_tokens=total_tokens,
264+
request_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
265+
response_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
266+
)
267+
if len(success_list) == 0:
268+
return {
269+
"status": "error",
270+
"success_list": success_list,
271+
"error_list": error_list,
272+
}
273+
else:
274+
return {
275+
"status": "success",
276+
"success_list": success_list,
277+
"error_list": error_list,
278+
}
279+
280+
281+
def add_span_attributes(
282+
span: Span,
283+
tool_context: ToolContext,
284+
input_part: dict = None,
285+
output_part: dict = None,
286+
input_tokens: int = None,
287+
output_tokens: int = None,
288+
total_tokens: int = None,
289+
request_model: str = None,
290+
response_model: str = None,
291+
):
292+
try:
293+
# common attributes
294+
app_name = tool_context._invocation_context.app_name
295+
user_id = tool_context._invocation_context.user_id
296+
agent_name = tool_context.agent_name
297+
session_id = tool_context._invocation_context.session.id
298+
span.set_attribute("gen_ai.agent.name", agent_name)
299+
span.set_attribute("openinference.instrumentation.veadk", VERSION)
300+
span.set_attribute("gen_ai.app.name", app_name)
301+
span.set_attribute("gen_ai.user.id", user_id)
302+
span.set_attribute("gen_ai.session.id", session_id)
303+
span.set_attribute("agent_name", agent_name)
304+
span.set_attribute("agent.name", agent_name)
305+
span.set_attribute("app_name", app_name)
306+
span.set_attribute("app.name", app_name)
307+
span.set_attribute("user.id", user_id)
308+
span.set_attribute("session.id", session_id)
309+
span.set_attribute("cozeloop.report.source", "veadk")
310+
311+
# llm attributes
312+
span.set_attribute("gen_ai.system", "openai")
313+
span.set_attribute("gen_ai.operation.name", "chat")
314+
if request_model:
315+
span.set_attribute("gen_ai.request.model", request_model)
316+
if response_model:
317+
span.set_attribute("gen_ai.response.model", response_model)
318+
if total_tokens:
319+
span.set_attribute("gen_ai.usage.total_tokens", total_tokens)
320+
if output_tokens:
321+
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
322+
if input_tokens:
323+
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
324+
if input_part:
325+
span.add_event("gen_ai.user.message", input_part)
326+
if output_part:
327+
span.add_event("gen_ai.choice", output_part)
328+
329+
except Exception:
330+
traceback.print_exc()
331+
332+
333+
def _upload_image_to_tos(image_bytes: bytes, object_key: str) -> None:
334+
try:
335+
from veadk.integrations.ve_tos.ve_tos import VeTOS
336+
import os
337+
from datetime import datetime
338+
339+
timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
340+
object_key = f"{timestamp}-{object_key}"
341+
bucket_name = os.getenv("DATABASE_TOS_BUCKET")
342+
ve_tos = VeTOS()
343+
344+
tos_url = ve_tos.build_tos_signed_url(
345+
object_key=object_key, bucket_name=bucket_name
346+
)
347+
348+
ve_tos.upload_bytes(
349+
data=image_bytes, object_key=object_key, bucket_name=bucket_name
350+
)
351+
352+
return tos_url
353+
except Exception as e:
354+
logger.error(f"Upload to TOS failed: {e}")
355+
return None

0 commit comments

Comments
 (0)