Skip to content

Commit b9e2992

Browse files
author
wangjiaju
committed
Add default model name
1 parent 9a270d1 commit b9e2992

File tree

5 files changed

+33
-23
lines changed

5 files changed

+33
-23
lines changed

veadk/consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,8 @@
5959
DEFAULT_TOS_BUCKET_NAME = "veadk-default-bucket"
6060

6161
DEFAULT_COZELOOP_SPACE_NAME = "VeADK Space"
62+
63+
DEFAULT_TEXT_TO_IMAGE_MODEL_NAME = "doubao-seedream-3-0-t2i-250415"
64+
DEFAULT_IMAGE_EDIT_MODEL_NAME = "doubao-seededit-3-0-i2i-250628"
65+
DEFAULT_VIDEO_MODEL_NAME = "doubao-seedance-1-0-pro-250528"
66+
DEFAULT_IMAGE_GENERATE_MODEL_NAME = "doubao-seedream-4-0-250828"

veadk/tools/builtin_tools/generate_image.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from google.genai import types
1818
from google.adk.tools import ToolContext
1919
from veadk.config import getenv
20+
from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
21+
2022
import base64
2123
from volcenginesdkarkruntime import Ark
2224
from opentelemetry import trace
@@ -31,8 +33,8 @@
3133
logger = get_logger(__name__)
3234

3335
client = Ark(
34-
api_key=getenv("MODEL_IMAGE_API_KEY"),
35-
base_url=getenv("MODEL_IMAGE_API_BASE"),
36+
api_key=getenv("MODEL_API_KEY"),
37+
base_url=DEFAULT_MODEL_AGENT_API_BASE,
3638
)
3739

3840

@@ -175,7 +177,7 @@ async def image_generate(
175177
):
176178
print(f"generate multi image, max_images: {max_images}")
177179
response = client.images.generate(
178-
model=getenv("MODEL_IMAGE_NAME"),
180+
model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
179181
**inputs,
180182
sequential_image_generation_options=SequentialImageGenerationOptions(
181183
max_images=max_images
@@ -184,7 +186,7 @@ async def image_generate(
184186
else:
185187
print("generate single image")
186188
response = client.images.generate(
187-
model=getenv("MODEL_IMAGE_NAME"), **inputs
189+
model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, **inputs
188190
)
189191
if not response.error:
190192
for i, image_data in enumerate(response.data):
@@ -244,8 +246,8 @@ async def image_generate(
244246
output_part=output_part,
245247
output_tokens=output_tokens,
246248
total_tokens=total_tokens,
247-
request_model=getenv("MODEL_IMAGE_NAME"),
248-
response_model=getenv("MODEL_IMAGE_NAME"),
249+
request_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
250+
response_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
249251
)
250252
if len(success_list) == 0:
251253
return {

veadk/tools/builtin_tools/image_edit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.genai import types
1818
from volcenginesdkarkruntime import Ark
1919
from veadk.config import getenv
20+
from veadk.consts import DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_IMAGE_EDIT_MODEL_NAME
2021
import base64
2122
from opentelemetry import trace
2223
import traceback
@@ -28,8 +29,8 @@
2829
logger = get_logger(__name__)
2930

3031
client = Ark(
31-
api_key=getenv("MODEL_EDIT_API_KEY"),
32-
base_url=getenv("MODEL_EDIT_API_BASE"),
32+
api_key=getenv("MODEL_API_KEY"),
33+
base_url=DEFAULT_MODEL_AGENT_API_BASE,
3334
)
3435

3536

@@ -119,7 +120,7 @@ async def image_edit(
119120
"content": json.dumps(inputs, ensure_ascii=False),
120121
}
121122
response = client.images.generate(
122-
model=getenv("MODEL_EDIT_NAME"), **inputs
123+
model=DEFAULT_IMAGE_EDIT_MODEL_NAME, **inputs
123124
)
124125
output_part = None
125126
if response.data and len(response.data) > 0:
@@ -160,8 +161,8 @@ async def image_edit(
160161
output_part=output_part,
161162
output_tokens=response.usage.output_tokens,
162163
total_tokens=response.usage.total_tokens,
163-
request_model=getenv("MODEL_EDIT_NAME"),
164-
response_model=getenv("MODEL_EDIT_NAME"),
164+
request_model=DEFAULT_IMAGE_EDIT_MODEL_NAME,
165+
response_model=DEFAULT_IMAGE_EDIT_MODEL_NAME,
165166
)
166167

167168
except Exception as e:

veadk/tools/builtin_tools/image_generate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.genai import types
1818
from google.adk.tools import ToolContext
1919
from veadk.config import getenv
20+
from veadk.consts import DEFAULT_TEXT_TO_IMAGE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
2021
import base64
2122
from volcenginesdkarkruntime import Ark
2223
from opentelemetry import trace
@@ -29,8 +30,8 @@
2930
logger = get_logger(__name__)
3031

3132
client = Ark(
32-
api_key=getenv("MODEL_IMAGE_API_KEY"),
33-
base_url=getenv("MODEL_IMAGE_API_BASE"),
33+
api_key=getenv("MODEL_API_KEY"),
34+
base_url=DEFAULT_MODEL_AGENT_API_BASE,
3435
)
3536

3637

@@ -120,7 +121,7 @@ async def image_generate(
120121
"content": json.dumps(inputs, ensure_ascii=False),
121122
}
122123
response = client.images.generate(
123-
model=getenv("MODEL_IMAGE_NAME"), **inputs
124+
model=DEFAULT_TEXT_TO_IMAGE_MODEL_NAME, **inputs
124125
)
125126
output_part = None
126127
if response.data and len(response.data) > 0:
@@ -161,8 +162,8 @@ async def image_generate(
161162
output_part=output_part,
162163
output_tokens=response.usage.output_tokens,
163164
total_tokens=response.usage.total_tokens,
164-
request_model=getenv("MODEL_IMAGE_NAME"),
165-
response_model=getenv("MODEL_IMAGE_NAME"),
165+
request_model=DEFAULT_TEXT_TO_IMAGE_MODEL_NAME,
166+
response_model=DEFAULT_TEXT_TO_IMAGE_MODEL_NAME,
166167
)
167168

168169
except Exception as e:

veadk/tools/builtin_tools/video_generate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
)
2727

2828
from veadk.config import getenv
29+
from veadk.consts import DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_VIDEO_MODEL_NAME
2930
from veadk.utils.logger import get_logger
3031
from veadk.version import VERSION
3132

3233
logger = get_logger(__name__)
3334

3435
client = Ark(
35-
api_key=getenv("MODEL_VIDEO_API_KEY"),
36-
base_url=getenv("MODEL_VIDEO_API_BASE"),
36+
api_key=getenv("MODEL_API_KEY"),
37+
base_url=DEFAULT_MODEL_AGENT_API_BASE,
3738
)
3839

3940

@@ -42,15 +43,15 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
4243
if first_frame_image is None:
4344
logger.debug("text generation")
4445
response = client.content_generation.tasks.create(
45-
model=getenv("MODEL_VIDEO_NAME"),
46+
model=DEFAULT_VIDEO_MODEL_NAME,
4647
content=[
4748
{"type": "text", "text": prompt},
4849
],
4950
)
5051
elif last_frame_image is None:
5152
logger.debug("first frame generation")
5253
response = client.content_generation.tasks.create(
53-
model=getenv("MODEL_VIDEO_NAME"),
54+
model=DEFAULT_VIDEO_MODEL_NAME,
5455
content=cast(
5556
list[CreateTaskContentParam], # avoid IDE warning
5657
[
@@ -65,7 +66,7 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
6566
else:
6667
logger.debug("last frame generation")
6768
response = client.content_generation.tasks.create(
68-
model=getenv("MODEL_VIDEO_NAME"),
69+
model=DEFAULT_VIDEO_MODEL_NAME,
6970
content=[
7071
{"type": "text", "text": prompt},
7172
{
@@ -262,8 +263,8 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
262263
output_part=output_part,
263264
output_tokens=total_tokens,
264265
total_tokens=total_tokens,
265-
request_model=getenv("MODEL_VIDEO_NAME"),
266-
response_model=getenv("MODEL_VIDEO_NAME"),
266+
request_model=DEFAULT_VIDEO_MODEL_NAME,
267+
response_model=DEFAULT_VIDEO_MODEL_NAME,
267268
)
268269

269270
if len(success_list) == 0:

0 commit comments

Comments
 (0)