11from __future__ import annotations
22
3+ import asyncio
34import logging
45import os
56import threading
2829from ._streaming import Stream
2930
3031from ._utils ._key_agreement import key_agreement_client
32+ from ._utils ._model_breaker import ModelBreaker
3133
3234__all__ = ["Ark" , "AsyncArk" ]
3335
@@ -39,6 +41,9 @@ class Ark(SyncAPIClient):
3941 tokenization : resources .Tokenization
4042 context : resources .Context
4143 content_generation : resources .ContentGeneration
44+ batch_chat : resources .BatchChat
45+ model_breaker_map : dict [str , ModelBreaker ]
46+ model_breaker_lock : threading .Lock
4247
4348 def __init__ (
4449 self ,
@@ -98,6 +103,9 @@ def __init__(
98103 self .tokenization = resources .Tokenization (self )
99104 self .context = resources .Context (self )
100105 self .content_generation = resources .ContentGeneration (self )
106+ self .batch_chat = resources .BatchChat (self )
107+ self .model_breaker_map = defaultdict (ModelBreaker )
108+ self .model_breaker_lock = threading .Lock ()
101109 # self.classification = resources.Classification(self)
102110
103111 def _get_endpoint_sts_token (self , endpoint_id : str ):
@@ -128,6 +136,9 @@ def auth_headers(self) -> dict[str, str]:
128136 api_key = self .api_key
129137 return {"Authorization" : f"Bearer { api_key } " }
130138
139+ def get_model_breaker (self , model_name : str ) -> ModelBreaker :
140+ with self .model_breaker_lock :
141+ return self .model_breaker_map [model_name ]
131142
132143class AsyncArk (AsyncAPIClient ):
133144 chat : resources .AsyncChat
@@ -136,6 +147,9 @@ class AsyncArk(AsyncAPIClient):
136147 tokenization : resources .AsyncTokenization
137148 context : resources .AsyncContext
138149 content_generation : resources .AsyncContentGeneration
150+ batch_chat : resources .AsyncBatchChat
151+ model_breaker_map : dict [str , ModelBreaker ]
152+ model_breaker_lock : asyncio .Lock
139153
140154 def __init__ (
141155 self ,
@@ -194,6 +208,9 @@ def __init__(
194208 self .tokenization = resources .AsyncTokenization (self )
195209 self .context = resources .AsyncContext (self )
196210 self .content_generation = resources .AsyncContentGeneration (self )
211+ self .batch_chat = resources .AsyncBatchChat (self )
212+ self .model_breaker_map = defaultdict (ModelBreaker )
213+ self .model_breaker_lock = asyncio .Lock ()
197214 # self.classification = resources.AsyncClassification(self)
198215
199216 def _get_endpoint_sts_token (self , endpoint_id : str ):
@@ -217,6 +234,10 @@ def auth_headers(self) -> dict[str, str]:
217234 api_key = self .api_key
218235 return {"Authorization" : f"Bearer { api_key } " }
219236
237+ async def get_model_breaker (self , model_name : str ) -> ModelBreaker :
238+ async with self .model_breaker_lock :
239+ return self .model_breaker_map [model_name ]
240+
220241
221242class StsTokenManager (object ):
222243 # The time at which we'll attempt to refresh, but not
0 commit comments