Skip to content

Commit a143863

Browse files
author
xiangjie.meng
committed
feat: batch-worker content generation
1 parent a867758 commit a143863

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

volcenginesdkarkruntime/resources/content_generation/tasks.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def create(
4646
content: Iterable[CreateTaskContentParam],
4747
callback_url: Optional[str] = None,
4848
return_last_frame: Optional[bool] = None,
49+
service_tier: Optional[str] = None,
50+
execution_expires_after: Optional[int] = None,
4951
extra_headers: Headers | None = None,
5052
extra_query: Query | None = None,
5153
extra_body: Body | None = None,
@@ -58,6 +60,8 @@ def create(
5860
"content": content,
5961
"callback_url": callback_url,
6062
"return_last_frame": return_last_frame,
63+
"service_tier": service_tier,
64+
"execution_expires_after": execution_expires_after,
6165
},
6266
options=make_request_options(
6367
extra_headers=extra_headers,
@@ -99,6 +103,7 @@ def list(
99103
status: str | None = None,
100104
task_ids: Union[List[str], str] | None = None,
101105
model: str | None = None,
106+
service_tier: str | None = None,
102107
extra_headers: Headers | None = None,
103108
extra_body: Body | None = None,
104109
extra_query: Query | None = None,
@@ -113,6 +118,8 @@ def list(
113118
query_params.append(("filter.status", status))
114119
if model:
115120
query_params.append(("filter.model", model))
121+
if service_tier:
122+
query_params.append(("filter.service_tier", service_tier))
116123
if task_ids:
117124
if isinstance(task_ids, str):
118125
task_ids = [task_ids]
@@ -167,6 +174,8 @@ async def create(
167174
content: Iterable[CreateTaskContentParam],
168175
callback_url: Optional[str] = None,
169176
return_last_frame: Optional[bool] = None,
177+
service_tier: Optional[str] = None,
178+
execution_expires_after: Optional[int] = None,
170179
extra_headers: Headers | None = None,
171180
extra_query: Query | None = None,
172181
extra_body: Body | None = None,
@@ -179,6 +188,8 @@ async def create(
179188
"content": content,
180189
"callback_url": callback_url,
181190
"return_last_frame": return_last_frame,
191+
"service_tier": service_tier,
192+
"execution_expires_after": execution_expires_after,
182193
},
183194
options=make_request_options(
184195
extra_headers=extra_headers,
@@ -220,6 +231,7 @@ async def list(
220231
status: str | None = None,
221232
task_ids: Union[List[str], str] | None = None,
222233
model: str | None = None,
234+
service_tier: str | None = None,
223235
extra_headers: Headers | None = None,
224236
extra_body: Body | None = None,
225237
extra_query: Query | None = None,
@@ -234,6 +246,8 @@ async def list(
234246
query_params.append(("filter.status", status))
235247
if model:
236248
query_params.append(("filter.model", model))
249+
if service_tier:
250+
query_params.append(("filter.service_tier", service_tier))
237251
if task_ids:
238252
if isinstance(task_ids, str):
239253
task_ids = [task_ids]

volcenginesdkarkruntime/types/content_generation/content_generation_task.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,9 @@ class ContentGenerationTask(BaseModel):
8080

8181
revised_prompt: str
8282
"""The revised prompt the model uses to generate content"""
83+
84+
service_tier: str
85+
"""The service tier used to run the task (optional)."""
86+
87+
execution_expires_after: int
88+
"""The expiration time in seconds after which execution should end (optional)."""

volcenginesdkexamples/volcenginesdkarkruntime/content_generation_tasks.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@
2424
# "role": "first_frame"
2525
}
2626
],
27+
service_tier="default",
28+
execution_expires_after=3600,
2729
# callback_url="${YOUR_CALLBACK_URL}"
2830
)
2931
print(create_result)
3032

3133
print("----- get request -----")
3234
get_result = client.content_generation.tasks.get(task_id=create_result.id)
3335
print(get_result)
36+
print("ServiceTier:", getattr(get_result, "service_tier", None))
37+
print("ExecutionExpiresAfter:", getattr(get_result, "execution_expires_after", None))
3438

3539
print("----- list request -----")
3640
list_result = client.content_generation.tasks.list(
@@ -41,10 +45,52 @@
4145
# task_ids=["test-id-1", "test-id-2"] # Filter by task_ids
4246
)
4347
print(list_result)
48+
if list_result.items:
49+
print("List Item ServiceTier:", getattr(list_result.items[0], "service_tier", None))
50+
print("List Item ExecutionExpiresAfter:", getattr(list_result.items[0], "execution_expires_after", None))
4451

4552
print("----- delete request -----")
4653
try:
4754
client.content_generation.tasks.delete(task_id=create_result.id)
4855
print(create_result.id)
4956
except Exception as e:
5057
print(f"failed to delete task: {e}")
58+
59+
# ---- flex tier flow: create + GET + LIST + DELETE ----
60+
print("----- create request (flex) -----")
61+
create_result_flex = client.content_generation.tasks.create(
62+
model="${YOUR_MODEL_EP}",
63+
content=[
64+
{
65+
"type": "text",
66+
"text": "使用 flex 级别进行内容生成测试,验证 service_tier 与 expire 字段"
67+
}
68+
],
69+
service_tier="flex",
70+
execution_expires_after=3600,
71+
)
72+
print(create_result_flex)
73+
74+
print("----- get request (flex) -----")
75+
get_result_flex = client.content_generation.tasks.get(task_id=create_result_flex.id)
76+
print(get_result_flex)
77+
print("Flex ServiceTier:", getattr(get_result_flex, "service_tier", None))
78+
print("Flex ExecutionExpiresAfter:", getattr(get_result_flex, "execution_expires_after", None))
79+
80+
print("----- list request (flex) -----")
81+
list_result_flex = client.content_generation.tasks.list(
82+
page_num=1,
83+
page_size=10,
84+
service_tier="flex",
85+
)
86+
print(list_result_flex)
87+
if list_result_flex.items:
88+
print("Flex List Item ServiceTier:", getattr(list_result_flex.items[0], "service_tier", None))
89+
print("Flex List Item ExecutionExpiresAfter:", getattr(list_result_flex.items[0], "execution_expires_after", None))
90+
91+
print("----- delete request (flex) -----")
92+
try:
93+
client.content_generation.tasks.delete(task_id=create_result_flex.id)
94+
print(create_result_flex.id)
95+
except Exception as e:
96+
print(f"failed to delete flex task: {e}")

0 commit comments

Comments
 (0)