Skip to content

Commit c0a21a7

Browse files
committed
feat(*): support ark batch chat
1 parent 6cf728f commit c0a21a7

File tree

11 files changed

+687
-1
lines changed

11 files changed

+687
-1
lines changed

volcenginesdkarkruntime/_base_client.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,57 @@ def _request(
478478
stream_cls=stream_cls,
479479
)
480480

481+
def _request_without_retry(self,
482+
*,
483+
cast_to: Type[ResponseT],
484+
options: RequestOptions,
485+
stream: bool,
486+
stream_cls: type[_StreamT] | None,
487+
) -> ResponseT | _StreamT:
488+
request = self._build_request(options)
489+
req_id = request.headers.get(CLIENT_REQUEST_HEADER, "")
490+
try:
491+
response = self._client.send(
492+
request,
493+
stream=stream or self._should_stream_response_body(request=request),
494+
)
495+
except httpx.TimeoutException as err:
496+
log.debug("Raising timeout error")
497+
raise ArkAPITimeoutError(request=request, request_id=req_id) from err
498+
except Exception as err:
499+
log.debug("Encountered Exception", exc_info=True)
500+
log.debug("Raising connection error")
501+
raise ArkAPIConnectionError(request=request, request_id=req_id) from err
502+
503+
log.debug(
504+
'HTTP Request: %s %s "%i %s"',
505+
request.method,
506+
request.url,
507+
response.status_code,
508+
response.reason_phrase,
509+
)
510+
511+
try:
512+
response.raise_for_status()
513+
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
514+
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
515+
# If the response is streamed then we need to explicitly read the response
516+
# to completion before attempting to access the response text.
517+
if not err.response.is_closed:
518+
err.response.read()
519+
520+
log.debug("Re-raising status error")
521+
raise self._make_status_error_from_response(
522+
err.response, request_id=req_id
523+
) from None
524+
525+
return self._process_response(
526+
cast_to=cast_to,
527+
response=response,
528+
stream=stream,
529+
stream_cls=stream_cls,
530+
)
531+
481532
def _retry_request(
482533
self,
483534
options: RequestOptions,
@@ -595,6 +646,28 @@ def delete(
595646

596647
return cast(ResponseT, self.request(cast_to, opts))
597648

649+
def post_without_retry(
650+
self,
651+
path: str,
652+
*,
653+
cast_to: Type[ResponseT],
654+
body: Dict | None = None,
655+
options: ExtraRequestOptions = {},
656+
files: RequestFiles | None = None,
657+
stream: bool = False,
658+
stream_cls: type[_StreamT] | None = None,
659+
) -> ResponseT | _StreamT:
660+
opts = RequestOptions.construct( # type: ignore
661+
method="post",
662+
url=path,
663+
body=body,
664+
**options,
665+
)
666+
667+
return cast(
668+
ResponseT, self.request_without_retry(cast_to, opts, stream=stream, stream_cls=stream_cls)
669+
)
670+
598671
def request(
599672
self,
600673
cast_to: Type[ResponseT],
@@ -612,6 +685,21 @@ def request(
612685
remaining_retries=remaining_retries,
613686
)
614687

688+
def request_without_retry(
689+
self,
690+
cast_to: Type[ResponseT],
691+
options: RequestOptions,
692+
*,
693+
stream: bool = False,
694+
stream_cls: type[_StreamT] | None = None,
695+
) -> ResponseT | _StreamT:
696+
return self._request_without_retry(
697+
cast_to=cast_to,
698+
options=options,
699+
stream=stream,
700+
stream_cls=stream_cls,
701+
)
702+
615703
def is_closed(self) -> bool:
616704
return self._client.is_closed
617705

@@ -755,6 +843,27 @@ async def delete(
755843

756844
return await self.request(cast_to, opts)
757845

846+
847+
async def post_without_retry(
848+
self,
849+
path: str,
850+
*,
851+
cast_to: Type[ResponseT],
852+
body: Dict | None = None,
853+
options: ExtraRequestOptions = {},
854+
files: RequestFiles | None = None,
855+
stream: bool = False,
856+
stream_cls: type[_AsyncStreamT] | None = None,
857+
) -> ResponseT | _AsyncStreamT:
858+
opts = RequestOptions.construct(
859+
method="post",
860+
url=path,
861+
body=body,
862+
**options,
863+
)
864+
865+
return await self.request_without_retry(cast_to, opts, stream=stream, stream_cls=stream_cls)
866+
758867
async def request(
759868
self,
760869
cast_to: Type[ResponseT],
@@ -772,6 +881,21 @@ async def request(
772881
remaining_retries=remaining_retries,
773882
)
774883

884+
async def request_without_retry(
885+
self,
886+
cast_to: Type[ResponseT],
887+
options: RequestOptions,
888+
*,
889+
stream: bool = False,
890+
stream_cls: type[_StreamT] | None = None,
891+
) -> ResponseT | _StreamT:
892+
return await self._request_without_retry(
893+
cast_to=cast_to,
894+
options=options,
895+
stream=stream,
896+
stream_cls=stream_cls,
897+
)
898+
775899
async def _request(
776900
self,
777901
*,
@@ -859,6 +983,57 @@ async def _request(
859983
stream_cls=stream_cls,
860984
)
861985

986+
async def _request_without_retry(
987+
self,
988+
*,
989+
cast_to: Type[ResponseT],
990+
options: RequestOptions,
991+
stream: bool,
992+
stream_cls: type[_AsyncStreamT] | None,
993+
) -> ResponseT | _AsyncStreamT:
994+
request = self._build_request(options)
995+
req_id = request.headers.get(CLIENT_REQUEST_HEADER, "")
996+
try:
997+
response = await self._client.send(
998+
request,
999+
stream=stream or self._should_stream_response_body(request=request),
1000+
)
1001+
except httpx.TimeoutException as err:
1002+
log.debug("Raising timeout error")
1003+
raise ArkAPITimeoutError(request=request, request_id=req_id) from err
1004+
except Exception as err:
1005+
log.debug("Encountered Exception", exc_info=True)
1006+
log.debug("Raising connection error")
1007+
raise ArkAPIConnectionError(request=request, request_id=req_id) from err
1008+
log.debug(
1009+
'HTTP Request: %s %s "%i %s"',
1010+
request.method,
1011+
request.url,
1012+
response.status_code,
1013+
response.reason_phrase,
1014+
)
1015+
try:
1016+
response.raise_for_status()
1017+
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
1018+
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
1019+
1020+
# If the response is streamed then we need to explicitly read the response
1021+
# to completion before attempting to access the response text.
1022+
if not err.response.is_closed:
1023+
await err.response.aread()
1024+
1025+
log.debug("Re-raising status error")
1026+
raise self._make_status_error_from_response(
1027+
err.response, request_id=req_id
1028+
) from None
1029+
1030+
return await self._process_response(
1031+
cast_to=cast_to,
1032+
response=response,
1033+
stream=stream,
1034+
stream_cls=stream_cls,
1035+
)
1036+
8621037
async def _retry_request(
8631038
self,
8641039
options: RequestOptions,

volcenginesdkarkruntime/_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ._streaming import Stream
2929

3030
from ._utils._key_agreement import key_agreement_client
31+
from ._utils._model_breaker import ModelBreaker
3132

3233
__all__ = ["Ark", "AsyncArk"]
3334

@@ -39,6 +40,9 @@ class Ark(SyncAPIClient):
3940
tokenization: resources.Tokenization
4041
context: resources.Context
4142
content_generation: resources.ContentGeneration
43+
batch_chat: resources.BatchChat
44+
model_breaker_map: dict[str, ModelBreaker]
45+
model_breaker_lock: threading.Lock
4246

4347
def __init__(
4448
self,
@@ -98,6 +102,9 @@ def __init__(
98102
self.tokenization = resources.Tokenization(self)
99103
self.context = resources.Context(self)
100104
self.content_generation = resources.ContentGeneration(self)
105+
self.batch_chat = resources.BatchChat(self)
106+
self.model_breaker_map = defaultdict(ModelBreaker)
107+
self.model_breaker_lock = threading.Lock()
101108
# self.classification = resources.Classification(self)
102109

103110
def _get_endpoint_sts_token(self, endpoint_id: str):
@@ -128,6 +135,9 @@ def auth_headers(self) -> dict[str, str]:
128135
api_key = self.api_key
129136
return {"Authorization": f"Bearer {api_key}"}
130137

138+
def get_model_breaker(self, model_name: str) -> ModelBreaker:
139+
with self.model_breaker_lock:
140+
return self.model_breaker_map[model_name]
131141

132142
class AsyncArk(AsyncAPIClient):
133143
chat: resources.AsyncChat
@@ -136,6 +146,9 @@ class AsyncArk(AsyncAPIClient):
136146
tokenization: resources.AsyncTokenization
137147
context: resources.AsyncContext
138148
content_generation: resources.AsyncContentGeneration
149+
batch_chat: resources.AsyncBatchChat
150+
model_breaker_map: dict[str, ModelBreaker]
151+
model_breaker_lock: threading.Lock
139152

140153
def __init__(
141154
self,
@@ -194,6 +207,9 @@ def __init__(
194207
self.tokenization = resources.AsyncTokenization(self)
195208
self.context = resources.AsyncContext(self)
196209
self.content_generation = resources.AsyncContentGeneration(self)
210+
self.batch_chat = resources.AsyncBatchChat(self)
211+
self.model_breaker_map = defaultdict(ModelBreaker)
212+
self.model_breaker_lock = threading.Lock()
197213
# self.classification = resources.AsyncClassification(self)
198214

199215
def _get_endpoint_sts_token(self, endpoint_id: str):
@@ -217,6 +233,10 @@ def auth_headers(self) -> dict[str, str]:
217233
api_key = self.api_key
218234
return {"Authorization": f"Bearer {api_key}"}
219235

236+
def get_model_breaker(self, model_name: str) -> ModelBreaker:
237+
with self.model_breaker_lock:
238+
return self.model_breaker_map[model_name]
239+
220240

221241
class StsTokenManager(object):
222242
# The time at which we'll attempt to refresh, but not

volcenginesdkarkruntime/_constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
SERVER_REQUEST_HEADER = "X-Request-Id"
1111
ARK_E2E_ENCRYPTION_HEADER = "x-is-encrypted"
1212

13+
DEFAULT_TIMEOUT_SECONDS = 600.0
14+
DEFAULT_CONNECT_TIMEOUT_SECONDS = 60.0
1315
# default timeout is 1 minutes
14-
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=60.0)
16+
DEFAULT_TIMEOUT = httpx.Timeout(timeout=DEFAULT_TIMEOUT_SECONDS, connect=DEFAULT_CONNECT_TIMEOUT_SECONDS)
17+
1518
DEFAULT_MAX_RETRIES = 2
1619
DEFAULT_CONNECTION_LIMITS = httpx.Limits(
1720
max_connections=1000, max_keepalive_connections=100

volcenginesdkarkruntime/_resource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(self, client: "Ark") -> None:
1212
self._post = client.post
1313
self._get = client.get
1414
self._delete = client.delete
15+
self._post_without_retry = client.post_without_retry
1516

1617

1718
class AsyncAPIResource:
@@ -22,3 +23,4 @@ def __init__(self, client: "AsyncArk") -> None:
2223
self._post = client.post
2324
self._get = client.get
2425
self._delete = client.delete
26+
self._post_without_retry = client.post_without_retry
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from datetime import datetime, timedelta
2+
3+
4+
class ModelBreaker:
5+
def __init__(self):
6+
# 初始化 allow_time 为当前时间
7+
self.allow_time = datetime.now()
8+
9+
def allow(self):
10+
# 检查当前时间是否在 allow_time 之后
11+
return datetime.now() > self.allow_time
12+
13+
def reset(self, duration):
14+
# 将 allow_time 重置为当前时间加上指定的持续时间
15+
self.allow_time = datetime.now() + timedelta(seconds=duration.total_seconds())
16+
17+
def get_allowed_duration(self):
18+
# 计算当前时间与 allow_time 之间的持续时间
19+
allow_duration = self.allow_time - datetime.now()
20+
# 如果持续时间为负,则返回一个零时长的 timedelta 对象
21+
if allow_duration.total_seconds() < 0:
22+
return timedelta(0)
23+
return allow_duration

volcenginesdkarkruntime/resources/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .bot import BotChat, AsyncBotChat
66
from .context import Context, AsyncContext
77
from .content_generation import ContentGeneration, AsyncContentGeneration
8+
from .batch_chat import BatchChat, AsyncBatchChat
89

910
__all__ = [
1011
"Chat",
@@ -19,4 +20,6 @@
1920
"AsyncContext",
2021
"ContentGeneration",
2122
"AsyncContentGeneration"
23+
"BatchChat",
24+
"AsyncBatchChat"
2225
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .chat import BatchChat, AsyncBatchChat
2+
3+
__all__ = ["BatchChat", "AsyncBatchChat"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
from .completions import Completions, AsyncCompletions
6+
from ..._compat import cached_property
7+
from ..._resource import SyncAPIResource, AsyncAPIResource
8+
9+
__all__ = ["BatchChat", "AsyncBatchChat"]
10+
11+
12+
class BatchChat(SyncAPIResource):
13+
@cached_property
14+
def completions(self) -> Completions:
15+
return Completions(self._client)
16+
17+
18+
class AsyncBatchChat(AsyncAPIResource):
19+
@cached_property
20+
def completions(self) -> AsyncCompletions:
21+
return AsyncCompletions(self._client)

0 commit comments

Comments
 (0)