Skip to content

Commit bd1f6c8

Browse files
committed
feat(*): support ark batch chat
1 parent 2245750 commit bd1f6c8

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed
Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
2+
import sys
23
from datetime import datetime
34

5+
import uvloop
6+
47
from volcenginesdkarkruntime import AsyncArk
58

69
# Authentication
@@ -16,35 +19,33 @@
1619
client = AsyncArk()
1720

1821

19-
async def worker(semaphore, worker_id, task_num):
20-
async with semaphore:
21-
print(f"Worker {worker_id} is starting.")
22-
for i in range(task_num):
23-
print(f"Worker {worker_id} task {i} is running.")
24-
try:
25-
completion = await client.batch_chat.completions.create(
26-
model="${YOUR_ENDPOINT_ID}",
27-
messages=[
28-
{"role": "system", "content": "你是豆包,是由字节跳动开发的 AI 人工智能助手"},
29-
{"role": "user", "content": "常见的十字花科植物有哪些?"},
30-
],
31-
)
32-
print(completion.choices[0].message.content)
33-
except Exception as e:
34-
print(f"Worker {worker_id} task {i} failed with error: {e}")
35-
else:
36-
print(f"Worker {worker_id} task {i} is completed.")
37-
print(f"Worker {worker_id} is completed.")
22+
async def worker(worker_id, task_num):
23+
print(f"Worker {worker_id} is starting.")
24+
for i in range(task_num):
25+
print(f"Worker {worker_id} task {i} is running.")
26+
try:
27+
completion = await client.batch_chat.completions.create(
28+
model="${YOUR_ENDPOINT_ID}",
29+
messages=[
30+
{"role": "system", "content": "你是豆包,是由字节跳动开发的 AI 人工智能助手"},
31+
{"role": "user", "content": "常见的十字花科植物有哪些?"},
32+
],
33+
)
34+
print(completion.choices[0].message.content)
35+
except Exception as e:
36+
print(f"Worker {worker_id} task {i} failed with error: {e}")
37+
else:
38+
print(f"Worker {worker_id} task {i} is completed.")
39+
print(f"Worker {worker_id} is completed.")
3840

3941

4042
async def main():
4143
start = datetime.now()
4244
max_concurrent_tasks = 1000
4345
task_num = 5
44-
semaphore = asyncio.Semaphore(max_concurrent_tasks)
4546

4647
# 创建任务列表
47-
tasks = [worker(semaphore, i, task_num) for i in range(max_concurrent_tasks)]
48+
tasks = [worker(i, task_num) for i in range(max_concurrent_tasks)]
4849

4950
# 等待所有任务完成
5051
await asyncio.gather(*tasks)
@@ -53,4 +54,9 @@ async def main():
5354

5455

5556
if __name__ == "__main__":
56-
asyncio.run(main())
57+
if sys.version_info >= (3, 11):
58+
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
59+
runner.run(main())
60+
else:
61+
uvloop.install()
62+
asyncio.run(main())

0 commit comments

Comments
 (0)