Skip to content

Commit 6540020

Browse files
doraemonlovewangjiaju
andauthored
fix(tools): fix default builtin tool model params (#221)
Co-authored-by: wangjiaju <[email protected]>
1 parent 66df2b0 commit 6540020

File tree

4 files changed

+38
-19
lines changed

4 files changed

+38
-19
lines changed

veadk/tools/builtin_tools/generate_image.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
client = Ark(
3535
api_key=getenv("MODEL_AGENT_API_KEY"),
36-
base_url=DEFAULT_MODEL_AGENT_API_BASE,
36+
base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
3737
)
3838

3939

@@ -184,15 +184,20 @@ async def image_generate(
184184
and max_images
185185
):
186186
response = client.images.generate(
187-
model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
187+
model=getenv(
188+
"MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME
189+
),
188190
**inputs,
189191
sequential_image_generation_options=SequentialImageGenerationOptions(
190192
max_images=max_images
191193
),
192194
)
193195
else:
194196
response = client.images.generate(
195-
model=DEFAULT_IMAGE_GENERATE_MODEL_NAME, **inputs
197+
model=getenv(
198+
"MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME
199+
),
200+
**inputs,
196201
)
197202
if not response.error:
198203
for i, image_data in enumerate(response.data):
@@ -261,8 +266,12 @@ async def image_generate(
261266
output_part=output_part,
262267
output_tokens=output_tokens,
263268
total_tokens=total_tokens,
264-
request_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
265-
response_model=DEFAULT_IMAGE_GENERATE_MODEL_NAME,
269+
request_model=getenv(
270+
"MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME
271+
),
272+
response_model=getenv(
273+
"MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME
274+
),
266275
)
267276
if len(success_list) == 0:
268277
return {

veadk/tools/builtin_tools/image_edit.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
client = Ark(
3131
api_key=getenv("MODEL_AGENT_API_KEY"),
32-
base_url=DEFAULT_MODEL_AGENT_API_BASE,
32+
base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
3333
)
3434

3535

@@ -123,7 +123,8 @@ async def image_edit(
123123
"parts.1.image_url.url": origin_image,
124124
}
125125
response = client.images.generate(
126-
model=DEFAULT_IMAGE_EDIT_MODEL_NAME, **inputs
126+
model=getenv("MODEL_EDIT_NAME", DEFAULT_IMAGE_EDIT_MODEL_NAME),
127+
**inputs,
127128
)
128129
output_part = None
129130
if response.data and len(response.data) > 0:
@@ -175,8 +176,12 @@ async def image_edit(
175176
output_part=output_part,
176177
output_tokens=response.usage.output_tokens,
177178
total_tokens=response.usage.total_tokens,
178-
request_model=DEFAULT_IMAGE_EDIT_MODEL_NAME,
179-
response_model=DEFAULT_IMAGE_EDIT_MODEL_NAME,
179+
request_model=getenv(
180+
"MODEL_EDIT_NAME", DEFAULT_IMAGE_EDIT_MODEL_NAME
181+
),
182+
response_model=getenv(
183+
"MODEL_EDIT_NAME", DEFAULT_IMAGE_EDIT_MODEL_NAME
184+
),
180185
)
181186

182187
except Exception as e:

veadk/tools/builtin_tools/image_generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
client = Ark(
3232
api_key=getenv("MODEL_AGENT_API_KEY"),
33-
base_url=DEFAULT_MODEL_AGENT_API_BASE,
33+
base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
3434
)
3535

3636

@@ -120,7 +120,8 @@ async def image_generate(
120120
"content": json.dumps(inputs, ensure_ascii=False),
121121
}
122122
response = client.images.generate(
123-
model=DEFAULT_TEXT_TO_IMAGE_MODEL_NAME, **inputs
123+
model=getenv("MODEL_IMAGE_NAME", DEFAULT_TEXT_TO_IMAGE_MODEL_NAME),
124+
**inputs,
124125
)
125126
output_part = None
126127
if response.data and len(response.data) > 0:
@@ -172,8 +173,12 @@ async def image_generate(
172173
output_part=output_part,
173174
output_tokens=response.usage.output_tokens,
174175
total_tokens=response.usage.total_tokens,
175-
request_model=DEFAULT_TEXT_TO_IMAGE_MODEL_NAME,
176-
response_model=DEFAULT_TEXT_TO_IMAGE_MODEL_NAME,
176+
request_model=getenv(
177+
"MODEL_IMAGE_NAME", DEFAULT_TEXT_TO_IMAGE_MODEL_NAME
178+
),
179+
response_model=getenv(
180+
"MODEL_IMAGE_NAME", DEFAULT_TEXT_TO_IMAGE_MODEL_NAME
181+
),
177182
)
178183

179184
except Exception as e:

veadk/tools/builtin_tools/video_generate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
client = Ark(
3636
api_key=getenv("MODEL_AGENT_API_KEY"),
37-
base_url=DEFAULT_MODEL_AGENT_API_BASE,
37+
base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
3838
)
3939

4040

@@ -43,15 +43,15 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
4343
if first_frame_image is None:
4444
logger.debug("text generation")
4545
response = client.content_generation.tasks.create(
46-
model=DEFAULT_VIDEO_MODEL_NAME,
46+
model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
4747
content=[
4848
{"type": "text", "text": prompt},
4949
],
5050
)
5151
elif last_frame_image is None:
5252
logger.debug("first frame generation")
5353
response = client.content_generation.tasks.create(
54-
model=DEFAULT_VIDEO_MODEL_NAME,
54+
model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
5555
content=cast(
5656
list[CreateTaskContentParam], # avoid IDE warning
5757
[
@@ -66,7 +66,7 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
6666
else:
6767
logger.debug("last frame generation")
6868
response = client.content_generation.tasks.create(
69-
model=DEFAULT_VIDEO_MODEL_NAME,
69+
model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
7070
content=[
7171
{"type": "text", "text": prompt},
7272
{
@@ -263,8 +263,8 @@ async def video_generate(params: list, tool_context: ToolContext) -> Dict:
263263
output_part=output_part,
264264
output_tokens=total_tokens,
265265
total_tokens=total_tokens,
266-
request_model=DEFAULT_VIDEO_MODEL_NAME,
267-
response_model=DEFAULT_VIDEO_MODEL_NAME,
266+
request_model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
267+
response_model=getenv("MODEL_VIDEO_NAME", DEFAULT_VIDEO_MODEL_NAME),
268268
)
269269

270270
if len(success_list) == 0:

0 commit comments

Comments
 (0)