Skip to content

Commit 7fbdda4

Browse files
chore(tools): support artifact in image generation tools (#226)
1 parent 1ea6e5e commit 7fbdda4

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

veadk/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import functools
16+
import os
1717
from types import MethodType
1818
from typing import Union
1919

veadk/tools/builtin_tools/generate_image.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import base64
16+
import json
17+
import mimetypes
18+
import traceback
1519
from typing import Dict
1620

1721
from google.adk.tools import ToolContext
18-
from veadk.config import getenv
19-
from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
20-
21-
import base64
22-
from volcenginesdkarkruntime import Ark
22+
from google.genai.types import Blob, Part
2323
from opentelemetry import trace
24-
import traceback
25-
from veadk.version import VERSION
2624
from opentelemetry.trace import Span
27-
from veadk.utils.logger import get_logger
25+
from volcenginesdkarkruntime import Ark
2826
from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
29-
import json
3027

28+
from veadk.config import getenv
29+
from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
30+
from veadk.utils.logger import get_logger
31+
from veadk.utils.misc import formatted_timestamp, read_png_to_bytes
32+
from veadk.version import VERSION
3133

3234
logger = get_logger(__name__)
3335

@@ -121,7 +123,7 @@ async def image_generate(
121123
- size 推荐使用 2048x2048 或表格里的标准比例,确保生成质量。
122124
"""
123125

124-
success_list = []
126+
success_list: list[dict] = []
125127
error_list = []
126128

127129
for idx, item in enumerate(tasks):
@@ -280,6 +282,28 @@ async def image_generate(
280282
"error_list": error_list,
281283
}
282284
else:
285+
app_name = tool_context._invocation_context.app_name
286+
user_id = tool_context._invocation_context.user_id
287+
session_id = tool_context._invocation_context.session.id
288+
289+
artifact_service = tool_context._invocation_context.artifact_service
290+
if artifact_service:
291+
for image in success_list:
292+
for _, image_tos_url in image.items():
293+
filename = f"artifact_{formatted_timestamp()}"
294+
await artifact_service.save_artifact(
295+
app_name=app_name,
296+
user_id=user_id,
297+
session_id=session_id,
298+
filename=filename,
299+
artifact=Part(
300+
inline_data=Blob(
301+
display_name=filename,
302+
data=read_png_to_bytes(image_tos_url),
303+
mime_type=mimetypes.guess_type(image_tos_url)[0],
304+
)
305+
),
306+
)
283307
return {
284308
"status": "success",
285309
"success_list": success_list,
@@ -341,10 +365,11 @@ def add_span_attributes(
341365

342366
def _upload_image_to_tos(image_bytes: bytes, object_key: str) -> None:
343367
try:
344-
from veadk.integrations.ve_tos.ve_tos import VeTOS
345368
import os
346369
from datetime import datetime
347370

371+
from veadk.integrations.ve_tos.ve_tos import VeTOS
372+
348373
timestamp: str = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
349374
object_key = f"{timestamp}-{object_key}"
350375
bucket_name = os.getenv("DATABASE_TOS_BUCKET")

0 commit comments

Comments
 (0)