Skip to content

Commit ff19874

Browse files
exiaohuliyuxuan-bd
authored andcommitted
feat: batch embeddings
1 parent e7cfd26 commit ff19874

File tree

5 files changed

+351
-233
lines changed

5 files changed

+351
-233
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import asyncio
2+
import logging
3+
import time
4+
from datetime import datetime, timedelta
5+
from random import random
6+
from typing import Any, Awaitable, Callable, Optional, TypeVar
7+
8+
import httpx
9+
10+
from ..._constants import INITIAL_RETRY_DELAY, MAX_RETRY_DELAY
11+
from ..._exceptions import ArkAPIConnectionError, ArkAPIStatusError, ArkAPITimeoutError
12+
from ..._utils._model_breaker import ModelBreaker
13+
14+
log: logging.Logger = logging.getLogger(__name__)
15+
16+
17+
def _calculate_retry_timeout(retry_times: int) -> float:
18+
nbRetries = min(retry_times, MAX_RETRY_DELAY / INITIAL_RETRY_DELAY)
19+
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2, nbRetries), MAX_RETRY_DELAY)
20+
# Apply some jitter, plus-or-minus half a second.
21+
jitter = 1 - 0.25 * random()
22+
timeout = sleep_seconds * jitter
23+
return timeout if timeout >= 0 else 0
24+
25+
26+
def _get_retry_after(response: httpx.Response) -> int | None:
27+
retry_after = response.headers.get("Retry-After")
28+
if retry_after is not None:
29+
if retry_after.isdigit():
30+
return int(retry_after)
31+
return None
32+
33+
34+
def _should_retry(response: httpx.Response) -> bool:
35+
# Retry on request timeouts.
36+
if response.status_code == 408:
37+
return True
38+
39+
# Retry on lock timeouts.
40+
if response.status_code == 409:
41+
return True
42+
43+
# Retry on rate limits.
44+
if response.status_code == 429:
45+
return True
46+
47+
# Retry internal errors.
48+
if response.status_code >= 500:
49+
return True
50+
51+
return False
52+
53+
54+
def get_request_last_time(
55+
client: httpx.Client, timeout: Optional[Any] = None
56+
) -> datetime:
57+
if timeout is None:
58+
timeout = client.timeout
59+
timeoutSeconds = 0
60+
if isinstance(timeout, httpx.Timeout):
61+
timeoutSeconds = timeout.read
62+
elif isinstance(timeout, float):
63+
timeoutSeconds = timeout
64+
elif isinstance(timeout, int):
65+
timeoutSeconds = timeout
66+
else:
67+
raise TypeError("timeout type {} is not supported".format(type(timeout)))
68+
return datetime.now() + timedelta(seconds=timeoutSeconds)
69+
70+
71+
R = TypeVar("R")
72+
73+
74+
def with_batch_retry(
75+
deadline: datetime,
76+
breaker: ModelBreaker,
77+
func: Callable[..., R],
78+
*args,
79+
**kwargs,
80+
) -> R:
81+
retry_times = 0
82+
while True:
83+
breaker.wait()
84+
if datetime.now() > deadline:
85+
raise ArkAPITimeoutError(None, None)
86+
try:
87+
return func(*args, **kwargs)
88+
except ArkAPIConnectionError:
89+
waitTime = _calculate_retry_timeout(retry_times)
90+
if datetime.now() + timedelta(seconds=waitTime) > deadline:
91+
raise ArkAPITimeoutError(None, None)
92+
time.sleep(waitTime)
93+
except ArkAPIStatusError as err:
94+
retry_after = _get_retry_after(err.response)
95+
if retry_after is not None and retry_after > 0:
96+
breaker.reset(retry_after)
97+
if not _should_retry(err.response):
98+
raise err
99+
100+
retry_times = retry_times + 1
101+
102+
103+
async def async_with_batch_retry(
104+
deadline: datetime,
105+
breaker: ModelBreaker,
106+
func: Callable[..., Awaitable[R]],
107+
*args,
108+
**kwargs,
109+
) -> R:
110+
retry_times = 0
111+
while True:
112+
await breaker.asyncwait()
113+
if datetime.now() > deadline:
114+
raise ArkAPITimeoutError(None, None)
115+
try:
116+
return await func(*args, **kwargs)
117+
except ArkAPIConnectionError:
118+
waitTime = _calculate_retry_timeout(retry_times)
119+
if datetime.now() + timedelta(seconds=waitTime) > deadline:
120+
raise ArkAPITimeoutError(None, None)
121+
await asyncio.sleep(waitTime)
122+
except ArkAPIStatusError as err:
123+
retry_after = _get_retry_after(err.response)
124+
if retry_after is not None and retry_after > 0:
125+
breaker.reset(retry_after)
126+
if not _should_retry(err.response):
127+
raise err
128+
129+
retry_times = retry_times + 1
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from ..._compat import cached_property
4-
from .chat.chat import Chat, AsyncChat
5-
from ..._resource import SyncAPIResource, AsyncAPIResource
4+
from ..._resource import AsyncAPIResource, SyncAPIResource
5+
from .chat.chat import AsyncChat, Chat
6+
from .embeddings import AsyncEmbeddings, Embeddings
67

78
__all__ = ["Batch", "AsyncBatch"]
89

@@ -12,8 +13,16 @@ class Batch(SyncAPIResource):
1213
def chat(self) -> Chat:
1314
return Chat(self._client)
1415

16+
@cached_property
17+
def embeddings(self) -> Embeddings:
18+
return Embeddings(self._client)
19+
1520

1621
class AsyncBatch(AsyncAPIResource):
1722
@cached_property
1823
def chat(self) -> AsyncChat:
1924
return AsyncChat(self._client)
25+
26+
@cached_property
27+
def embeddings(self) -> AsyncEmbeddings:
28+
return AsyncEmbeddings(self._client)

volcenginesdkarkruntime/resources/batch/chat/chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from ...._compat import cached_property
4-
from .completions import Completions, AsyncCompletions
5-
from ...._resource import SyncAPIResource, AsyncAPIResource
4+
from ...._resource import AsyncAPIResource, SyncAPIResource
5+
from .completions import AsyncCompletions, Completions
66

77
__all__ = ["Chat", "AsyncChat"]
88

0 commit comments

Comments
 (0)