Skip to content

Commit 4d2a005

Browse files
exiaohuliyuxuan-bd
authored andcommitted
fix: add example for batch embeddings
1 parent 4ab4c26 commit 4d2a005

File tree

9 files changed

+331
-10
lines changed

9 files changed

+331
-10
lines changed

volcenginesdkarkruntime/resources/batch/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _calculate_retry_timeout(retry_times: int) -> float:
2323
return timeout if timeout >= 0 else 0
2424

2525

26-
def _get_retry_after(response: httpx.Response) -> int | None:
26+
def _get_retry_after(response: httpx.Response) -> Optional[int]:
2727
retry_after = response.headers.get("Retry-After")
2828
if retry_after is not None:
2929
if retry_after.isdigit():

volcenginesdkarkruntime/resources/batch/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def create(
3737
extra_body: Body | None = None,
3838
timeout: float | httpx.Timeout | None = None,
3939
) -> CreateEmbeddingResponse:
40-
deadline = get_request_last_time(timeout)
40+
deadline = get_request_last_time(self._client, timeout)
4141
breaker = self._client.get_model_breaker(model)
4242

4343
return with_batch_retry(
@@ -81,7 +81,7 @@ async def create(
8181
extra_body: Body | None = None,
8282
timeout: float | httpx.Timeout | None = None,
8383
) -> CreateEmbeddingResponse:
84-
deadline = get_request_last_time(timeout)
84+
deadline = get_request_last_time(self._client, timeout)
8585
breaker = await self._client.get_model_breaker(model)
8686

8787
return await with_batch_retry(

volcenginesdkarkruntime/resources/batch/multimodal_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def create(
4343
extra_body: Body | None = None,
4444
timeout: float | httpx.Timeout | None = None,
4545
) -> MultimodalEmbeddingResponse:
46-
deadline = get_request_last_time(timeout)
46+
deadline = get_request_last_time(self._client, timeout)
4747
breaker = self._client.get_model_breaker(model)
4848

4949
return with_batch_retry(
@@ -87,7 +87,7 @@ async def create(
8787
extra_body: Body | None = None,
8888
timeout: float | httpx.Timeout | None = None,
8989
) -> MultimodalEmbeddingResponse:
90-
deadline = get_request_last_time(timeout)
90+
deadline = get_request_last_time(self._client, timeout)
9191
breaker = await self._client.get_model_breaker(model)
9292

9393
return await async_with_batch_retry(

volcenginesdkexamples/volcenginesdkarkruntime/async_batch_chat_completions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
async def worker(
2020
worker_id: int,
2121
client: AsyncArk,
22-
requests: asyncio.Queue[dict],
22+
requests: "asyncio.Queue[dict]",
2323
):
2424
print(f"Worker {worker_id} is starting.")
2525

@@ -36,7 +36,7 @@ async def worker(
3636

3737
async def main():
3838
start = datetime.now()
39-
max_concurrent_tasks, task_num = 1000, 10000
39+
max_concurrent_tasks, task_num = 10, 100
4040

4141
requests = asyncio.Queue()
4242
client = AsyncArk(timeout=24 * 3600)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import asyncio
2+
import sys
3+
from datetime import datetime
4+
5+
from volcenginesdkarkruntime import AsyncArk
6+
7+
# Authentication
8+
# 1.If you authorize your endpoint using an API key, you can set your api key to environment variable "ARK_API_KEY"
9+
# or specify api key by Ark(api_key="${YOUR_API_KEY}").
10+
# Note: If you use an API key, this API key will not be refreshed.
11+
# To prevent the API from expiring and failing after some time, choose an API key with no expiration date.
12+
13+
# 2.If you authorize your endpoint with Volcengine Identity and Access Management(IAM), set your api key to environment variable "VOLC_ACCESSKEY", "VOLC_SECRETKEY"
14+
# or specify ak&sk by Ark(ak="${YOUR_AK}", sk="${YOUR_SK}").
15+
# To get your ak&sk, please refer to this document(https://www.volcengine.com/docs/6291/65568)
16+
# For more information,please check this document(https://www.volcengine.com/docs/82379/1263279)
17+
18+
19+
async def worker(
20+
worker_id: int,
21+
client: AsyncArk,
22+
requests: "asyncio.Queue[dict]",
23+
):
24+
print(f"Worker {worker_id} is starting.")
25+
26+
while True:
27+
request = await requests.get()
28+
try:
29+
completion = await client.batch.embeddings.create(**request)
30+
print(completion)
31+
except Exception as e:
32+
print(e, file=sys.stderr)
33+
finally:
34+
requests.task_done()
35+
36+
37+
async def main():
38+
start = datetime.now()
39+
max_concurrent_tasks, task_num = 10, 100
40+
41+
requests = asyncio.Queue()
42+
client = AsyncArk(timeout=24 * 3600)
43+
44+
# mock `task_num` tasks
45+
for _ in range(task_num):
46+
await requests.put(
47+
{"model": "${YOUR_ENDPOINT_ID}", "input": ["花椰菜又称菜花、花菜,是一种常见的蔬菜。"]}
48+
)
49+
50+
# create `max_concurrent_tasks` workers and start them
51+
tasks = [
52+
asyncio.create_task(worker(i, client, requests))
53+
for i in range(max_concurrent_tasks)
54+
]
55+
56+
# wait for all requests is done
57+
await requests.join()
58+
59+
# stop workers
60+
for task in tasks:
61+
task.cancel()
62+
63+
# wait for all workers is canceled
64+
await asyncio.gather(*tasks, return_exceptions=True)
65+
await client.close()
66+
67+
end = datetime.now()
68+
print(f"Total time: {end - start}, Total task: {task_num}")
69+
70+
71+
if __name__ == "__main__":
72+
asyncio.run(main())
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
import sys
3+
from datetime import datetime
4+
5+
from volcenginesdkarkruntime import AsyncArk
6+
7+
# Authentication
8+
# 1.If you authorize your endpoint using an API key, you can set your api key to environment variable "ARK_API_KEY"
9+
# or specify api key by Ark(api_key="${YOUR_API_KEY}").
10+
# Note: If you use an API key, this API key will not be refreshed.
11+
# To prevent the API from expiring and failing after some time, choose an API key with no expiration date.
12+
13+
# 2.If you authorize your endpoint with Volcengine Identity and Access Management(IAM), set your api key to environment variable "VOLC_ACCESSKEY", "VOLC_SECRETKEY"
14+
# or specify ak&sk by Ark(ak="${YOUR_AK}", sk="${YOUR_SK}").
15+
# To get your ak&sk, please refer to this document(https://www.volcengine.com/docs/6291/65568)
16+
# For more information,please check this document(https://www.volcengine.com/docs/82379/1263279)
17+
18+
19+
async def worker(
20+
worker_id: int,
21+
client: AsyncArk,
22+
requests: "asyncio.Queue[dict]",
23+
):
24+
print(f"Worker {worker_id} is starting.")
25+
26+
while True:
27+
request = await requests.get()
28+
try:
29+
completion = await client.batch.multimodal_embeddings.create(**request)
30+
print(completion)
31+
except Exception as e:
32+
print(e, file=sys.stderr)
33+
finally:
34+
requests.task_done()
35+
36+
37+
async def main():
38+
start = datetime.now()
39+
max_concurrent_tasks, task_num = 10, 100
40+
41+
requests = asyncio.Queue()
42+
client = AsyncArk(timeout=24 * 3600)
43+
44+
# mock `task_num` tasks
45+
for _ in range(task_num):
46+
await requests.put(
47+
{
48+
"model": "${YOUR_ENDPOINT_ID}",
49+
"input": [
50+
{"type": "text", "text": "What is the weather like today?"},
51+
{
52+
"type": "image_url",
53+
"image_url": {
54+
"url": "https://ark-project.tos-cn-beijing.volces.com/images/view.jpeg"
55+
},
56+
},
57+
],
58+
}
59+
)
60+
61+
# create `max_concurrent_tasks` workers and start them
62+
tasks = [
63+
asyncio.create_task(worker(i, client, requests))
64+
for i in range(max_concurrent_tasks)
65+
]
66+
67+
# wait for all requests is done
68+
await requests.join()
69+
70+
# stop workers
71+
for task in tasks:
72+
task.cancel()
73+
74+
# wait for all workers is canceled
75+
await asyncio.gather(*tasks, return_exceptions=True)
76+
await client.close()
77+
78+
end = datetime.now()
79+
print(f"Total time: {end - start}, Total task: {task_num}")
80+
81+
82+
if __name__ == "__main__":
83+
asyncio.run(main())

volcenginesdkexamples/volcenginesdkarkruntime/batch_chat_completions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def worker(
2121
worker_id: int,
2222
client: Ark,
23-
requests: queue.Queue[dict],
23+
requests: "queue.Queue[dict]",
2424
):
2525
print(f"Worker {worker_id} is starting.")
2626

@@ -45,7 +45,7 @@ def worker(
4545

4646
def main():
4747
start = datetime.now()
48-
max_concurrent_tasks, task_num = 1000, 10000
48+
max_concurrent_tasks, task_num = 10, 100
4949

5050
requests = queue.Queue()
5151
client = Ark(timeout=24 * 3600)
@@ -72,7 +72,6 @@ def main():
7272
with ThreadPool(max_concurrent_tasks) as pool:
7373
for i in range(max_concurrent_tasks):
7474
pool.apply_async(worker, args=(i, client, requests))
75-
pool.apply_async(worker, args=(i, client, requests))
7675

7776
# wait for all request to done
7877
pool.close()
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import queue
2+
import sys
3+
from datetime import datetime
4+
from multiprocessing.pool import ThreadPool
5+
6+
from volcenginesdkarkruntime import Ark
7+
8+
# Authentication
9+
# 1.If you authorize your endpoint using an API key, you can set your api key to environment variable "ARK_API_KEY"
10+
# or specify api key by Ark(api_key="${YOUR_API_KEY}").
11+
# Note: If you use an API key, this API key will not be refreshed.
12+
# To prevent the API from expiring and failing after some time, choose an API key with no expiration date.
13+
14+
# 2.If you authorize your endpoint with Volcengine Identity and Access Management(IAM), set your api key to environment variable "VOLC_ACCESSKEY", "VOLC_SECRETKEY"
15+
# or specify ak&sk by Ark(ak="${YOUR_AK}", sk="${YOUR_SK}").
16+
# To get your ak&sk, please refer to this document(https://www.volcengine.com/docs/6291/65568)
17+
# For more information,please check this document(https://www.volcengine.com/docs/82379/1263279)
18+
19+
20+
def worker(
21+
worker_id: int,
22+
client: Ark,
23+
requests: "queue.Queue[dict]",
24+
):
25+
print(f"Worker {worker_id} is starting.")
26+
27+
while True:
28+
request = requests.get()
29+
30+
# check for signal of no more request
31+
if not request:
32+
# put back on the queue for other workers
33+
requests.put(request)
34+
return
35+
36+
try:
37+
# do request
38+
completion = client.batch.embeddings.create(**request)
39+
print(completion)
40+
except Exception as e:
41+
print(e, file=sys.stderr)
42+
finally:
43+
requests.task_done()
44+
45+
46+
def main():
47+
start = datetime.now()
48+
max_concurrent_tasks, task_num = 10, 100
49+
50+
requests = queue.Queue()
51+
client = Ark(timeout=24 * 3600)
52+
53+
# mock `task_num` tasks
54+
for _ in range(task_num):
55+
requests.put(
56+
{"model": "${YOUR_ENDPOINT_ID}", "input": ["花椰菜又称菜花、花菜,是一种常见的蔬菜。"]}
57+
)
58+
59+
# put a signal of no more request
60+
requests.put(None)
61+
62+
# create `max_concurrent_tasks` workers and start them
63+
with ThreadPool(max_concurrent_tasks) as pool:
64+
for i in range(max_concurrent_tasks):
65+
pool.apply_async(worker, args=(i, client, requests))
66+
67+
# wait for all request to done
68+
pool.close()
69+
pool.join()
70+
71+
client.close()
72+
73+
end = datetime.now()
74+
print(f"Total time: {end - start}, Total task: {task_num}")
75+
76+
77+
if __name__ == "__main__":
78+
main()

0 commit comments

Comments
 (0)