@@ -39,6 +39,7 @@ def __init__(
3939 ak : str | None = None ,
4040 sk : str | None = None ,
4141 api_key : str | None = None ,
42+ region : str = "cn-beijing" ,
4243 timeout : float | Timeout | None = DEFAULT_TIMEOUT ,
4344 max_retries : int = DEFAULT_MAX_RETRIES ,
4445 http_client : Client | None = None ,
@@ -66,6 +67,7 @@ def __init__(
6667 self .ak = ak
6768 self .sk = sk
6869 self .api_key = api_key
70+ self .region = region
6971
7072 assert (api_key is not None ) or (ak is not None and sk is not None ), "you need to support api_key or ak&sk"
7173
@@ -81,12 +83,15 @@ def __init__(
8183 self ._sts_token_manager : StsTokenManager | None = None
8284
8385 self .chat = resources .Chat (self )
86+ self .embeddings = resources .Embeddings (self )
87+ # self.tokenization = resources.Tokenization(self)
88+ # self.classification = resources.Classification(self)
8489
8590 def _get_endpoint_sts_token (self , endpoint_id : str ):
8691 if self ._sts_token_manager is None :
8792 if self .ak is None or self .sk is None :
8893 raise ArkAPIError ("must set ak and sk before get endpoint token." )
89- self ._sts_token_manager = StsTokenManager (self .ak , self .sk )
94+ self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self . region )
9095 return self ._sts_token_manager .get (endpoint_id )
9196
9297 @property
@@ -105,6 +110,7 @@ def __init__(
105110 sk : str | None = None ,
106111 api_key : str | None = None ,
107112 base_url : str | URL = BASE_URL ,
113+ region : str = "cn-beijing" ,
108114 timeout : float | Timeout | None = DEFAULT_TIMEOUT ,
109115 max_retries : int = DEFAULT_MAX_RETRIES ,
110116 http_client : AsyncClient | None = None ,
@@ -131,6 +137,7 @@ def __init__(
131137 self .ak = ak
132138 self .sk = sk
133139 self .api_key = api_key
140+ self .region = region
134141
135142 assert (api_key is not None ) or (ak is not None and sk is not None ), "you need to support api_key or ak&sk"
136143
@@ -146,12 +153,15 @@ def __init__(
146153 self ._sts_token_manager : StsTokenManager | None = None
147154
148155 self .chat = resources .AsyncChat (self )
156+ self .embeddings = resources .AsyncEmbeddings (self )
157+ # self.tokenization = resources.AsyncTokenization(self)
158+ # self.classification = resources.AsyncClassification(self)
149159
150160 def _get_endpoint_sts_token (self , endpoint_id : str ):
151161 if self ._sts_token_manager is None :
152162 if self .ak is None or self .sk is None :
153163 raise ArkAPIError ("must set ak and sk before get endpoint token." )
154- self ._sts_token_manager = StsTokenManager (self .ak , self .sk )
164+ self ._sts_token_manager = StsTokenManager (self .ak , self .sk , self . region )
155165 return self ._sts_token_manager .get (endpoint_id )
156166
157167 @property
@@ -169,7 +179,7 @@ class StsTokenManager(object):
169179 # refreshed credentials.
170180 _mandatory_refresh_timeout : int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
171181
172- def __init__ (self , ak : str , sk : str ):
182+ def __init__ (self , ak : str , sk : str , region : str ):
173183 self ._endpoint_sts_tokens : Dict [str , Tuple [str , int ]] = defaultdict (lambda : ("" , 0 ))
174184 self ._refresh_lock = threading .Lock ()
175185
@@ -178,7 +188,8 @@ def __init__(self, ak: str, sk: str):
178188 configuration = volcenginesdkcore .Configuration ()
179189 configuration .ak = ak
180190 configuration .sk = sk
181- configuration .region = "cn-beijing"
191+ configuration .region = region
192+ configuration .schema = "https"
182193
183194 volcenginesdkcore .Configuration .set_default (configuration )
184195 self .api_instance = volcenginesdkark .ARKApi ()
@@ -190,8 +201,8 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
190201 return self ._endpoint_sts_tokens [ep ][1 ] - time .time () < refresh_in
191202
192203 def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ):
193- if ttl < _DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2 :
194- raise ArkAPIError ("ttl should not be under {} seconds." .format (_DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2 ))
204+ if ttl < self . _advisory_refresh_timeout * 2 :
205+ raise ArkAPIError ("ttl should not be under {} seconds." .format (self . _advisory_refresh_timeout * 2 ))
195206
196207 try :
197208 api_key , expired_time = self ._load_api_key (
0 commit comments