Skip to content

Commit c4e75d8

Browse files
author
wangjiaju.716
committed
Add seedream-4-0
1 parent 71ac1bc commit c4e75d8

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

veadk/tools/builtin_tools/generate_image.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from opentelemetry.trace import Span
2626
from veadk.utils.logger import get_logger
2727
from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
28+
import json
29+
2830

2931
logger = get_logger(__name__)
3032

@@ -148,7 +150,9 @@ async def image_generate(
148150
image = item.get("image", None)
149151
sequential_image_generation = item.get("sequential_image_generation", None)
150152
max_images = item.get("max_images", None)
151-
print(f"item: {item}")
153+
154+
input_part[f"parts.{idx}.type"] = "text"
155+
input_part[f"parts.{idx}.text"] = json.dumps(item, ensure_ascii=False)
152156
inputs = {
153157
"prompt": prompt,
154158
}
@@ -160,16 +164,16 @@ async def image_generate(
160164
if watermark:
161165
inputs["watermark"] = watermark
162166
if image:
163-
if task_type == "multi_image_to_single":
167+
if task_type.startswith("single"):
164168
assert isinstance(image, str), (
165-
f"multi_image_to_single task_type image must be str, got {type(image)}"
169+
f"single_* task_type image must be str, got {type(image)}"
166170
)
167-
elif task_type == "multi_image_to_multi":
171+
elif task_type.startswith("multi"):
168172
assert isinstance(image, list), (
169-
f"multi_image_to_multi task_type image must be list, got {type(image)}"
173+
f"multi_* task_type image must be list, got {type(image)}"
170174
)
171175
assert len(image) <= 10, (
172-
f"multi_image_to_multi task_type image list length must be <= 10, got {len(image)}"
176+
f"multi_* task_type image list length must be <= 10, got {len(image)}"
173177
)
174178

175179
if sequential_image_generation:
@@ -194,7 +198,6 @@ async def image_generate(
194198
response = client.images.generate(
195199
model=getenv("MODEL_IMAGE_NAME"), **inputs
196200
)
197-
print(f"response: {response}")
198201
if not response.error:
199202
for i, image_data in enumerate(response.data):
200203
image_name = f"task_{idx}_image_{i}"

0 commit comments

Comments
 (0)