2222 _DEFAULT_ADVISORY_REFRESH_TIMEOUT ,
2323 _DEFAULT_MANDATORY_REFRESH_TIMEOUT ,
2424 _DEFAULT_STS_TIMEOUT ,
25+ _DEFAULT_RESOURCE_TYPE ,
2526 DEFAULT_TIMEOUT
2627)
2728from ._streaming import Stream
3132
3233class Ark (SyncAPIClient ):
3334 chat : resources .Chat
35+ bot_chat : resources .BotChat
36+ embeddings : resources .Embeddings
3437
3538 def __init__ (
3639 self ,
@@ -83,6 +86,7 @@ def __init__(
8386 self ._sts_token_manager : StsTokenManager | None = None
8487
8588 self .chat = resources .Chat (self )
89+ self .bot_chat = resources .BotChat (self )
8690 self .embeddings = resources .Embeddings (self )
8791 # self.tokenization = resources.Tokenization(self)
8892 # self.classification = resources.Classification(self)
@@ -94,6 +98,13 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
9498 self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self .region )
9599 return self ._sts_token_manager .get (endpoint_id )
96100
101+ def _get_bot_sts_token (self , bot_id : str ):
102+ if self ._sts_token_manager is None :
103+ if self .ak is None or self .sk is None :
104+ raise ArkAPIError ("must set ak and sk before get endpoint token." )
105+ self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self .region )
106+ return self ._sts_token_manager .get (bot_id , resource_type = "bot" )
107+
97108 @property
98109 def auth_headers (self ) -> dict [str , str ]:
99110 api_key = self .api_key
@@ -102,6 +113,8 @@ def auth_headers(self) -> dict[str, str]:
102113
103114class AsyncArk (AsyncAPIClient ):
104115 chat : resources .AsyncChat
116+ bot_chat : resources .AsyncBotChat
117+ embeddings : resources .AsyncEmbeddings
105118
106119 def __init__ (
107120 self ,
@@ -153,6 +166,7 @@ def __init__(
153166 self ._sts_token_manager : StsTokenManager | None = None
154167
155168 self .chat = resources .AsyncChat (self )
169+ self .bot_chat = resources .AsyncBotChat (self )
156170 self .embeddings = resources .AsyncEmbeddings (self )
157171 # self.tokenization = resources.AsyncTokenization(self)
158172 # self.classification = resources.AsyncClassification(self)
@@ -171,7 +185,6 @@ def auth_headers(self) -> dict[str, str]:
171185
172186
173187class StsTokenManager (object ):
174-
175188 # The time at which we'll attempt to refresh, but not
176189 # block if someone else is refreshing.
177190 _advisory_refresh_timeout : int = _DEFAULT_ADVISORY_REFRESH_TIMEOUT
@@ -200,13 +213,14 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
200213
201214 return self ._endpoint_sts_tokens [ep ][1 ] - time .time () < refresh_in
202215
203- def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ):
216+ def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ,
217+ resource_type : str = _DEFAULT_RESOURCE_TYPE ):
204218 if ttl < self ._advisory_refresh_timeout * 2 :
205219 raise ArkAPIError ("ttl should not be under {} seconds." .format (self ._advisory_refresh_timeout * 2 ))
206220
207221 try :
208222 api_key , expired_time = self ._load_api_key (
209- ep , ttl
223+ ep , ttl , resource_type = resource_type
210224 )
211225 self ._endpoint_sts_tokens [ep ] = (api_key , expired_time )
212226 except ApiException as e :
@@ -215,7 +229,7 @@ def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandat
215229 else :
216230 logging .error ("load api key cause error: e={}" .format (e ))
217231
218- def _refresh (self , ep : str ):
232+ def _refresh (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ):
219233 if not self ._need_refresh (ep , self ._advisory_refresh_timeout ):
220234 return
221235
@@ -228,7 +242,7 @@ def _refresh(self, ep: str):
228242 ep , self ._mandatory_refresh_timeout
229243 )
230244
231- self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh )
245+ self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type )
232246 return
233247 finally :
234248 self ._refresh_lock .release ()
@@ -237,16 +251,17 @@ def _refresh(self, ep: str):
237251 if not self ._need_refresh (ep , self ._mandatory_refresh_timeout ):
238252 return
239253
240- self ._protected_refresh (ep , is_mandatory = True )
254+ self ._protected_refresh (ep , is_mandatory = True , resource_type = resource_type )
241255
242- def get (self , ep : str ) -> str :
243- self ._refresh (ep )
256+ def get (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> str :
257+ self ._refresh (ep , resource_type = resource_type )
244258 return self ._endpoint_sts_tokens [ep ][0 ]
245259
246- def _load_api_key (self , ep : str , duration_seconds : int ) -> Tuple [str , int ]:
260+ def _load_api_key (self , ep : str , duration_seconds : int ,
261+ resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> Tuple [str , int ]:
247262 get_api_key_request = volcenginesdkark .GetApiKeyRequest (
248263 duration_seconds = duration_seconds ,
249- resource_type = "endpoint" ,
264+ resource_type = resource_type ,
250265 resource_ids = [ep ],
251266 )
252267 resp : volcenginesdkark .GetApiKeyResponse = self .api_instance .get_api_key (
0 commit comments