Skip to content

Commit 0593923

Browse files
author
BitsAdmin
committed
Merge branch 'feat/ark-runtime-batch' into 'integration_2025-01-16_675380848130'
feat: [development task] ark-runtime-manual-Python (978354) See merge request iaasng/volcengine-python-sdk!491
2 parents 4a5068c + 3ded380 commit 0593923

File tree

11 files changed

+585
-1
lines changed

11 files changed

+585
-1
lines changed

volcenginesdkarkruntime/_base_client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,28 @@ def delete(
595595

596596
return cast(ResponseT, self.request(cast_to, opts))
597597

598+
def post_without_retry(
599+
self,
600+
path: str,
601+
*,
602+
cast_to: Type[ResponseT],
603+
body: Dict | None = None,
604+
options: ExtraRequestOptions = {},
605+
files: RequestFiles | None = None,
606+
stream: bool = False,
607+
stream_cls: type[_StreamT] | None = None,
608+
) -> ResponseT | _StreamT:
609+
opts = RequestOptions.construct( # type: ignore
610+
method="post",
611+
url=path,
612+
body=body,
613+
**options,
614+
)
615+
616+
return cast(
617+
ResponseT, self.request(cast_to, opts, remaining_retries=0, stream=stream, stream_cls=stream_cls)
618+
)
619+
598620
def request(
599621
self,
600622
cast_to: Type[ResponseT],
@@ -755,6 +777,26 @@ async def delete(
755777

756778
return await self.request(cast_to, opts)
757779

780+
async def post_without_retry(
781+
self,
782+
path: str,
783+
*,
784+
cast_to: Type[ResponseT],
785+
body: Dict | None = None,
786+
options: ExtraRequestOptions = {},
787+
files: RequestFiles | None = None,
788+
stream: bool = False,
789+
stream_cls: type[_AsyncStreamT] | None = None,
790+
) -> ResponseT | _AsyncStreamT:
791+
opts = RequestOptions.construct(
792+
method="post",
793+
url=path,
794+
body=body,
795+
**options,
796+
)
797+
798+
return await self.request(cast_to, opts, remaining_retries=0, stream=stream, stream_cls=stream_cls)
799+
758800
async def request(
759801
self,
760802
cast_to: Type[ResponseT],

volcenginesdkarkruntime/_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import logging
45
import os
56
import threading
@@ -28,6 +29,7 @@
2829
from ._streaming import Stream
2930

3031
from ._utils._key_agreement import key_agreement_client
32+
from ._utils._model_breaker import ModelBreaker
3133

3234
__all__ = ["Ark", "AsyncArk"]
3335

@@ -39,6 +41,9 @@ class Ark(SyncAPIClient):
3941
tokenization: resources.Tokenization
4042
context: resources.Context
4143
content_generation: resources.ContentGeneration
44+
batch_chat: resources.BatchChat
45+
model_breaker_map: dict[str, ModelBreaker]
46+
model_breaker_lock: threading.Lock
4247

4348
def __init__(
4449
self,
@@ -98,6 +103,9 @@ def __init__(
98103
self.tokenization = resources.Tokenization(self)
99104
self.context = resources.Context(self)
100105
self.content_generation = resources.ContentGeneration(self)
106+
self.batch_chat = resources.BatchChat(self)
107+
self.model_breaker_map = defaultdict(ModelBreaker)
108+
self.model_breaker_lock = threading.Lock()
101109
# self.classification = resources.Classification(self)
102110

103111
def _get_endpoint_sts_token(self, endpoint_id: str):
@@ -128,6 +136,9 @@ def auth_headers(self) -> dict[str, str]:
128136
api_key = self.api_key
129137
return {"Authorization": f"Bearer {api_key}"}
130138

139+
def get_model_breaker(self, model_name: str) -> ModelBreaker:
140+
with self.model_breaker_lock:
141+
return self.model_breaker_map[model_name]
131142

132143
class AsyncArk(AsyncAPIClient):
133144
chat: resources.AsyncChat
@@ -136,6 +147,9 @@ class AsyncArk(AsyncAPIClient):
136147
tokenization: resources.AsyncTokenization
137148
context: resources.AsyncContext
138149
content_generation: resources.AsyncContentGeneration
150+
batch_chat: resources.AsyncBatchChat
151+
model_breaker_map: dict[str, ModelBreaker]
152+
model_breaker_lock: asyncio.Lock
139153

140154
def __init__(
141155
self,
@@ -194,6 +208,9 @@ def __init__(
194208
self.tokenization = resources.AsyncTokenization(self)
195209
self.context = resources.AsyncContext(self)
196210
self.content_generation = resources.AsyncContentGeneration(self)
211+
self.batch_chat = resources.AsyncBatchChat(self)
212+
self.model_breaker_map = defaultdict(ModelBreaker)
213+
self.model_breaker_lock = asyncio.Lock()
197214
# self.classification = resources.AsyncClassification(self)
198215

199216
def _get_endpoint_sts_token(self, endpoint_id: str):
@@ -217,6 +234,10 @@ def auth_headers(self) -> dict[str, str]:
217234
api_key = self.api_key
218235
return {"Authorization": f"Bearer {api_key}"}
219236

237+
async def get_model_breaker(self, model_name: str) -> ModelBreaker:
238+
async with self.model_breaker_lock:
239+
return self.model_breaker_map[model_name]
240+
220241

221242
class StsTokenManager(object):
222243
# The time at which we'll attempt to refresh, but not

volcenginesdkarkruntime/_constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
SERVER_REQUEST_HEADER = "X-Request-Id"
1111
ARK_E2E_ENCRYPTION_HEADER = "x-is-encrypted"
1212

13+
DEFAULT_TIMEOUT_SECONDS = 600.0
14+
DEFAULT_CONNECT_TIMEOUT_SECONDS = 60.0
1315
# default timeout is 1 minutes
14-
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=60.0)
16+
DEFAULT_TIMEOUT = httpx.Timeout(timeout=DEFAULT_TIMEOUT_SECONDS, connect=DEFAULT_CONNECT_TIMEOUT_SECONDS)
17+
1518
DEFAULT_MAX_RETRIES = 2
1619
DEFAULT_CONNECTION_LIMITS = httpx.Limits(
1720
max_connections=1000, max_keepalive_connections=100

volcenginesdkarkruntime/_resource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(self, client: "Ark") -> None:
1212
self._post = client.post
1313
self._get = client.get
1414
self._delete = client.delete
15+
self._post_without_retry = client.post_without_retry
1516

1617

1718
class AsyncAPIResource:
@@ -22,3 +23,4 @@ def __init__(self, client: "AsyncArk") -> None:
2223
self._post = client.post
2324
self._get = client.get
2425
self._delete = client.delete
26+
self._post_without_retry = client.post_without_retry
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from datetime import datetime, timedelta
2+
3+
4+
class ModelBreaker:
5+
def __init__(self):
6+
# 初始化 allow_time 为当前时间
7+
self.allow_time = datetime.now()
8+
9+
def allow(self):
10+
# 检查当前时间是否在 allow_time 之后
11+
return datetime.now() > self.allow_time
12+
13+
def reset(self, duration):
14+
# 将 allow_time 重置为当前时间加上指定的持续时间
15+
self.allow_time = datetime.now() + timedelta(seconds=duration.total_seconds())
16+
17+
def get_allowed_duration(self):
18+
# 计算当前时间与 allow_time 之间的持续时间
19+
allow_duration = self.allow_time - datetime.now()
20+
# 如果持续时间为负,则返回一个零时长的 timedelta 对象
21+
if allow_duration.total_seconds() < 0:
22+
return timedelta(0)
23+
return allow_duration

volcenginesdkarkruntime/resources/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .bot import BotChat, AsyncBotChat
66
from .context import Context, AsyncContext
77
from .content_generation import ContentGeneration, AsyncContentGeneration
8+
from .batch_chat import BatchChat, AsyncBatchChat
89

910
__all__ = [
1011
"Chat",
@@ -19,4 +20,6 @@
1920
"AsyncContext",
2021
"ContentGeneration",
2122
"AsyncContentGeneration"
23+
"BatchChat",
24+
"AsyncBatchChat"
2225
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .chat import BatchChat, AsyncBatchChat
2+
3+
__all__ = ["BatchChat", "AsyncBatchChat"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
from .completions import Completions, AsyncCompletions
6+
from ..._compat import cached_property
7+
from ..._resource import SyncAPIResource, AsyncAPIResource
8+
9+
__all__ = ["BatchChat", "AsyncBatchChat"]
10+
11+
12+
class BatchChat(SyncAPIResource):
13+
@cached_property
14+
def completions(self) -> Completions:
15+
return Completions(self._client)
16+
17+
18+
class AsyncBatchChat(AsyncAPIResource):
19+
@cached_property
20+
def completions(self) -> AsyncCompletions:
21+
return AsyncCompletions(self._client)

0 commit comments

Comments
 (0)