Skip to content

Commit 9a270d1

Browse files
author
wangjiaju
committed
Add seedream-4-0
1 parent 69ddc2d commit 9a270d1

File tree

4 files changed

+316
-4
lines changed

4 files changed

+316
-4
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
17+
from google.genai import types
18+
from google.adk.tools import ToolContext
19+
from veadk.config import getenv
20+
import base64
21+
from volcenginesdkarkruntime import Ark
22+
from opentelemetry import trace
23+
import traceback
24+
from veadk.version import VERSION
25+
from opentelemetry.trace import Span
26+
from veadk.utils.logger import get_logger
27+
from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
28+
import json
29+
30+
31+
logger = get_logger(__name__)
32+
33+
client = Ark(
34+
api_key=getenv("MODEL_IMAGE_API_KEY"),
35+
base_url=getenv("MODEL_IMAGE_API_BASE"),
36+
)
37+
38+
39+
async def image_generate(
40+
tasks: list,
41+
tool_context: ToolContext,
42+
) -> Dict:
43+
"""
44+
Seedream 4.0: batch image generation via tasks.
45+
Args:
46+
tasks (list[dict]):
47+
A list of image-generation tasks. Each task is a dict.
48+
Per-task schema
49+
---------------
50+
Required:
51+
- task_type (str):
52+
One of:
53+
* "multi_image_to_group" # 多图生组图
54+
* "single_image_to_group" # 单图生组图
55+
* "text_to_group" # 文生组图
56+
* "multi_image_to_single" # 多图生单图
57+
* "single_image_to_single" # 单图生单图
58+
* "text_to_single" # 文生单图
59+
- prompt (str)
60+
Text description of the desired image(s). 中文/English 均可。
61+
若要指定生成图片的数量,请在prompt中添加"生成N张图片",其中N为具体的数字。
62+
Optional:
63+
- size (str)
64+
指定生成图像的大小,有两种用法(二选一,不可混用):
65+
方式 1:分辨率级别
66+
可选值: "1K", "2K", "4K"
67+
模型会结合 prompt 中的语义推断合适的宽高比、长宽。
68+
方式 2:具体宽高值
69+
格式: "<宽度>x<高度>",如 "2048x2048", "2384x1728"
70+
约束:
71+
* 总像素数范围: [1024x1024, 4096x4096]
72+
* 宽高比范围: [1/16, 16]
73+
推荐值:
74+
- 1:1 → 2048x2048
75+
- 4:3 → 2384x1728
76+
- 3:4 → 1728x2304
77+
- 16:9 → 2560x1440
78+
- 9:16 → 1440x2560
79+
- 3:2 → 2496x1664
80+
- 2:3 → 1664x2496
81+
- 21:9 → 3024x1296
82+
默认值: "2048x2048"
83+
- response_format (str)
84+
Return format: "url" (default, URL 24h 过期) | "b64_json".
85+
- watermark (bool)
86+
Add watermark. Default: true.
87+
- image (str | list[str]) # 仅“非文生图”需要。文生图请不要提供 image
88+
Reference image(s) as URL or Base64.
89+
* 生成“单图”的任务:传入 string(exactly 1 image)。
90+
* 生成“组图”的任务:传入 array(2–10 images)。
91+
- sequential_image_generation (str)
92+
控制是否生成“组图”。Default: "disabled".
93+
* 若要生成组图:必须设为 "auto"。
94+
- max_images (int)
95+
仅当生成组图时生效。控制模型能生成的最多张数,范围 [1, 15], 不设置默认为15。
96+
注意这个参数不等于生成的图片数量,而是模型最多能生成的图片数量。
97+
在单图组图场景最多 14;多图组图场景需满足 (len(images)+max_images ≤ 15)。
98+
Model 行为说明(如何由参数推断模式)
99+
---------------------------------
100+
1) 文生单图: 不提供 image 且 (S 未设置或 S="disabled") → 1 张图。
101+
2) 文生组图: 不提供 image 且 S="auto" → 组图,数量由 max_images 控制。
102+
3) 单图生单图: image=string 且 (S 未设置或 S="disabled") → 1 张图。
103+
4) 单图生组图: image=string 且 S="auto" → 组图,数量 ≤14。
104+
5) 多图生单图: image=array (2–10) 且 (S 未设置或 S="disabled") → 1 张图。
105+
6) 多图生组图: image=array (2–10) 且 S="auto" → 组图,需满足总数 ≤15。
106+
返回结果
107+
--------
108+
Dict with generation summary.
109+
Example:
110+
{
111+
"status": "success",
112+
"success_list": [
113+
{"image_name": "url"}
114+
],
115+
"error_list": ["image_name"]
116+
}
117+
Notes:
118+
- 组图任务必须 sequential_image_generation="auto"。
119+
- 如果想要指定生成组图的数量,请在prompt里添加数量说明,例如:"生成3张图片"。
120+
- size 推荐使用 2048x2048 或表格里的标准比例,确保生成质量。
121+
"""
122+
123+
success_list = []
124+
error_list = []
125+
126+
tracer = trace.get_tracer("gcp.vertex.agent")
127+
with tracer.start_as_current_span("call_llm") as span:
128+
input_part = {"role": "user"}
129+
output_part = {"message.role": "model"}
130+
total_tokens = 0
131+
output_tokens = 0
132+
for idx, item in enumerate(tasks):
133+
task_type = item.get("task_type", "text_to_single")
134+
prompt = item.get("prompt", "")
135+
response_format = item.get("response_format", None)
136+
size = item.get("size", None)
137+
watermark = item.get("watermark", None)
138+
image = item.get("image", None)
139+
sequential_image_generation = item.get("sequential_image_generation", None)
140+
max_images = item.get("max_images", None)
141+
142+
input_part[f"parts.{idx}.type"] = "text"
143+
input_part[f"parts.{idx}.text"] = json.dumps(item, ensure_ascii=False)
144+
inputs = {
145+
"prompt": prompt,
146+
}
147+
148+
if size:
149+
inputs["size"] = size
150+
if response_format:
151+
inputs["response_format"] = response_format
152+
if watermark:
153+
inputs["watermark"] = watermark
154+
if image:
155+
if task_type.startswith("single"):
156+
assert isinstance(image, str), (
157+
f"single_* task_type image must be str, got {type(image)}"
158+
)
159+
elif task_type.startswith("multi"):
160+
assert isinstance(image, list), (
161+
f"multi_* task_type image must be list, got {type(image)}"
162+
)
163+
assert len(image) <= 10, (
164+
f"multi_* task_type image list length must be <= 10, got {len(image)}"
165+
)
166+
167+
if sequential_image_generation:
168+
inputs["sequential_image_generation"] = sequential_image_generation
169+
170+
try:
171+
if (
172+
sequential_image_generation
173+
and sequential_image_generation == "auto"
174+
and max_images
175+
):
176+
print(f"generate multi image, max_images: {max_images}")
177+
response = client.images.generate(
178+
model=getenv("MODEL_IMAGE_NAME"),
179+
**inputs,
180+
sequential_image_generation_options=SequentialImageGenerationOptions(
181+
max_images=max_images
182+
),
183+
)
184+
else:
185+
print("generate single image")
186+
response = client.images.generate(
187+
model=getenv("MODEL_IMAGE_NAME"), **inputs
188+
)
189+
if not response.error:
190+
for i, image_data in enumerate(response.data):
191+
image_name = f"task_{idx}_image_{i}"
192+
if "error" in image_data:
193+
error_details = (
194+
f"Image {image_name} error: {image_data.error}"
195+
)
196+
logger.error(error_details)
197+
error_list.append(image_name)
198+
continue
199+
if image_data.url:
200+
image = image_data.url
201+
tool_context.state[f"{image_name}_url"] = image
202+
203+
output_part[f"message.parts.{i}.type"] = "text"
204+
output_part[f"message.parts.{i}.text"] = (
205+
f"{image_name}: {image}"
206+
)
207+
208+
else:
209+
image = image_data.b64_json
210+
image_bytes = base64.b64decode(image)
211+
212+
tool_context.state[f"{image_name}_url"] = (
213+
f"data:image/png;base64,{image}"
214+
)
215+
216+
report_artifact = types.Part.from_bytes(
217+
data=image_bytes, mime_type="image/png"
218+
)
219+
await tool_context.save_artifact(
220+
image_name, report_artifact
221+
)
222+
logger.debug(f"Image saved as ADK artifact: {image_name}")
223+
224+
total_tokens += response.usage.total_tokens
225+
output_tokens += response.usage.output_tokens
226+
success_list.append({image_name: image})
227+
else:
228+
error_details = (
229+
f"No images returned by Doubao model: {response.error}"
230+
)
231+
logger.error(error_details)
232+
error_list.append(f"task_{idx}")
233+
234+
except Exception as e:
235+
error_details = f"Error: {e}"
236+
logger.error(error_details)
237+
traceback.print_exc()
238+
error_list.append(f"task_{idx}")
239+
240+
add_span_attributes(
241+
span,
242+
tool_context,
243+
input_part=input_part,
244+
output_part=output_part,
245+
output_tokens=output_tokens,
246+
total_tokens=total_tokens,
247+
request_model=getenv("MODEL_IMAGE_NAME"),
248+
response_model=getenv("MODEL_IMAGE_NAME"),
249+
)
250+
if len(success_list) == 0:
251+
return {
252+
"status": "error",
253+
"success_list": success_list,
254+
"error_list": error_list,
255+
}
256+
else:
257+
return {
258+
"status": "success",
259+
"success_list": success_list,
260+
"error_list": error_list,
261+
}
262+
263+
264+
def add_span_attributes(
265+
span: Span,
266+
tool_context: ToolContext,
267+
input_part: dict = None,
268+
output_part: dict = None,
269+
input_tokens: int = None,
270+
output_tokens: int = None,
271+
total_tokens: int = None,
272+
request_model: str = None,
273+
response_model: str = None,
274+
):
275+
try:
276+
# common attributes
277+
app_name = tool_context._invocation_context.app_name
278+
user_id = tool_context._invocation_context.user_id
279+
agent_name = tool_context.agent_name
280+
session_id = tool_context._invocation_context.session.id
281+
span.set_attribute("gen_ai.agent.name", agent_name)
282+
span.set_attribute("openinference.instrumentation.veadk", VERSION)
283+
span.set_attribute("gen_ai.app.name", app_name)
284+
span.set_attribute("gen_ai.user.id", user_id)
285+
span.set_attribute("gen_ai.session.id", session_id)
286+
span.set_attribute("agent_name", agent_name)
287+
span.set_attribute("agent.name", agent_name)
288+
span.set_attribute("app_name", app_name)
289+
span.set_attribute("app.name", app_name)
290+
span.set_attribute("user.id", user_id)
291+
span.set_attribute("session.id", session_id)
292+
span.set_attribute("cozeloop.report.source", "veadk")
293+
294+
# llm attributes
295+
span.set_attribute("gen_ai.system", "openai")
296+
span.set_attribute("gen_ai.operation.name", "chat")
297+
if request_model:
298+
span.set_attribute("gen_ai.request.model", request_model)
299+
if response_model:
300+
span.set_attribute("gen_ai.response.model", response_model)
301+
if total_tokens:
302+
span.set_attribute("gen_ai.usage.total_tokens", total_tokens)
303+
if output_tokens:
304+
span.set_attribute("gen_ai.usage.output_tokens", output_tokens)
305+
if input_tokens:
306+
span.set_attribute("gen_ai.usage.input_tokens", input_tokens)
307+
if input_part:
308+
span.add_event("gen_ai.user.message", input_part)
309+
if output_part:
310+
span.add_event("gen_ai.choice", output_part)
311+
312+
except Exception:
313+
traceback.print_exc()

veadk/tools/builtin_tools/image_edit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def image_edit(
136136
image_bytes = base64.b64decode(image)
137137

138138
tool_context.state[f"{image_name}_url"] = (
139-
f"data:image/jpeg;base64,{image}"
139+
f"data:image/png;base64,{image}"
140140
)
141141

142142
report_artifact = types.Part.from_bytes(

veadk/tools/builtin_tools/image_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def image_generate(
137137
image_bytes = base64.b64decode(image)
138138

139139
tool_context.state[f"{image_name}_url"] = (
140-
f"data:image/jpeg;base64,{image}"
140+
f"data:image/png;base64,{image}"
141141
)
142142

143143
report_artifact = types.Part.from_bytes(

veadk/tools/builtin_tools/video_generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
202202
for idx, item in enumerate(params):
203203
input_part[f"parts.{idx}.type"] = "text"
204204
input_part[f"parts.{idx}.text"] = json.dumps(item, ensure_ascii=False)
205-
205+
total_tokens = 0
206206
for start_idx in range(0, len(params), batch_size):
207207
batch = params[start_idx : start_idx + batch_size]
208208
task_dict = {}
@@ -223,7 +223,6 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
223223
logger.error(f"Error: {e}")
224224
error_list.append(video_name)
225225

226-
total_tokens = 0
227226
while True:
228227
task_list = list(task_dict.keys())
229228
if len(task_list) == 0:

0 commit comments

Comments
 (0)