3737 _DEFAULT_MANDATORY_REFRESH_TIMEOUT ,
3838 _DEFAULT_STS_TIMEOUT ,
3939 _DEFAULT_RESOURCE_TYPE ,
40+ _PRESETENDPOINT_RESOURCE_TYPE ,
4041 DEFAULT_TIMEOUT ,
42+ _BOT_RESOURCE_TYPE ,
4143)
4244from ._streaming import Stream
4345
@@ -137,12 +139,15 @@ def __init__(
137139 self .files = resources .Files (self )
138140 # self.classification = resources.Classification(self)
139141
140- def _get_endpoint_sts_token (self , endpoint_id : str ):
142+ def _get_endpoint_sts_token (self , endpoint_id : str , project_name : str = None ):
141143 if self ._sts_token_manager is None :
142144 if self .ak is None or self .sk is None :
143145 raise ArkAPIError ("must set ak and sk before get endpoint token." )
144146 self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self .region )
145- return self ._sts_token_manager .get (endpoint_id )
147+ resource_type : str = self .get_resource_type_by_endpoint_id (endpoint_id )
148+ if resource_type == _PRESETENDPOINT_RESOURCE_TYPE and (project_name is None or project_name .strip () == "" ):
149+ raise ArkAPIError ("must set project_name when get preset endpoint token." )
150+ return self ._sts_token_manager .get (endpoint_id , resource_type = resource_type , project_name = project_name )
146151
147152 def _get_endpoint_certificate (
148153 self , endpoint_id : str
@@ -179,6 +184,16 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
179184 with self .model_breaker_lock :
180185 return self .model_breaker_map [model_name ]
181186
187+ def get_resource_type_by_endpoint_id (self , endpoint_id : str ) -> str :
188+ if endpoint_id .startswith ("ep-m-" ):
189+ return _PRESETENDPOINT_RESOURCE_TYPE
190+ if endpoint_id .startswith ("ep-" ):
191+ return _DEFAULT_RESOURCE_TYPE
192+ if endpoint_id .startswith ("bot-" ):
193+ return _BOT_RESOURCE_TYPE
194+ # for model id, default to preset endpoint
195+ return _PRESETENDPOINT_RESOURCE_TYPE
196+
182197
183198class AsyncArk (AsyncAPIClient ):
184199 beta : beta .AsyncBeta
@@ -350,6 +365,7 @@ def _protected_refresh(
350365 ttl : int = _DEFAULT_STS_TIMEOUT ,
351366 is_mandatory : bool = False ,
352367 resource_type : str = _DEFAULT_RESOURCE_TYPE ,
368+ project_name : str = None ,
353369 ):
354370 if ttl < self ._advisory_refresh_timeout * 2 :
355371 raise ArkAPIError (
@@ -360,7 +376,7 @@ def _protected_refresh(
360376
361377 try :
362378 api_key , expired_time = self ._load_api_key (
363- ep , ttl , resource_type = resource_type
379+ ep , ttl , resource_type = resource_type , project_name = project_name
364380 )
365381 self ._endpoint_sts_tokens [ep ] = (api_key , expired_time )
366382 except ApiException as e :
@@ -369,7 +385,7 @@ def _protected_refresh(
369385 else :
370386 logging .error ("load api key cause error: e={}" .format (e ))
371387
372- def _refresh (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ):
388+ def _refresh (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE , project_name : str = None ):
373389 if not self ._need_refresh (ep , self ._advisory_refresh_timeout ):
374390 return
375391
@@ -383,7 +399,7 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
383399 )
384400
385401 self ._protected_refresh (
386- ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type
402+ ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type , project_name = project_name
387403 )
388404 return
389405 finally :
@@ -394,24 +410,27 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
394410 return
395411
396412 self ._protected_refresh (
397- ep , is_mandatory = True , resource_type = resource_type
413+ ep , is_mandatory = True , resource_type = resource_type , project_name = project_name
398414 )
399415
400- def get (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> str :
401- self ._refresh (ep , resource_type = resource_type )
416+ def get (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE , project_name : str = None ) -> str :
417+ self ._refresh (ep , resource_type = resource_type , project_name = project_name )
402418 return self ._endpoint_sts_tokens [ep ][0 ]
403419
404420 def _load_api_key (
405421 self ,
406422 ep : str ,
407423 duration_seconds : int ,
408424 resource_type : str = _DEFAULT_RESOURCE_TYPE ,
425+ project_name : str = None ,
409426 ) -> Tuple [str , int ]:
410427 get_api_key_request = volcenginesdkark .GetApiKeyRequest (
411428 duration_seconds = duration_seconds ,
412429 resource_type = resource_type ,
413430 resource_ids = [ep ],
414431 )
432+ if project_name is not None and project_name .strip () != "" :
433+ get_api_key_request .project_name = project_name
415434 resp : volcenginesdkark .GetApiKeyResponse = self .api_instance .get_api_key (
416435 get_api_key_request
417436 )
0 commit comments