Skip to content

Commit 751b882

Browse files
committed
chore: add vision model request headers
1 parent f05bd89 commit 751b882

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

veadk/tools/builtin_tools/generate_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from veadk.tools.builtin_tools.image_generate import image_generate # noqa: F401
15+
from veadk.tools.builtin_tools.image_generate import (
16+
image_generate, # noqa: F401
17+
)
1618
from veadk.utils.logger import get_logger
1719

1820
logger = get_logger(__name__)

veadk/tools/builtin_tools/image_generate.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from opentelemetry import trace
2727
from opentelemetry.trace import Span
2828
from volcenginesdkarkruntime import Ark
29-
from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
29+
from volcenginesdkarkruntime.types.images.images import (
30+
SequentialImageGenerationOptions,
31+
)
3032

3133
from veadk.config import getenv, settings
3234
from veadk.consts import (
@@ -41,7 +43,8 @@
4143

4244
client = Ark(
4345
api_key=getenv(
44-
"MODEL_IMAGE_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
46+
"MODEL_IMAGE_API_KEY",
47+
getenv("MODEL_AGENT_API_KEY", settings.model.api_key),
4548
),
4649
base_url=getenv("MODEL_IMAGE_API_BASE", DEFAULT_IMAGE_GENERATE_MODEL_API_BASE),
4750
)
@@ -119,11 +122,24 @@ def handle_single_task_sync(
119122
and sequential_image_generation == "auto"
120123
and max_images
121124
):
122-
response = client.images.generate(
123-
model=getenv("MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME),
124-
**inputs,
125-
sequential_image_generation_options=SequentialImageGenerationOptions(
126-
max_images=max_images
125+
response = (
126+
client.images.generate(
127+
model=getenv(
128+
"MODEL_IMAGE_NAME",
129+
DEFAULT_IMAGE_GENERATE_MODEL_NAME,
130+
),
131+
**inputs,
132+
sequential_image_generation_options=SequentialImageGenerationOptions(
133+
max_images=max_images
134+
),
135+
extra_headers={
136+
"veadk-source": "veadk",
137+
"veadk-version": VERSION,
138+
"User-Agent": f"VeADK/{VERSION}",
139+
"X-Client-Request-Id": getenv(
140+
"MODEL_AGENT_CLIENT_REQ_ID", f"veadk/{VERSION}"
141+
),
142+
},
127143
),
128144
)
129145
else:
@@ -157,7 +173,8 @@ def handle_single_task_sync(
157173
continue
158174
image_bytes = base64.b64decode(b64)
159175
image_url = _upload_image_to_tos(
160-
image_bytes=image_bytes, object_key=f"{image_name}.png"
176+
image_bytes=image_bytes,
177+
object_key=f"{image_name}.png",
161178
)
162179
if not image_url:
163180
logger.error(f"Upload image to TOS failed: {image_name}")
@@ -367,7 +384,11 @@ def make_task(idx, item):
367384
logger.debug(
368385
f"image_generate success_list: {success_list}\nerror_list: {error_list}"
369386
)
370-
return {"status": "success", "success_list": success_list, "error_list": error_list}
387+
return {
388+
"status": "success",
389+
"success_list": success_list,
390+
"error_list": error_list,
391+
}
371392

372393

373394
def add_span_attributes(

veadk/tools/builtin_tools/video_generate.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
client = Ark(
3636
api_key=getenv(
37-
"MODEL_VIDEO_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
37+
"MODEL_VIDEO_API_KEY",
38+
getenv("MODEL_AGENT_API_KEY", settings.model.api_key),
3839
),
3940
base_url=getenv("MODEL_VIDEO_API_BASE", DEFAULT_VIDEO_MODEL_API_BASE),
4041
)
@@ -48,6 +49,14 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
4849
content=[
4950
{"type": "text", "text": prompt},
5051
],
52+
extra_headers={
53+
"veadk-source": "veadk",
54+
"veadk-version": VERSION,
55+
"User-Agent": f"VeADK/{VERSION}",
56+
"X-Client-Request-Id": getenv(
57+
"MODEL_AGENT_CLIENT_REQ_ID", f"veadk/{VERSION}"
58+
),
59+
},
5160
)
5261
elif last_frame_image is None:
5362
response = client.content_generation.tasks.create(
@@ -62,6 +71,14 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
6271
},
6372
],
6473
),
74+
extra_headers={
75+
"veadk-source": "veadk",
76+
"veadk-version": VERSION,
77+
"User-Agent": f"VeADK/{VERSION}",
78+
"X-Client-Request-Id": getenv(
79+
"MODEL_AGENT_CLIENT_REQ_ID", f"veadk/{VERSION}"
80+
),
81+
},
6582
)
6683
else:
6784
response = client.content_generation.tasks.create(
@@ -79,6 +96,14 @@ async def generate(prompt, first_frame_image=None, last_frame_image=None):
7996
"role": "last_frame",
8097
},
8198
],
99+
extra_headers={
100+
"veadk-source": "veadk",
101+
"veadk-version": VERSION,
102+
"User-Agent": f"VeADK/{VERSION}",
103+
"X-Client-Request-Id": getenv(
104+
"MODEL_AGENT_CLIENT_REQ_ID", f"veadk/{VERSION}"
105+
),
106+
},
82107
)
83108
except:
84109
traceback.print_exc()

0 commit comments

Comments
 (0)