Skip to content

Commit 0108de3

Browse files
author
wangjiaju.716
committed
Add buildin tools:image_generate,image_edit,video_generate
1 parent 7cef667 commit 0108de3

File tree

3 files changed

+290
-0
lines changed

3 files changed

+290
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Dict
2+
from google.adk.tools import ToolContext
3+
from google.genai import types
4+
from volcenginesdkarkruntime import Ark
5+
from veadk.config import getenv
6+
import base64
7+
8+
client = Ark(
9+
api_key=getenv("MODEL_IMAGE_API_KEY"),
10+
base_url=getenv("MODEL_IMAGE_API_BASE"),
11+
)
12+
13+
async def image_edit(
14+
origin_image: str,
15+
image_name: str,
16+
image_prompt: str,
17+
response_format: str,
18+
guidance_scale: float,
19+
watermark: bool,
20+
seed: int,
21+
tool_context: ToolContext) -> Dict:
22+
"""Edit an image accoding to the prompt.
23+
24+
Args:
25+
origin_image: The url or the base64 string of the edited image.
26+
image_name: The name of the generated image.
27+
image_prompt: The prompt that describes the image.
28+
response_format: str, b64_json or url, default url.
29+
guidance_scale: default 2.5.
30+
watermark: default True.
31+
seed: default -1.
32+
33+
"""
34+
try:
35+
response = client.images.generate(
36+
model=getenv("MODEL_EDIT_NAME"),
37+
image=origin_image,
38+
prompt=image_prompt,
39+
response_format=response_format,
40+
guidance_scale=guidance_scale,
41+
watermark=watermark,
42+
seed=seed
43+
)
44+
45+
if response.data and len(response.data) > 0:
46+
for item in response.data:
47+
if response_format == "url":
48+
image = item.url
49+
tool_context.state["generated_image_url"] = image
50+
51+
elif response_format == "b64_json":
52+
image = item.b64_json
53+
image_bytes = base64.b64decode(image)
54+
55+
tool_context.state["generated_image_url"] = (
56+
f"data:image/jpeg;base64,{image}"
57+
)
58+
59+
report_artifact = types.Part.from_bytes(
60+
data=image_bytes, mime_type="image/png"
61+
)
62+
await tool_context.save_artifact(image_name, report_artifact)
63+
logger.debug(f"Image saved as ADK artifact: {image_name}")
64+
65+
return {
66+
"status": "success",
67+
"image_name": image_name,
68+
"image": image
69+
}
70+
else:
71+
error_details = f"No images returned by Doubao model: {response}"
72+
logger.error(error_details)
73+
return {"status": "error", "message": error_details}
74+
75+
except Exception as e:
76+
return {
77+
"status": "error",
78+
"message": f"Doubao image generation failed: {str(e)}",
79+
}
80+
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Dict
2+
3+
from google.genai import types
4+
from google.adk.tools import ToolContext
5+
from veadk.config import getenv
6+
import base64
7+
from volcenginesdkarkruntime import Ark
8+
9+
from veadk.utils.logger import get_logger
10+
11+
logger = get_logger(__name__)
12+
13+
client = Ark(
14+
api_key=getenv("MODEL_IMAGE_API_KEY"),
15+
base_url=getenv("MODEL_IMAGE_API_BASE"),
16+
)
17+
18+
async def image_generate(
19+
image_name: str,
20+
image_prompt: str,
21+
response_format: str,
22+
size: str,
23+
guidance_scale: float,
24+
watermark: bool,
25+
seed: int,
26+
tool_context: ToolContext) -> Dict:
27+
"""Generate an image accoding to the prompt.
28+
29+
Args:
30+
image_name: The name of the generated image.
31+
image_prompt: The prompt that describes the image.
32+
response_format: str, b64_json or url, default url.
33+
size: default 1024x1024.
34+
guidance_scale: default 2.5.
35+
watermark: default True.
36+
seed: default -1.
37+
38+
"""
39+
try:
40+
response = client.images.generate(
41+
model=getenv("MODEL_IMAGE_NAME"),
42+
prompt=image_prompt,
43+
response_format=response_format,
44+
size=size,
45+
guidance_scale=guidance_scale,
46+
watermark=watermark,
47+
seed=seed
48+
)
49+
50+
if response.data and len(response.data) > 0:
51+
for item in response.data:
52+
if response_format == "url":
53+
image = item.url
54+
tool_context.state["generated_image_url"] = image
55+
56+
elif response_format == "b64_json":
57+
image = item.b64_json
58+
image_bytes = base64.b64decode(image)
59+
60+
tool_context.state["generated_image_url"] = (
61+
f"data:image/jpeg;base64,{image}"
62+
)
63+
64+
report_artifact = types.Part.from_bytes(
65+
data=image_bytes, mime_type="image/png"
66+
)
67+
await tool_context.save_artifact(image_name, report_artifact)
68+
logger.debug(f"Image saved as ADK artifact: {image_name}")
69+
70+
return {
71+
"status": "success",
72+
"image_name": image_name,
73+
"image": image
74+
}
75+
else:
76+
error_details = f"No images returned by Doubao model: {response}"
77+
logger.error(error_details)
78+
return {"status": "error", "message": error_details}
79+
80+
except Exception as e:
81+
return {
82+
"status": "error",
83+
"message": f"Doubao image generation failed: {str(e)}",
84+
}
85+
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import Dict
2+
from google.adk.tools import ToolContext
3+
from volcenginesdkarkruntime import Ark
4+
from veadk.config import getenv
5+
import time
6+
import traceback
7+
import base64
8+
9+
from veadk.utils.logger import get_logger
10+
11+
logger = get_logger(__name__)
12+
13+
client = Ark(
14+
api_key=getenv("MODEL_VIDEO_API_KEY"),
15+
base_url=getenv("MODEL_VIDEO_API_BASE"),
16+
)
17+
18+
async def generate(tool_context, prompt, first_frame_image=None, last_frame_image=None):
19+
try:
20+
if first_frame_image is None:
21+
logger.debug("text generation")
22+
response = client.content_generation.tasks.create(
23+
model=getenv("MODEL_VIDEO_NAME"),
24+
content=[
25+
{"type": "text", "text": prompt},
26+
],
27+
)
28+
elif last_frame_image is None:
29+
logger.debug("first frame generation")
30+
response = client.content_generation.tasks.create(
31+
model=getenv("MODEL_VIDEO_NAME"),
32+
content=[
33+
{"type": "text", "text": prompt},
34+
{
35+
"type": "image_url",
36+
"image_url": {"url": first_frame_image},
37+
},
38+
],
39+
)
40+
else:
41+
logger.debug("last frame generation")
42+
response = client.content_generation.tasks.create(
43+
model=getenv("MODEL_VIDEO_NAME"),
44+
content=[
45+
{"type": "text", "text": prompt},
46+
{
47+
"type": "image_url",
48+
"image_url": {"url": first_frame_image},
49+
"role": "first_frame",
50+
},
51+
{
52+
"type": "image_url",
53+
"image_url": {"url": last_frame_image},
54+
"role": "last_frame",
55+
},
56+
],
57+
)
58+
except:
59+
traceback.print_exc()
60+
raise
61+
return response
62+
63+
async def video_generate(
64+
params: list,
65+
tool_context: ToolContext) -> Dict:
66+
"""Generate video in batch according to the prompt.
67+
68+
Args:
69+
params:
70+
video_name: The name of the generated video.
71+
first_frame: The first frame of the video, url or base64 string, or None.
72+
last_frame:The last frame of the video, url or base64 string, or None.
73+
prompt:The prompt of the video.
74+
"""
75+
batch_size = 10
76+
success_list = []
77+
error_list = []
78+
for start_idx in range(0, len(params), batch_size):
79+
batch = params[start_idx : start_idx + batch_size]
80+
task_dict = {}
81+
for item in batch:
82+
video_name = item["video_name"]
83+
first_frame = item["first_frame"]
84+
last_frame = item["last_frame"]
85+
prompt = item["prompt"]
86+
try:
87+
if not first_frame:
88+
response = await generate(tool_context, prompt)
89+
elif not last_frame:
90+
response = await generate(tool_context, prompt, first_frame)
91+
else:
92+
response = await generate(
93+
tool_context, prompt, first_frame, last_frame
94+
)
95+
task_dict[response.id] = video_name
96+
except Exception:
97+
traceback.print_exc()
98+
while True:
99+
task_list = list(task_dict.keys())
100+
if len(task_list) == 0:
101+
break
102+
for task_id in task_list:
103+
result = client.content_generation.tasks.get(task_id=task_id)
104+
status = result.status
105+
if status == "succeeded":
106+
logger.debug("----- task succeeded -----")
107+
tool_context.state[f"{task_dict[task_id]}_video_url"] = result.content.video_url
108+
success_list.append({task_dict[task_id]: result.content.video_url})
109+
task_dict.pop(task_id, None)
110+
elif status == "failed":
111+
logger.debug("----- task failed -----")
112+
logger.debug(f"Error: {result.error}")
113+
error_list.append(task_dict[task_id])
114+
task_dict.pop(task_id, None)
115+
else:
116+
logger.debug(f"Current status: {status}, Retrying after 10 seconds...")
117+
time.sleep(10)
118+
119+
if len(success_list) == 0:
120+
return {"status": "error", "message": f"Following videos failed: {error_list}"}
121+
else:
122+
return {
123+
"status": "success",
124+
"message": f"Following videos generated: {success_list}\nFollowing videos failed: {error_list}",
125+
}

0 commit comments

Comments
 (0)