Skip to content

Commit cf302c0

Browse files
committed
fix: improve ark batch chat example code
1 parent eab96c7 commit cf302c0

File tree

1 file changed

+50
-31
lines changed

1 file changed

+50
-31
lines changed

volcenginesdkexamples/volcenginesdkarkruntime/async_batch_chat_completions.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import sys
33
from datetime import datetime
44

5-
import uvloop
6-
75
from volcenginesdkarkruntime import AsyncArk
86

97

@@ -18,45 +16,66 @@
1816
# To get your ak&sk, please refer to this document(https://www.volcengine.com/docs/6291/65568)
1917
# For more information,please check this document(https://www.volcengine.com/docs/82379/1263279)
2018

21-
async def worker(worker_id, task_num):
22-
client = AsyncArk()
19+
async def worker(
20+
worker_id: int,
21+
client: AsyncArk,
22+
requests: asyncio.Queue[dict],
23+
):
2324
print(f"Worker {worker_id} is starting.")
24-
for i in range(task_num):
25-
print(f"Worker {worker_id} task {i} is running.")
25+
26+
while True:
27+
request = await requests.get()
2628
try:
27-
completion = await client.batch_chat.completions.create(
28-
model="${YOUR_ENDPOINT_ID}",
29-
messages=[
30-
{"role": "system", "content": "你是豆包,是由字节跳动开发的 AI 人工智能助手"},
31-
{"role": "user", "content": "常见的十字花科植物有哪些?"},
32-
],
33-
)
34-
print(completion.choices[0].message.content)
29+
completion = await client.batch_chat.completions.create(**request)
30+
print(completion)
3531
except Exception as e:
36-
print(f"Worker {worker_id} task {i} failed with error: {e}")
37-
else:
38-
print(f"Worker {worker_id} task {i} is completed.")
39-
print(f"Worker {worker_id} is completed.")
32+
print(e, file=sys.stderr)
33+
finally:
34+
requests.task_done()
4035

4136

4237
async def main():
4338
start = datetime.now()
44-
max_concurrent_tasks = 1000
45-
task_num = 5
39+
max_concurrent_tasks, task_num = 1000, 10000
40+
41+
requests = asyncio.Queue()
42+
client = AsyncArk(timeout=24 * 3600)
43+
44+
# mock `task_num` tasks
45+
for _ in range(task_num):
46+
await requests.put(
47+
{
48+
"model": "${YOUR_ENDPOINT_ID}",
49+
"messages": [
50+
{
51+
"role": "system",
52+
"content": "你是豆包,是由字节跳动开发的 AI 人工智能助手",
53+
},
54+
{"role": "user", "content": "常见的十字花科植物有哪些?"},
55+
],
56+
}
57+
)
58+
59+
# create `max_concurrent_tasks` workers and start them
60+
tasks = [
61+
asyncio.create_task(worker(i, client, requests))
62+
for i in range(max_concurrent_tasks)
63+
]
64+
65+
# wait for all requests is done
66+
await requests.join()
67+
68+
# stop workers
69+
for task in tasks:
70+
task.cancel()
4671

47-
# 创建任务列表
48-
tasks = [worker(i, task_num) for i in range(max_concurrent_tasks)]
72+
# wait for all workers is canceled
73+
await asyncio.gather(*tasks, return_exceptions=True)
74+
await client.close()
4975

50-
# 等待所有任务完成
51-
await asyncio.gather(*tasks)
5276
end = datetime.now()
53-
print(f"Total time: {end - start}, Total task: {max_concurrent_tasks * task_num}")
77+
print(f"Total time: {end - start}, Total task: {task_num}")
5478

5579

5680
if __name__ == "__main__":
57-
if sys.version_info >= (3, 11):
58-
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
59-
runner.run(main())
60-
else:
61-
uvloop.install()
62-
asyncio.run(main())
81+
asyncio.run(main())

0 commit comments

Comments
 (0)