|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import base64 |
| 16 | +import json |
| 17 | +import mimetypes |
| 18 | +import traceback |
15 | 19 | from typing import Dict |
16 | 20 |
|
17 | 21 | from google.adk.tools import ToolContext |
18 | | -from veadk.config import getenv |
19 | | -from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE |
20 | | - |
21 | | -import base64 |
22 | | -from volcenginesdkarkruntime import Ark |
| 22 | +from google.genai.types import Blob, Part |
23 | 23 | from opentelemetry import trace |
24 | | -import traceback |
25 | | -from veadk.version import VERSION |
26 | 24 | from opentelemetry.trace import Span |
27 | | -from veadk.utils.logger import get_logger |
| 25 | +from volcenginesdkarkruntime import Ark |
28 | 26 | from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions |
29 | | -import json |
30 | 27 |
|
| 28 | +from veadk.config import getenv |
| 29 | +from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE |
| 30 | +from veadk.utils.logger import get_logger |
| 31 | +from veadk.utils.misc import formatted_timestamp, read_png_to_bytes |
| 32 | +from veadk.version import VERSION |
31 | 33 |
|
32 | 34 | logger = get_logger(__name__) |
33 | 35 |
|
@@ -121,7 +123,7 @@ async def image_generate( |
121 | 123 | - size 推荐使用 2048x2048 或表格里的标准比例,确保生成质量。 |
122 | 124 | """ |
123 | 125 |
|
124 | | - success_list = [] |
| 126 | + success_list: list[dict] = [] |
125 | 127 | error_list = [] |
126 | 128 |
|
127 | 129 | for idx, item in enumerate(tasks): |
@@ -280,6 +282,28 @@ async def image_generate( |
280 | 282 | "error_list": error_list, |
281 | 283 | } |
282 | 284 | else: |
| 285 | + app_name = tool_context._invocation_context.app_name |
| 286 | + user_id = tool_context._invocation_context.user_id |
| 287 | + session_id = tool_context._invocation_context.session.id |
| 288 | + |
| 289 | + artifact_service = tool_context._invocation_context.artifact_service |
| 290 | + if artifact_service: |
| 291 | + for image in success_list: |
| 292 | + for _, image_tos_url in image.items(): |
| 293 | + filename = f"artifact_{formatted_timestamp()}" |
| 294 | + await artifact_service.save_artifact( |
| 295 | + app_name=app_name, |
| 296 | + user_id=user_id, |
| 297 | + session_id=session_id, |
| 298 | + filename=filename, |
| 299 | + artifact=Part( |
| 300 | + inline_data=Blob( |
| 301 | + display_name=filename, |
| 302 | + data=read_png_to_bytes(image_tos_url), |
| 303 | + mime_type=mimetypes.guess_type(image_tos_url)[0], |
| 304 | + ) |
| 305 | + ), |
| 306 | + ) |
283 | 307 | return { |
284 | 308 | "status": "success", |
285 | 309 | "success_list": success_list, |
@@ -341,10 +365,11 @@ def add_span_attributes( |
341 | 365 |
|
342 | 366 | def _upload_image_to_tos(image_bytes: bytes, object_key: str) -> None: |
343 | 367 | try: |
344 | | - from veadk.integrations.ve_tos.ve_tos import VeTOS |
345 | 368 | import os |
346 | 369 | from datetime import datetime |
347 | 370 |
|
| 371 | + from veadk.integrations.ve_tos.ve_tos import VeTOS |
| 372 | + |
348 | 373 | timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] |
349 | 374 | object_key = f"{timestamp}-{object_key}" |
350 | 375 | bucket_name = os.getenv("DATABASE_TOS_BUCKET") |
|
0 commit comments