2424 _DEFAULT_MANDATORY_REFRESH_TIMEOUT ,
2525 _DEFAULT_STS_TIMEOUT ,
2626 _DEFAULT_RESOURCE_TYPE ,
27- DEFAULT_TIMEOUT
27+ DEFAULT_TIMEOUT ,
2828)
2929from ._streaming import Stream
3030
@@ -42,6 +42,7 @@ class Ark(SyncAPIClient):
4242 context : resources .Context
4343 multimodal_embeddings : resources .MultimodalEmbeddings
4444 content_generation : resources .ContentGeneration
45+ images : resources .Images
4546 batch_chat : resources .BatchChat
4647 model_breaker_map : dict [str , ModelBreaker ]
4748 model_breaker_lock : threading .Lock
@@ -71,7 +72,6 @@ def __init__(
7172 Returns:
7273 ark client
7374 """
74-
7575 if ak is None :
7676 ak = os .environ .get ("VOLC_ACCESSKEY" )
7777 if sk is None :
@@ -84,7 +84,9 @@ def __init__(
8484 self .api_key = api_key
8585 self .region = region
8686
87- assert (api_key is not None ) or (ak is not None and sk is not None ), "you need to support api_key or ak&sk"
87+ assert (api_key is not None ) or (ak is not None and sk is not None ), (
88+ "you need to support api_key or ak&sk"
89+ )
8890
8991 super ().__init__ (
9092 base_url = base_url ,
@@ -105,6 +107,7 @@ def __init__(
105107 self .context = resources .Context (self )
106108 self .multimodal_embeddings = resources .MultimodalEmbeddings (self )
107109 self .content_generation = resources .ContentGeneration (self )
110+ self .images = resources .Images (self )
108111 self .batch_chat = resources .BatchChat (self )
109112 self .model_breaker_map = defaultdict (ModelBreaker )
110113 self .model_breaker_lock = threading .Lock ()
@@ -120,10 +123,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
120123 def _get_endpoint_certificate (self , endpoint_id : str ) -> key_agreement_client :
121124 if self ._certificate_manager is None :
122125 cert_path = os .environ .get ("E2E_CERTIFICATE_PATH" )
123- if (self .ak is None or self .sk is None ) and cert_path is None and self .api_key is None :
124- raise ArkAPIError ("must set (api_key) or (ak and sk) \
125- or (E2E_CERTIFICATE_PATH) before get endpoint token." )
126- self ._certificate_manager = E2ECertificateManager (self .ak , self .sk , self .region , self ._base_url , self .api_key )
126+ if (
127+ (self .ak is None or self .sk is None )
128+ and cert_path is None
129+ and self .api_key is None
130+ ):
131+ raise ArkAPIError (
132+ "must set (api_key) or (ak and sk) \
133+ or (E2E_CERTIFICATE_PATH) before get endpoint token."
134+ )
135+ self ._certificate_manager = E2ECertificateManager (
136+ self .ak , self .sk , self .region , self ._base_url , self .api_key
137+ )
127138 return self ._certificate_manager .get (endpoint_id )
128139
129140 def _get_bot_sts_token (self , bot_id : str ):
@@ -142,6 +153,7 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
142153 with self .model_breaker_lock :
143154 return self .model_breaker_map [model_name ]
144155
156+
145157class AsyncArk (AsyncAPIClient ):
146158 chat : resources .AsyncChat
147159 bot_chat : resources .AsyncBotChat
@@ -150,6 +162,7 @@ class AsyncArk(AsyncAPIClient):
150162 context : resources .AsyncContext
151163 multimodal_embeddings : resources .AsyncMultimodalEmbeddings
152164 content_generation : resources .AsyncContentGeneration
165+ images : resources .AsyncImages
153166 batch_chat : resources .AsyncBatchChat
154167 model_breaker_map : dict [str , ModelBreaker ]
155168 model_breaker_lock : asyncio .Lock
@@ -168,15 +181,15 @@ def __init__(
168181 ) -> None :
169182 """init async ark client, this client is thread unsafe
170183
171- Args:
172- ak: access key id
173- sk: secret access key
174- api_key: api key,this api key will not be refreshed
175- timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
176- max_retries: times of retry when request failed. default 1
177- http_client: specify customized http_client
178- Returns:
179- async ark client
184+ Args:
185+ ak: access key id
186+ sk: secret access key
187+ api_key: api key,this api key will not be refreshed
188+ timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
189+ max_retries: times of retry when request failed. default 1
190+ http_client: specify customized http_client
191+ Returns:
192+ async ark client
180193 """
181194
182195 if ak is None :
@@ -191,7 +204,9 @@ def __init__(
191204 self .api_key = api_key
192205 self .region = region
193206
194- assert (api_key is not None ) or (ak is not None and sk is not None ), "you need to support api_key or ak&sk"
207+ assert (api_key is not None ) or (ak is not None and sk is not None ), (
208+ "you need to support api_key or ak&sk"
209+ )
195210
196211 super ().__init__ (
197212 base_url = base_url ,
@@ -212,6 +227,7 @@ def __init__(
212227 self .context = resources .AsyncContext (self )
213228 self .multimodal_embeddings = resources .AsyncMultimodalEmbeddings (self )
214229 self .content_generation = resources .AsyncContentGeneration (self )
230+ self .images = resources .AsyncImages (self )
215231 self .batch_chat = resources .AsyncBatchChat (self )
216232 self .model_breaker_map = defaultdict (ModelBreaker )
217233 self .model_breaker_lock = asyncio .Lock ()
@@ -227,10 +243,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
227243 def _get_endpoint_certificate (self , endpoint_id : str ) -> key_agreement_client :
228244 if self ._certificate_manager is None :
229245 cert_path = os .environ .get ("E2E_CERTIFICATE_PATH" )
230- if (self .ak is None or self .sk is None ) and cert_path is None and self .api_key is None :
231- raise ArkAPIError ("must set (api_key) or (ak and sk) \
232- or (E2E_CERTIFICATE_PATH) before get endpoint token." )
233- self ._certificate_manager = E2ECertificateManager (self .ak , self .sk , self .region , self ._base_url , self .api_key )
246+ if (
247+ (self .ak is None or self .sk is None )
248+ and cert_path is None
249+ and self .api_key is None
250+ ):
251+ raise ArkAPIError (
252+ "must set (api_key) or (ak and sk) \
253+ or (E2E_CERTIFICATE_PATH) before get endpoint token."
254+ )
255+ self ._certificate_manager = E2ECertificateManager (
256+ self .ak , self .sk , self .region , self ._base_url , self .api_key
257+ )
234258 return self ._certificate_manager .get (endpoint_id )
235259
236260 @property
@@ -252,7 +276,9 @@ class StsTokenManager(object):
252276 _mandatory_refresh_timeout : int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
253277
254278 def __init__ (self , ak : str , sk : str , region : str ):
255- self ._endpoint_sts_tokens : Dict [str , Tuple [str , int ]] = defaultdict (lambda : ("" , 0 ))
279+ self ._endpoint_sts_tokens : Dict [str , Tuple [str , int ]] = defaultdict (
280+ lambda : ("" , 0 )
281+ )
256282 self ._refresh_lock = threading .Lock ()
257283
258284 import volcenginesdkcore
@@ -272,10 +298,19 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
272298
273299 return self ._endpoint_sts_tokens [ep ][1 ] - time .time () < refresh_in
274300
275- def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ,
276- resource_type : str = _DEFAULT_RESOURCE_TYPE ):
301+ def _protected_refresh (
302+ self ,
303+ ep : str ,
304+ ttl : int = _DEFAULT_STS_TIMEOUT ,
305+ is_mandatory : bool = False ,
306+ resource_type : str = _DEFAULT_RESOURCE_TYPE ,
307+ ):
277308 if ttl < self ._advisory_refresh_timeout * 2 :
278- raise ArkAPIError ("ttl should not be under {} seconds." .format (self ._advisory_refresh_timeout * 2 ))
309+ raise ArkAPIError (
310+ "ttl should not be under {} seconds." .format (
311+ self ._advisory_refresh_timeout * 2
312+ )
313+ )
279314
280315 try :
281316 api_key , expired_time = self ._load_api_key (
@@ -301,7 +336,9 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
301336 ep , self ._mandatory_refresh_timeout
302337 )
303338
304- self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type )
339+ self ._protected_refresh (
340+ ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type
341+ )
305342 return
306343 finally :
307344 self ._refresh_lock .release ()
@@ -310,14 +347,20 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
310347 if not self ._need_refresh (ep , self ._mandatory_refresh_timeout ):
311348 return
312349
313- self ._protected_refresh (ep , is_mandatory = True , resource_type = resource_type )
350+ self ._protected_refresh (
351+ ep , is_mandatory = True , resource_type = resource_type
352+ )
314353
315354 def get (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> str :
316355 self ._refresh (ep , resource_type = resource_type )
317356 return self ._endpoint_sts_tokens [ep ][0 ]
318357
319- def _load_api_key (self , ep : str , duration_seconds : int ,
320- resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> Tuple [str , int ]:
358+ def _load_api_key (
359+ self ,
360+ ep : str ,
361+ duration_seconds : int ,
362+ resource_type : str = _DEFAULT_RESOURCE_TYPE ,
363+ ) -> Tuple [str , int ]:
321364 get_api_key_request = volcenginesdkark .GetApiKeyRequest (
322365 duration_seconds = duration_seconds ,
323366 resource_type = resource_type ,
@@ -331,19 +374,26 @@ def _load_api_key(self, ep: str, duration_seconds: int,
331374
332375
333376class E2ECertificateManager (object ):
334-
335- class CertificateResponse ():
377+ class CertificateResponse :
336378 Certificate : str
337379 """The certificate content."""
338380
339- def __init__ (self , ak : str , sk : str , region : str , base_url : str | URL = BASE_URL , api_key : str | None = None ):
381+ def __init__ (
382+ self ,
383+ ak : str ,
384+ sk : str ,
385+ region : str ,
386+ base_url : str | URL = BASE_URL ,
387+ api_key : str | None = None ,
388+ ):
340389 self ._certificate_manager : Dict [str , key_agreement_client ] = {}
341390
342391 # local cache prepare
343392 self ._init_local_cert_cache ()
344393
345394 # api instance prepare
346395 import volcenginesdkcore
396+
347397 configuration = volcenginesdkcore .Configuration ()
348398 configuration .ak = ak
349399 configuration .sk = sk
@@ -365,38 +415,47 @@ def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL
365415 api_key = api_key ,
366416 )
367417 self ._e2e_uri = "/e2e/get/certificate"
368- self ._x_session_token = {' X-Session-Token' : self ._e2e_uri }
418+ self ._x_session_token = {" X-Session-Token" : self ._e2e_uri }
369419
370420 def _load_cert_by_cert_path (self ) -> str :
371- with open (self .cert_path , 'r' ) as f :
421+ with open (self .cert_path , "r" ) as f :
372422 cert_pem = f .read ()
373423 return cert_pem
374424
375425 def _load_cert_by_ak_sk (self , ep : str ) -> str :
376- get_endpoint_certificate_request = volcenginesdkark . GetEndpointCertificateRequest (
377- id = ep
426+ get_endpoint_certificate_request = (
427+ volcenginesdkark . GetEndpointCertificateRequest ( id = ep )
378428 )
379429 try :
380- resp : volcenginesdkark .GetEndpointCertificateResponse = self .api_instance .get_endpoint_certificate (
381- get_endpoint_certificate_request )
430+ resp : volcenginesdkark .GetEndpointCertificateResponse = (
431+ self .api_instance .get_endpoint_certificate (
432+ get_endpoint_certificate_request
433+ )
434+ )
382435 except ApiException as e :
383- raise ArkAPIError ("Getting model vendor encryption certificate failed: %s\n " % e )
436+ raise ArkAPIError (
437+ "Getting model vendor encryption certificate failed: %s\n " % e
438+ )
384439
385440 return resp .pca_instance_certificate
386441
387442 def _sync_load_cert_by_auth (self , ep : str ) -> str :
388443 try : # try to make request with session header (used for header statistic)
389- resp = self .client .post (self ._e2e_uri , options = {"headers" : self ._x_session_token },
390- body = {"model" : ep }, cast_to = self .CertificateResponse )
444+ resp = self .client .post (
445+ self ._e2e_uri ,
446+ options = {"headers" : self ._x_session_token },
447+ body = {"model" : ep },
448+ cast_to = self .CertificateResponse ,
449+ )
391450 except Exception as e :
392451 raise ArkAPIError ("Getting Certificate failed: %s\n " % e )
393- if ' error' in resp :
394- raise ArkAPIError ("Getting Certificate failed: %s\n " % resp [' error' ])
395- return resp [' Certificate' ]
452+ if " error" in resp :
453+ raise ArkAPIError ("Getting Certificate failed: %s\n " % resp [" error" ])
454+ return resp [" Certificate" ]
396455
397456 def _save_cert_to_file (self , ep : str , cert_pem : str ):
398457 cert_file_path = os .path .join (self ._cert_storage_path , f"{ ep } .pem" )
399- with open (cert_file_path , 'w' ) as f :
458+ with open (cert_file_path , "w" ) as f :
400459 f .write (cert_pem )
401460
402461 def _load_cert_locally (self , ep : str ) -> str | None :
@@ -406,7 +465,7 @@ def _load_cert_locally(self, ep: str) -> str | None:
406465 current_time = time .time ()
407466 time_difference = current_time - last_modified_time
408467 if time_difference <= self ._cert_expiration_seconds :
409- with open (cert_file_path , 'r' ) as f :
468+ with open (cert_file_path , "r" ) as f :
410469 return f .read ()
411470 else :
412471 os .remove (cert_file_path )
0 commit comments