11from __future__ import annotations
22
3+ import logging
34import os
4- from typing import Dict
5+ import threading
6+ import time
7+ from collections import defaultdict
58
6- import httpx
7- from httpx import Timeout , URL
9+ from httpx import Timeout , URL , Client , AsyncClient
10+ from typing import Dict , Tuple
811
9- from . import resources , _exceptions
12+ from volcenginesdkcore .rest import ApiException
13+ from ._exceptions import ArkAPIError
14+
15+ import volcenginesdkark
16+
17+ from . import resources
1018from ._base_client import SyncAPIClient , AsyncAPIClient
11- from ._constants import DEFAULT_MAX_RETRIES , BASE_URL
12- from ._exceptions import ArkError , ArkAPIStatusError
19+ from ._constants import (
20+ DEFAULT_MAX_RETRIES ,
21+ BASE_URL ,
22+ _DEFAULT_ADVISORY_REFRESH_TIMEOUT ,
23+ _DEFAULT_MANDATORY_REFRESH_TIMEOUT ,
24+ _DEFAULT_STS_TIMEOUT ,
25+ DEFAULT_TIMEOUT
26+ )
1327from ._streaming import Stream
1428
1529__all__ = ["Ark" , "AsyncArk" ]
@@ -22,32 +36,51 @@ def __init__(
2236 self ,
2337 * ,
2438 base_url : str | URL = BASE_URL ,
25- api_key : str | None = None ,
26- timeout : float | Timeout | None = None ,
39+ ak : str | None = None ,
40+ sk : str | None = None ,
41+ timeout : float | Timeout | None = DEFAULT_TIMEOUT ,
2742 max_retries : int = DEFAULT_MAX_RETRIES ,
28- default_query : Dict [str , object ] | None = None ,
29- http_client : httpx .Client | None = None ,
43+ http_client : Client | None = None ,
3044 ) -> None :
31- if api_key is None :
32- api_key = os .environ .get ("ARK_API_KEY" )
33- self .api_key = api_key
45+ """init ark client, this client is thread unsafe. If need to use in multi thread, init a new `Ark` client in
46+ each thread
47+
48+ Args:
49+ ak: access key id
50+ sk: secret access key
51+ timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
52+ max_retries: times of retry when request failed. default 1
53+ http_client: specify customized http_client
54+ Returns:
55+ ark client
56+ """
57+
58+ if ak is None :
59+ ak = os .environ .get ("VOLC_ACCESSKEY" )
60+ if sk is None :
61+ sk = os .environ .get ("VOLC_SECRETKEY" )
62+ self .ak = ak
63+ self .sk = sk
3464
3565 super ().__init__ (
3666 base_url = base_url ,
3767 max_retries = max_retries ,
3868 timeout = timeout ,
3969 http_client = http_client ,
40- custom_query = default_query ,
70+ custom_query = None ,
4171 )
4272
4373 self ._default_stream_cls = Stream
74+ self ._sts_token_manager : StsTokenManager | None = None
4475
4576 self .chat = resources .Chat (self )
4677
47- @property
48- def auth_headers (self ) -> dict [str , str ]:
49- api_key = self .api_key
50- return {"Authorization" : f"Bearer { api_key } " }
78+ def _get_endpoint_sts_token (self , endpoint_id : str ):
79+ if self ._sts_token_manager is None :
80+ if self .ak is None or self .sk is None :
81+ raise ArkAPIError ("must set ak and sk before get endpoint token." )
82+ self ._sts_token_manager = StsTokenManager (self .ak , self .sk )
83+ return self ._sts_token_manager .get (endpoint_id )
5184
5285
5386class AsyncArk (AsyncAPIClient ):
@@ -56,31 +89,133 @@ class AsyncArk(AsyncAPIClient):
5689 def __init__ (
5790 self ,
5891 * ,
92+ ak : str | None = None ,
93+ sk : str | None = None ,
5994 base_url : str | URL = BASE_URL ,
60- api_key : str | None = None ,
61- timeout : float | Timeout | None = None ,
95+ timeout : float | Timeout | None = DEFAULT_TIMEOUT ,
6296 max_retries : int = DEFAULT_MAX_RETRIES ,
63- default_query : Dict [str , object ] | None = None ,
64- http_client : httpx .Client | None = None ,
97+ http_client : AsyncClient | None = None ,
6598 ) -> None :
66- if api_key is None :
67- api_key = os .environ .get ("ARK_API_KEY" )
68- self .api_key = api_key
99+ """init async ark client, this client is thread unsafe
100+
101+ Args:
102+ ak: access key id
103+ sk: secret access key
104+ timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
105+ max_retries: times of retry when request failed. default 1
106+ http_client: specify customized http_client
107+ Returns:
108+ async ark client
109+ """
110+
111+ if ak is None :
112+ ak = os .environ .get ("VOLC_ACCESSKEY" )
113+ if sk is None :
114+ sk = os .environ .get ("VOLC_SECRETKEY" )
115+ self .ak = ak
116+ self .sk = sk
69117
70118 super ().__init__ (
71119 base_url = base_url ,
72120 max_retries = max_retries ,
73121 timeout = timeout ,
74122 http_client = http_client ,
75- custom_query = default_query ,
123+ custom_query = None ,
76124 )
77125
78126 self ._default_stream_cls = Stream
127+ self ._sts_token_manager : StsTokenManager | None = None
79128
80129 self .chat = resources .AsyncChat (self )
81130
131+ def _get_endpoint_sts_token (self , endpoint_id : str ):
132+ if self ._sts_token_manager is None :
133+ if self .ak is None or self .sk is None :
134+ raise ArkAPIError ("must set ak and sk before get endpoint token." )
135+ self ._sts_token_manager = StsTokenManager (self .ak , self .sk )
136+ return self ._sts_token_manager .get (endpoint_id )
137+
138+
139+ class StsTokenManager (object ):
140+
141+ # The time at which we'll attempt to refresh, but not
142+ # block if someone else is refreshing.
143+ _advisory_refresh_timeout : int = _DEFAULT_ADVISORY_REFRESH_TIMEOUT
144+ # The time at which all threads will block waiting for
145+ # refreshed credentials.
146+ _mandatory_refresh_timeout : int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
147+
148+ def __init__ (self , ak : str , sk : str ):
149+ self ._endpoint_sts_tokens : Dict [str , Tuple [str , int ]] = defaultdict (lambda : ("" , 0 ))
150+ self ._refresh_lock = threading .Lock ()
151+
152+ import volcenginesdkcore
153+
154+ configuration = volcenginesdkcore .Configuration ()
155+ configuration .ak = ak
156+ configuration .sk = sk
157+ configuration .region = "cn-beijing"
158+
159+ volcenginesdkcore .Configuration .set_default (configuration )
160+ self .api_instance = volcenginesdkark .ARKApi ()
161+
162+ def _need_refresh (self , ep : str , refresh_in : int | None = None ) -> bool :
163+ if refresh_in is None :
164+ refresh_in = self ._advisory_refresh_timeout
165+
166+ return self ._endpoint_sts_tokens [ep ][1 ] - time .time () < refresh_in
167+
168+ def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ):
169+ if ttl < _DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2 :
170+ raise ArkAPIError ("ttl should not be under {} seconds." .format (_DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2 ))
171+
172+ try :
173+ api_key , expired_time = self ._load_api_key (
174+ ep , ttl
175+ )
176+ self ._endpoint_sts_tokens [ep ] = (api_key , expired_time )
177+ except ApiException as e :
178+ if is_mandatory :
179+ raise ArkAPIError ("load api key cause error: e={}" .format (e ))
180+ else :
181+ logging .error ("load api key cause error: e={}" .format (e ))
182+
183+ def _refresh (self , ep : str ):
184+ if not self ._need_refresh (ep , self ._advisory_refresh_timeout ):
185+ return
186+
187+ if self ._refresh_lock .acquire (False ):
188+ if not self ._need_refresh (ep , self ._advisory_refresh_timeout ):
189+ return
190+
191+ try :
192+ is_mandatory_refresh = self ._need_refresh (
193+ ep , self ._mandatory_refresh_timeout
194+ )
195+
196+ self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh )
197+ return
198+ finally :
199+ self ._refresh_lock .release ()
200+ elif self ._need_refresh (ep , self ._mandatory_refresh_timeout ):
201+ with self ._refresh_lock :
202+ if not self ._need_refresh (ep , self ._mandatory_refresh_timeout ):
203+ return
204+
205+ self ._protected_refresh (ep , is_mandatory = True )
206+
207+ def get (self , ep : str ) -> str :
208+ self ._refresh (ep )
209+ return self ._endpoint_sts_tokens [ep ][0 ]
210+
211+ def _load_api_key (self , ep : str , duration_seconds : int ) -> Tuple [str , int ]:
212+ get_api_key_request = volcenginesdkark .GetApiKeyRequest (
213+ duration_seconds = duration_seconds ,
214+ resource_type = "endpoint" ,
215+ resource_ids = [ep ],
216+ )
217+ resp : volcenginesdkark .GetApiKeyResponse = self .api_instance .get_api_key (
218+ get_api_key_request
219+ )
82220
83- @property
84- def auth_headers (self ) -> dict [str , str ]:
85- api_key = self .api_key
86- return {"Authorization" : f"Bearer { api_key } " }
221+ return resp .api_key , resp .expired_time
0 commit comments