|
| 1 | +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Dict |
| 16 | + |
| 17 | +from google.adk.tools import ToolContext |
| 18 | +from veadk.config import getenv |
| 19 | +from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE |
| 20 | + |
| 21 | +import base64 |
| 22 | +from volcenginesdkarkruntime import Ark |
| 23 | +from opentelemetry import trace |
| 24 | +import traceback |
| 25 | +from veadk.version import VERSION |
| 26 | +from opentelemetry.trace import Span |
| 27 | +from veadk.utils.logger import get_logger |
| 28 | +from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions |
| 29 | +import json |
| 30 | + |
| 31 | + |
| 32 | +logger = get_logger(__name__) |
| 33 | + |
| 34 | +client = Ark( |
| 35 | + api_key=getenv("MODEL_AGENT_API_KEY"), |
| 36 | + base_url=DEFAULT_MODEL_AGENT_API_BASE, |
| 37 | +) |
| 38 | + |
| 39 | + |
| 40 | +async def image_generate( |
| 41 | + tasks: list, |
| 42 | + tool_context: ToolContext, |
| 43 | +) -> Dict: |
| 44 | + """ |
| 45 | + Seedream 4.0: batch image generation via tasks. |
| 46 | + Args: |
| 47 | + tasks (list[dict]): |
| 48 | + A list of image-generation tasks. Each task is a dict. |
| 49 | + Per-task schema |
| 50 | + --------------- |
| 51 | + Required: |
| 52 | + - task_type (str): |
| 53 | + One of: |
| 54 | + * "multi_image_to_group" # 多图生组图 |
| 55 | + * "single_image_to_group" # 单图生组图 |
| 56 | + * "text_to_group" # 文生组图 |
| 57 | + * "multi_image_to_single" # 多图生单图 |
| 58 | + * "single_image_to_single" # 单图生单图 |
| 59 | + * "text_to_single" # 文生单图 |
| 60 | + - prompt (str) |
| 61 | + Text description of the desired image(s). 中文/English 均可。 |
| 62 | + 若要指定生成图片的数量,请在prompt中添加"生成N张图片",其中N为具体的数字。 |
| 63 | + Optional: |
| 64 | + - size (str) |
| 65 | + 指定生成图像的大小,有两种用法(二选一,不可混用): |
| 66 | + 方式 1:分辨率级别 |
| 67 | + 可选值: "1K", "2K", "4K" |
| 68 | + 模型会结合 prompt 中的语义推断合适的宽高比、长宽。 |
| 69 | + 方式 2:具体宽高值 |
| 70 | + 格式: "<宽度>x<高度>",如 "2048x2048", "2384x1728" |
| 71 | + 约束: |
| 72 | + * 总像素数范围: [1024x1024, 4096x4096] |
| 73 | + * 宽高比范围: [1/16, 16] |
| 74 | + 推荐值: |
| 75 | + - 1:1 → 2048x2048 |
| 76 | + - 4:3 → 2384x1728 |
| 77 | + - 3:4 → 1728x2304 |
| 78 | + - 16:9 → 2560x1440 |
| 79 | + - 9:16 → 1440x2560 |
| 80 | + - 3:2 → 2496x1664 |
| 81 | + - 2:3 → 1664x2496 |
| 82 | + - 21:9 → 3024x1296 |
| 83 | + 默认值: "2048x2048" |
| 84 | + - response_format (str) |
| 85 | + Return format: "url" (default, URL 24h 过期) | "b64_json". |
| 86 | + - watermark (bool) |
| 87 | + Add watermark. Default: true. |
| 88 | + - image (str | list[str]) # 仅“非文生图”需要。文生图请不要提供 image |
| 89 | + Reference image(s) as URL or Base64. |
| 90 | + * 生成“单图”的任务:传入 string(exactly 1 image)。 |
| 91 | + * 生成“组图”的任务:传入 array(2–10 images)。 |
| 92 | + - sequential_image_generation (str) |
| 93 | + 控制是否生成“组图”。Default: "disabled". |
| 94 | + * 若要生成组图:必须设为 "auto"。 |
| 95 | + - max_images (int) |
| 96 | + 仅当生成组图时生效。控制模型能生成的最多张数,范围 [1, 15], 不设置默认为15。 |
| 97 | + 注意这个参数不等于生成的图片数量,而是模型最多能生成的图片数量。 |
| 98 | + 在单图组图场景最多 14;多图组图场景需满足 (len(images)+max_images ≤ 15)。 |
| 99 | + Model 行为说明(如何由参数推断模式) |
| 100 | + --------------------------------- |
| 101 | + 1) 文生单图: 不提供 image 且 (S 未设置或 S="disabled") → 1 张图。 |
| 102 | + 2) 文生组图: 不提供 image 且 S="auto" → 组图,数量由 max_images 控制。 |
| 103 | + 3) 单图生单图: image=string 且 (S 未设置或 S="disabled") → 1 张图。 |
| 104 | + 4) 单图生组图: image=string 且 S="auto" → 组图,数量 ≤14。 |
| 105 | + 5) 多图生单图: image=array (2–10) 且 (S 未设置或 S="disabled") → 1 张图。 |
| 106 | + 6) 多图生组图: image=array (2–10) 且 S="auto" → 组图,需满足总数 ≤15。 |
| 107 | + 返回结果 |
| 108 | + -------- |
| 109 | + Dict with generation summary. |
| 110 | + Example: |
| 111 | + { |
| 112 | + "status": "success", |
| 113 | + "success_list": [ |
| 114 | + {"image_name": "url"} |
| 115 | + ], |
| 116 | + "error_list": ["image_name"] |
| 117 | + } |
| 118 | + Notes: |
| 119 | + - 组图任务必须 sequential_image_generation="auto"。 |
| 120 | + - 如果想要指定生成组图的数量,请在prompt里添加数量说明,例如:"生成3张图片"。 |
| 121 | + - size 推荐使用 2048x2048 或表格里的标准比例,确保生成质量。 |
| 122 | + """ |
| 123 | + |
| 124 | + success_list = [] |
| 125 | + error_list = [] |
| 126 | + |
| 127 | + for idx, item in enumerate(tasks): |
| 128 | + input_part = {"role": "user"} |
| 129 | + output_part = {"message.role": "model"} |
| 130 | + total_tokens = 0 |
| 131 | + output_tokens = 0 |
| 132 | + tracer = trace.get_tracer("gcp.vertex.agent") |
| 133 | + with tracer.start_as_current_span("call_llm") as span: |
| 134 | + task_type = item.get("task_type", "text_to_single") |
| 135 | + prompt = item.get("prompt", "") |
| 136 | + response_format = item.get("response_format", None) |
| 137 | + size = item.get("size", None) |
| 138 | + watermark = item.get("watermark", None) |
| 139 | + image = item.get("image", None) |
| 140 | + sequential_image_generation = item.get("sequential_image_generation", None) |
| 141 | + max_images = item.get("max_images", None) |
| 142 | + |
| 143 | + input_part["parts.0.type"] = "text" |
| 144 | + input_part["parts.0.text"] = json.dumps(item, ensure_ascii=False) |
| 145 | + inputs = { |
| 146 | + "prompt": prompt, |
| 147 | + } |
| 148 | + |
| 149 | + if size: |
| 150 | + inputs["size"] = size |
| 151 | + if response_format: |
| 152 | + inputs["response_format"] = response_format |
| 153 | + if watermark: |
| 154 | + inputs["watermark"] = watermark |
| 155 | + if image: |
| 156 | + if task_type.startswith("single"): |
| 157 | + assert isinstance(image, str), ( |
| 158 | + f"single_* task_type image must be str, got {type(image)}" |
| 159 | + ) |
| 160 | + input_part["parts.1.type"] = "image_url" |
| 161 | + input_part["parts.1.image_url.name"] = "origin_image" |
| 162 | + input_part["parts.1.image_url.url"] = image |
| 163 | + elif task_type.startswith("multi"): |
| 164 | + assert isinstance(image, list), ( |
| 165 | + f"multi_* task_type image must be list, got {type(image)}" |
| 166 | + ) |
| 167 | + assert len(image) <= 10, ( |
| 168 | + f"multi_* task_type image list length must be <= 10, got {len(image)}" |
| 169 | + ) |
| 170 | + for i, image_url in enumerate(image): |
| 171 | + input_part[f"parts.{i + 1}.type"] = "image_url" |
| 172 | + input_part[f"parts.{i + 1}.image_url.name"] = ( |
| 173 | + f"origin_image_{i}" |
| 174 | + ) |
| 175 | + input_part[f"parts.{i + 1}.image_url.url"] = image_url |
| 176 | + |
| 177 | + if sequential_image_generation: |
| 178 | + inputs["sequential_image_generation"] = sequential_image_generation |
| 179 | + |
| 180 | + try: |
| 181 | + if ( |
| 182 | + sequential_image_generation |
| 183 | + and sequential_image_generation == "auto" |
| 184 | + and max_images |
| 185 | + ): |
| 186 | + response = client.images.generate( |
| 187 | + model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, |
| 188 | + **inputs, |
| 189 | + sequential_image_generation_options=SequentialImageGenerationOptions( |
| 190 | + max_images=max_images |
| 191 | + ), |
| 192 | + ) |
| 193 | + else: |
| 194 | + response = client.images.generate( |
| 195 | + model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, **inputs |
| 196 | + ) |
| 197 | + if not response.error: |
| 198 | + for i, image_data in enumerate(response.data): |
| 199 | + image_name = f"task_{idx}_image_{i}" |
| 200 | + if "error" in image_data: |
| 201 | + error_details = ( |
| 202 | + f"Image {image_name} error: {image_data.error}" |
| 203 | + ) |
| 204 | + logger.error(error_details) |
| 205 | + error_list.append(image_name) |
| 206 | + continue |
| 207 | + if image_data.url: |
| 208 | + image = image_data.url |
| 209 | + tool_context.state[f"{image_name}_url"] = image |
| 210 | + |
| 211 | + output_part[f"message.parts.{i}.type"] = "image_url" |
| 212 | + output_part[f"message.parts.{i}.image_url.name"] = ( |
| 213 | + image_name |
| 214 | + ) |
| 215 | + output_part[f"message.parts.{i}.image_url.url"] = image |
| 216 | + |
| 217 | + else: |
| 218 | + image = image_data.b64_json |
| 219 | + image_bytes = base64.b64decode(image) |
| 220 | + |
| 221 | + tos_url = _upload_image_to_tos( |
| 222 | + image_bytes=image_bytes, object_key=f"{image_name}.png" |
| 223 | + ) |
| 224 | + if tos_url: |
| 225 | + tool_context.state[f"{image_name}_url"] = tos_url |
| 226 | + image = tos_url |
| 227 | + output_part[f"message.parts.{i}.type"] = "image_url" |
| 228 | + output_part[f"message.parts.{i}.image_url.name"] = ( |
| 229 | + image_name |
| 230 | + ) |
| 231 | + output_part[f"message.parts.{i}.image_url.url"] = image |
| 232 | + else: |
| 233 | + logger.error( |
| 234 | + f"Upload image to TOS failed: {image_name}" |
| 235 | + ) |
| 236 | + error_list.append(image_name) |
| 237 | + continue |
| 238 | + |
| 239 | + logger.debug(f"Image saved as ADK artifact: {image_name}") |
| 240 | + |
| 241 | + total_tokens += response.usage.total_tokens |
| 242 | + output_tokens += response.usage.output_tokens |
| 243 | + success_list.append({image_name: image}) |
| 244 | + else: |
| 245 | + error_details = ( |
| 246 | + f"No images returned by Doubao model: {response.error}" |
| 247 | + ) |
| 248 | + logger.error(error_details) |
| 249 | + error_list.append(f"task_{idx}") |
| 250 | + |
| 251 | + except Exception as e: |
| 252 | + error_details = f"Error: {e}" |
| 253 | + logger.error(error_details) |
| 254 | + traceback.print_exc() |
| 255 | + error_list.append(f"task_{idx}") |
| 256 | + |
| 257 | + add_span_attributes( |
| 258 | + span, |
| 259 | + tool_context, |
| 260 | + input_part=input_part, |
| 261 | + output_part=output_part, |
| 262 | + output_tokens=output_tokens, |
| 263 | + total_tokens=total_tokens, |
| 264 | + request_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, |
| 265 | + response_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, |
| 266 | + ) |
| 267 | + if len(success_list) == 0: |
| 268 | + return { |
| 269 | + "status": "error", |
| 270 | + "success_list": success_list, |
| 271 | + "error_list": error_list, |
| 272 | + } |
| 273 | + else: |
| 274 | + return { |
| 275 | + "status": "success", |
| 276 | + "success_list": success_list, |
| 277 | + "error_list": error_list, |
| 278 | + } |
| 279 | + |
| 280 | + |
| 281 | +def add_span_attributes( |
| 282 | + span: Span, |
| 283 | + tool_context: ToolContext, |
| 284 | + input_part: dict = None, |
| 285 | + output_part: dict = None, |
| 286 | + input_tokens: int = None, |
| 287 | + output_tokens: int = None, |
| 288 | + total_tokens: int = None, |
| 289 | + request_model: str = None, |
| 290 | + response_model: str = None, |
| 291 | +): |
| 292 | + try: |
| 293 | + # common attributes |
| 294 | + app_name = tool_context._invocation_context.app_name |
| 295 | + user_id = tool_context._invocation_context.user_id |
| 296 | + agent_name = tool_context.agent_name |
| 297 | + session_id = tool_context._invocation_context.session.id |
| 298 | + span.set_attribute("gen_ai.agent.name", agent_name) |
| 299 | + span.set_attribute("openinference.instrumentation.veadk", VERSION) |
| 300 | + span.set_attribute("gen_ai.app.name", app_name) |
| 301 | + span.set_attribute("gen_ai.user.id", user_id) |
| 302 | + span.set_attribute("gen_ai.session.id", session_id) |
| 303 | + span.set_attribute("agent_name", agent_name) |
| 304 | + span.set_attribute("agent.name", agent_name) |
| 305 | + span.set_attribute("app_name", app_name) |
| 306 | + span.set_attribute("app.name", app_name) |
| 307 | + span.set_attribute("user.id", user_id) |
| 308 | + span.set_attribute("session.id", session_id) |
| 309 | + span.set_attribute("cozeloop.report.source", "veadk") |
| 310 | + |
| 311 | + # llm attributes |
| 312 | + span.set_attribute("gen_ai.system", "openai") |
| 313 | + span.set_attribute("gen_ai.operation.name", "chat") |
| 314 | + if request_model: |
| 315 | + span.set_attribute("gen_ai.request.model", request_model) |
| 316 | + if response_model: |
| 317 | + span.set_attribute("gen_ai.response.model", response_model) |
| 318 | + if total_tokens: |
| 319 | + span.set_attribute("gen_ai.usage.total_tokens", total_tokens) |
| 320 | + if output_tokens: |
| 321 | + span.set_attribute("gen_ai.usage.output_tokens", output_tokens) |
| 322 | + if input_tokens: |
| 323 | + span.set_attribute("gen_ai.usage.input_tokens", input_tokens) |
| 324 | + if input_part: |
| 325 | + span.add_event("gen_ai.user.message", input_part) |
| 326 | + if output_part: |
| 327 | + span.add_event("gen_ai.choice", output_part) |
| 328 | + |
| 329 | + except Exception: |
| 330 | + traceback.print_exc() |
| 331 | + |
| 332 | + |
| 333 | +def _upload_image_to_tos(image_bytes: bytes, object_key: str) -> None: |
| 334 | + try: |
| 335 | + from veadk.integrations.ve_tos.ve_tos import VeTOS |
| 336 | + import os |
| 337 | + from datetime import datetime |
| 338 | + |
| 339 | + timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] |
| 340 | + object_key = f"{timestamp}-{object_key}" |
| 341 | + bucket_name = os.getenv("DATABASE_TOS_BUCKET") |
| 342 | + ve_tos = VeTOS() |
| 343 | + |
| 344 | + tos_url = ve_tos.build_tos_signed_url( |
| 345 | + object_key=object_key, bucket_name=bucket_name |
| 346 | + ) |
| 347 | + |
| 348 | + ve_tos.upload_bytes( |
| 349 | + data=image_bytes, object_key=object_key, bucket_name=bucket_name |
| 350 | + ) |
| 351 | + |
| 352 | + return tos_url |
| 353 | + except Exception as e: |
| 354 | + logger.error(f"Upload to TOS failed: {e}") |
| 355 | + return None |
0 commit comments