Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions veadk/tools/builtin_tools/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _build_input_parts(item: dict, task_type: str, image_field):
def handle_single_task_sync(
idx: int, item: dict, tool_context
) -> tuple[list[dict], list[str]]:
logger.debug(f"handle_single_task_sync item {idx}: {item}")
success_list: list[dict] = []
error_list: list[str] = []
total_tokens = 0
Expand Down Expand Up @@ -131,6 +132,8 @@ def handle_single_task_sync(
)

if not response.error:
logger.debug(f"task {idx} Image generate response: {response}")

total_tokens += getattr(response.usage, "total_tokens", 0) or 0
output_tokens += getattr(response.usage, "output_tokens", 0) or 0

Expand Down Expand Up @@ -165,10 +168,14 @@ def handle_single_task_sync(
output_part[f"message.parts.{i}.type"] = "image_url"
output_part[f"message.parts.{i}.image_url.name"] = image_name
output_part[f"message.parts.{i}.image_url.url"] = image_url

logger.debug(
f"Image {image_name} generated successfully: {image_url}"
)
success_list.append({image_name: image_url})
else:
logger.error(f"No images returned by model: {response.error}")
logger.error(
f"Task {idx} No images returned by model: {response.error}"
)
error_list.append(f"task_{idx}")

except Exception as e:
Expand All @@ -191,7 +198,9 @@ def handle_single_task_sync(
"MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME
),
)

logger.debug(
f"task {idx} Image generate success_list: {success_list}\nerror_list: {error_list}"
)
return success_list, error_list


Expand Down Expand Up @@ -275,9 +284,12 @@ async def image_generate(tasks: list[dict], tool_context) -> Dict:
- 如果想要指定生成组图的数量,请在prompt里添加数量说明,例如:"生成3张图片"。
- size 推荐使用 2048x2048 或表格里的标准比例,确保生成质量。
"""
logger.debug(
f"Using model: {getenv('MODEL_IMAGE_NAME', DEFAULT_IMAGE_GENERATE_MODEL_NAME)}"
)
success_list: list[dict] = []
error_list: list[str] = []

logger.debug(f"image_generate tasks: {tasks}")
with tracer.start_as_current_span("image_generate"):
base_ctx = contextvars.copy_context()

Expand All @@ -303,12 +315,14 @@ def make_task(idx, item):
error_list.extend(e)

if not success_list:
logger.debug(
f"image_generate success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "error",
"success_list": success_list,
"error_list": error_list,
}

app_name = tool_context._invocation_context.app_name
user_id = tool_context._invocation_context.user_id
session_id = tool_context._invocation_context.session.id
Expand All @@ -332,6 +346,9 @@ def make_task(idx, item):
),
)

logger.debug(
f"image_generate success_list: {success_list}\nerror_list: {error_list}"
)
return {"status": "success", "success_list": success_list, "error_list": error_list}


Expand Down
16 changes: 15 additions & 1 deletion veadk/tools/builtin_tools/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,14 @@ async def image_edit(
- Provide the same `seed` for consistent outputs across runs.
- A high `guidance_scale` enforces stricter adherence to text prompt.
"""
logger.debug(
f"Using model: {getenv('MODEL_EDIT_NAME', DEFAULT_IMAGE_EDIT_MODEL_NAME)}"
)
success_list = []
error_list = []
logger.debug(f"image_edit params: {params}")
for idx, item in enumerate(params):
logger.debug(f"image_edit item {idx}: {item}")
image_name = item.get("image_name", f"generated_image_{idx}")
prompt = item.get("prompt")
origin_image = item.get("origin_image")
Expand Down Expand Up @@ -133,6 +138,7 @@ async def image_edit(
)
output_part = None
if response.data and len(response.data) > 0:
logger.debug(f"task {idx} Image edit response: {response}")
for item in response.data:
if response_format == "url":
image = item.url
Expand Down Expand Up @@ -167,7 +173,9 @@ async def image_edit(
continue

logger.debug(f"Image saved as ADK artifact: {image_name}")

logger.debug(
f"Image {image_name} generated successfully: {image}"
)
success_list.append({image_name: image})
else:
error_details = f"No images returned by Doubao model: {response}"
Expand Down Expand Up @@ -196,12 +204,18 @@ async def image_edit(
error_list.append(image_name)

if len(success_list) == 0:
logger.debug(
f"image_edit success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "error",
"success_list": success_list,
"error_list": error_list,
}
else:
logger.debug(
f"image_edit success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "success",
"success_list": success_list,
Expand Down
16 changes: 15 additions & 1 deletion veadk/tools/builtin_tools/image_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ async def image_generate(
- Use a fixed `seed` for reproducibility.
- Choose appropriate `size` for desired aspect ratio.
"""
logger.debug(
f"Using model: {getenv('MODEL_IMAGE_NAME', DEFAULT_TEXT_TO_IMAGE_MODEL_NAME)}"
)
success_list = []
error_list = []
logger.debug(f"image_generate params: {params}")
for idx, item in enumerate(params):
logger.debug(f"image_generate item {idx}: {item}")
prompt = item.get("prompt", "")
image_name = item.get("image_name", f"generated_image_{idx}")
response_format = item.get("response_format", "url")
Expand Down Expand Up @@ -130,6 +135,7 @@ async def image_generate(
)
output_part = None
if response.data and len(response.data) > 0:
logger.debug(f"task {idx} Image generate response: {response}")
for item in response.data:
if response_format == "url":
image = item.url
Expand Down Expand Up @@ -164,7 +170,9 @@ async def image_generate(
continue

logger.debug(f"Image saved as ADK artifact: {image_name}")

logger.debug(
f"Image {image_name} generated successfully: {image}"
)
success_list.append({image_name: image})
else:
error_details = f"No images returned by Doubao model: {response}"
Expand Down Expand Up @@ -192,12 +200,18 @@ async def image_generate(
error_list.append(image_name)

if len(success_list) == 0:
logger.debug(
f"image_generate success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "error",
"success_list": success_list,
"error_list": error_list,
}
else:
logger.debug(
f"image_generate success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "success",
"success_list": success_list,
Expand Down
39 changes: 32 additions & 7 deletions veadk/tools/builtin_tools/video_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@
async def generate(prompt, first_frame_image=None, last_frame_image=None):
try:
if first_frame_image is None:
logger.debug("text generation")
response = client.content_generation.tasks.create(
model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
content=[
{"type": "text", "text": prompt},
],
)
elif last_frame_image is None:
logger.debug("first frame generation")
response = client.content_generation.tasks.create(
model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
content=cast(
Expand All @@ -66,7 +64,6 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
),
)
else:
logger.debug("last frame generation")
response = client.content_generation.tasks.create(
model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
content=[
Expand Down Expand Up @@ -197,9 +194,13 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
batch_size = 10
success_list = []
error_list = []
logger.debug(f"Using model: {getenv('MODEL_VIDEO_NAME', DEFAULT_VIDEO_MODEL_NAME)}")
logger.debug(f"video_generate params: {params}")

for start_idx in range(0, len(params), batch_size):
batch = params[start_idx : start_idx + batch_size]
logger.debug(f"video_generate batch {start_idx // batch_size}: {batch}")

task_dict = {}
tracer = trace.get_tracer("gcp.vertex.agent")
with tracer.start_as_current_span("call_llm") as span:
Expand All @@ -216,15 +217,30 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
last_frame = item.get("last_frame", None)
try:
if not first_frame:
logger.debug(
f"video_generate task_{idx} text generation: prompt={prompt}"
)
response = await generate(prompt)
elif not last_frame:
logger.debug(
f"video_generate task_{idx} first frame generation: prompt={prompt}, first_frame={first_frame}"
)
response = await generate(prompt, first_frame)
else:
logger.debug(
f"video_generate task_{idx} first and last frame generation: prompt={prompt}, first_frame={first_frame}, last_frame={last_frame}"
)
response = await generate(prompt, first_frame, last_frame)
logger.debug(
f"batch_{start_idx // batch_size} video_generate task_{idx} response: {response}"
)
task_dict[response.id] = video_name
except Exception as e:
logger.error(f"Error: {e}")
error_list.append(video_name)
continue

logger.debug("begin query video_generate task status...")

while True:
task_list = list(task_dict.keys())
Expand All @@ -234,7 +250,9 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
result = client.content_generation.tasks.get(task_id=task_id)
status = result.status
if status == "succeeded":
logger.debug("----- task succeeded -----")
logger.debug(
f"{task_dict[task_id]} video_generate {status}. Video URL: {result.content.video_url}"
)
tool_context.state[f"{task_dict[task_id]}_video_url"] = (
result.content.video_url
)
Expand All @@ -248,13 +266,14 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
)
task_dict.pop(task_id, None)
elif status == "failed":
logger.error("----- task failed -----")
logger.error(f"Error: {result.error}")
logger.error(
f"{task_dict[task_id]} video_generate {status}. Error: {result.error}"
)
error_list.append(task_dict[task_id])
task_dict.pop(task_id, None)
else:
logger.debug(
f"Current status: {status}, Retrying after 10 seconds..."
f"{task_dict[task_id]} video_generate current status: {status}, Retrying after 10 seconds..."
)
time.sleep(10)

Expand All @@ -270,12 +289,18 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
)

if len(success_list) == 0:
logger.debug(
f"video_generate success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "error",
"success_list": success_list,
"error_list": error_list,
}
else:
logger.debug(
f"video_generate success_list: {success_list}\nerror_list: {error_list}"
)
return {
"status": "success",
"success_list": success_list,
Expand Down