|
2 | 2 | import sys |
3 | 3 | from datetime import datetime |
4 | 4 |
|
5 | | -import uvloop |
6 | | - |
7 | 5 | from volcenginesdkarkruntime import AsyncArk |
8 | 6 |
|
9 | 7 |
|
|
18 | 16 | # To get your ak&sk, please refer to this document(https://www.volcengine.com/docs/6291/65568) |
19 | 17 | # For more information,please check this document(https://www.volcengine.com/docs/82379/1263279) |
20 | 18 |
|
21 | | -async def worker(worker_id, task_num): |
22 | | - client = AsyncArk() |
| 19 | +async def worker( |
| 20 | + worker_id: int, |
| 21 | + client: AsyncArk, |
| 22 | + requests: asyncio.Queue[dict], |
| 23 | +): |
23 | 24 | print(f"Worker {worker_id} is starting.") |
24 | | - for i in range(task_num): |
25 | | - print(f"Worker {worker_id} task {i} is running.") |
| 25 | + |
| 26 | + while True: |
| 27 | + request = await requests.get() |
26 | 28 | try: |
27 | | - completion = await client.batch_chat.completions.create( |
28 | | - model="${YOUR_ENDPOINT_ID}", |
29 | | - messages=[ |
30 | | - {"role": "system", "content": "你是豆包,是由字节跳动开发的 AI 人工智能助手"}, |
31 | | - {"role": "user", "content": "常见的十字花科植物有哪些?"}, |
32 | | - ], |
33 | | - ) |
34 | | - print(completion.choices[0].message.content) |
| 29 | + completion = await client.batch_chat.completions.create(**request) |
| 30 | + print(completion) |
35 | 31 | except Exception as e: |
36 | | - print(f"Worker {worker_id} task {i} failed with error: {e}") |
37 | | - else: |
38 | | - print(f"Worker {worker_id} task {i} is completed.") |
39 | | - print(f"Worker {worker_id} is completed.") |
| 32 | + print(e, file=sys.stderr) |
| 33 | + finally: |
| 34 | + requests.task_done() |
40 | 35 |
|
41 | 36 |
|
42 | 37 | async def main(): |
43 | 38 | start = datetime.now() |
44 | | - max_concurrent_tasks = 1000 |
45 | | - task_num = 5 |
| 39 | + max_concurrent_tasks, task_num = 1000, 10000 |
| 40 | + |
| 41 | + requests = asyncio.Queue() |
| 42 | + client = AsyncArk(timeout=24 * 3600) |
| 43 | + |
| 44 | + # mock `task_num` tasks |
| 45 | + for _ in range(task_num): |
| 46 | + await requests.put( |
| 47 | + { |
| 48 | + "model": "${YOUR_ENDPOINT_ID}", |
| 49 | + "messages": [ |
| 50 | + { |
| 51 | + "role": "system", |
| 52 | + "content": "你是豆包,是由字节跳动开发的 AI 人工智能助手", |
| 53 | + }, |
| 54 | + {"role": "user", "content": "常见的十字花科植物有哪些?"}, |
| 55 | + ], |
| 56 | + } |
| 57 | + ) |
| 58 | + |
| 59 | + # create `max_concurrent_tasks` workers and start them |
| 60 | + tasks = [ |
| 61 | + asyncio.create_task(worker(i, client, requests)) |
| 62 | + for i in range(max_concurrent_tasks) |
| 63 | + ] |
| 64 | + |
| 65 | + # wait for all requests is done |
| 66 | + await requests.join() |
| 67 | + |
| 68 | + # stop workers |
| 69 | + for task in tasks: |
| 70 | + task.cancel() |
46 | 71 |
|
47 | | - # 创建任务列表 |
48 | | - tasks = [worker(i, task_num) for i in range(max_concurrent_tasks)] |
| 72 | + # wait for all workers is canceled |
| 73 | + await asyncio.gather(*tasks, return_exceptions=True) |
| 74 | + await client.close() |
49 | 75 |
|
50 | | - # 等待所有任务完成 |
51 | | - await asyncio.gather(*tasks) |
52 | 76 | end = datetime.now() |
53 | | - print(f"Total time: {end - start}, Total task: {max_concurrent_tasks * task_num}") |
| 77 | + print(f"Total time: {end - start}, Total task: {task_num}") |
54 | 78 |
|
55 | 79 |
|
56 | 80 | if __name__ == "__main__": |
57 | | - if sys.version_info >= (3, 11): |
58 | | - with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: |
59 | | - runner.run(main()) |
60 | | - else: |
61 | | - uvloop.install() |
62 | | - asyncio.run(main()) |
| 81 | + asyncio.run(main()) |
0 commit comments