Skip to content

Commit a6446b0

Browse files
committed
feat(*): support ark batch chat
1 parent 0006578 commit a6446b0

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

volcenginesdkarkruntime/resources/batch_chat/completions.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import time
45
from datetime import timedelta, datetime
56
from random import random
@@ -142,9 +143,7 @@ def create(
142143
is_encrypt = True
143144
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
144145
retryTimes = 0
145-
if timeout is None:
146-
timeout = self._client.timeout
147-
last_time = datetime.now() + timedelta(seconds=timeout.read)
146+
last_time = self._get_request_last_time(timeout)
148147
model_breaker = self._client.get_model_breaker(model)
149148
while True:
150149
while not model_breaker.allow():
@@ -203,6 +202,19 @@ def create(
203202
resp = self._decrypt(e2e_key, e2e_nonce, resp)
204203
return resp
205204

205+
def _get_request_last_time(self, timeout):
206+
if timeout is None:
207+
timeout = self._client.timeout
208+
timeoutSeconds = 0
209+
if isinstance(timeout, httpx.Timeout):
210+
timeoutSeconds = timeout.read
211+
elif isinstance(timeout, float):
212+
timeoutSeconds = timeout
213+
elif isinstance(self._client.timeout, int):
214+
timeoutSeconds = timeout
215+
else:
216+
raise TypeError("timeout type {} is not supported".format(type(self._client.timeout)))
217+
return datetime.now() + timedelta(seconds=timeoutSeconds)
206218

207219
class AsyncCompletions(AsyncAPIResource):
208220
@cached_property
@@ -272,15 +284,13 @@ async def create(
272284
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
273285

274286
retryTimes = 0
275-
if timeout is None:
276-
timeout = self._client.timeout
277-
last_time = datetime.now() + timedelta(seconds=timeout.read)
287+
last_time = self._get_request_last_time(timeout)
278288
model_breaker = self._client.get_model_breaker(model)
279289
while True:
280290
while not model_breaker.allow():
281291
if datetime.now() + timedelta(seconds=model_breaker.get_allowed_duration().total_seconds()) > last_time:
282292
raise ArkAPITimeoutError()
283-
time.sleep(model_breaker.get_allowed_duration().total_seconds())
293+
await asyncio.sleep(model_breaker.get_allowed_duration().total_seconds())
284294
if datetime.now() > last_time:
285295
raise ArkAPITimeoutError()
286296
try:
@@ -318,7 +328,7 @@ async def create(
318328
waitTime = _calculate_retry_timeout(retryTimes)
319329
if datetime.now() + timedelta(seconds=waitTime) > last_time:
320330
raise ArkAPITimeoutError()
321-
time.sleep(waitTime)
331+
await asyncio.sleep(waitTime)
322332
retryTimes = retryTimes + 1
323333
continue
324334
except ArkAPIStatusError as err:
@@ -335,6 +345,19 @@ async def create(
335345
resp = await self._decrypt(e2e_key, e2e_nonce, resp)
336346
return resp
337347

348+
def _get_request_last_time(self, timeout):
349+
if timeout is None:
350+
timeout = self._client.timeout
351+
timeoutSeconds = 0
352+
if isinstance(timeout, httpx.Timeout):
353+
timeoutSeconds = timeout.read
354+
elif isinstance(timeout, float):
355+
timeoutSeconds = timeout
356+
elif isinstance(self._client.timeout, int):
357+
timeoutSeconds = timeout
358+
else:
359+
raise TypeError("timeout type {} is not supported".format(type(self._client.timeout)))
360+
return datetime.now() + timedelta(seconds=timeoutSeconds)
338361

339362
class CompletionsWithRawResponse:
340363
def __init__(self, completions: Completions) -> None:

0 commit comments

Comments
 (0)