Skip to content

Commit 59734d2

Browse files
author
BitsAdmin
committed
Merge branch 'feat/add_thinking' into 'integration_2025-05-15_905360636418'
feat: [development task] ark runtime (1224979) See merge request iaasng/volcengine-python-sdk!609
2 parents 97ec46f + 7a6e101 commit 59734d2

32 files changed

+691
-386
lines changed

volcenginesdkarkruntime/_base_client.py

Lines changed: 195 additions & 187 deletions
Large diffs are not rendered by default.

volcenginesdkarkruntime/_client.py

Lines changed: 105 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_DEFAULT_MANDATORY_REFRESH_TIMEOUT,
2525
_DEFAULT_STS_TIMEOUT,
2626
_DEFAULT_RESOURCE_TYPE,
27-
DEFAULT_TIMEOUT
27+
DEFAULT_TIMEOUT,
2828
)
2929
from ._streaming import Stream
3030

@@ -42,6 +42,7 @@ class Ark(SyncAPIClient):
4242
context: resources.Context
4343
multimodal_embeddings: resources.MultimodalEmbeddings
4444
content_generation: resources.ContentGeneration
45+
images: resources.Images
4546
batch_chat: resources.BatchChat
4647
model_breaker_map: dict[str, ModelBreaker]
4748
model_breaker_lock: threading.Lock
@@ -71,7 +72,6 @@ def __init__(
7172
Returns:
7273
ark client
7374
"""
74-
7575
if ak is None:
7676
ak = os.environ.get("VOLC_ACCESSKEY")
7777
if sk is None:
@@ -84,7 +84,9 @@ def __init__(
8484
self.api_key = api_key
8585
self.region = region
8686

87-
assert (api_key is not None) or (ak is not None and sk is not None), "you need to support api_key or ak&sk"
87+
assert (api_key is not None) or (ak is not None and sk is not None), (
88+
"you need to support api_key or ak&sk"
89+
)
8890

8991
super().__init__(
9092
base_url=base_url,
@@ -105,6 +107,7 @@ def __init__(
105107
self.context = resources.Context(self)
106108
self.multimodal_embeddings = resources.MultimodalEmbeddings(self)
107109
self.content_generation = resources.ContentGeneration(self)
110+
self.images = resources.Images(self)
108111
self.batch_chat = resources.BatchChat(self)
109112
self.model_breaker_map = defaultdict(ModelBreaker)
110113
self.model_breaker_lock = threading.Lock()
@@ -120,10 +123,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
120123
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
121124
if self._certificate_manager is None:
122125
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
123-
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
124-
raise ArkAPIError("must set (api_key) or (ak and sk) \
125-
or (E2E_CERTIFICATE_PATH) before get endpoint token.")
126-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self._base_url, self.api_key)
126+
if (
127+
(self.ak is None or self.sk is None)
128+
and cert_path is None
129+
and self.api_key is None
130+
):
131+
raise ArkAPIError(
132+
"must set (api_key) or (ak and sk) \
133+
or (E2E_CERTIFICATE_PATH) before get endpoint token."
134+
)
135+
self._certificate_manager = E2ECertificateManager(
136+
self.ak, self.sk, self.region, self._base_url, self.api_key
137+
)
127138
return self._certificate_manager.get(endpoint_id)
128139

129140
def _get_bot_sts_token(self, bot_id: str):
@@ -142,6 +153,7 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
142153
with self.model_breaker_lock:
143154
return self.model_breaker_map[model_name]
144155

156+
145157
class AsyncArk(AsyncAPIClient):
146158
chat: resources.AsyncChat
147159
bot_chat: resources.AsyncBotChat
@@ -150,6 +162,7 @@ class AsyncArk(AsyncAPIClient):
150162
context: resources.AsyncContext
151163
multimodal_embeddings: resources.AsyncMultimodalEmbeddings
152164
content_generation: resources.AsyncContentGeneration
165+
images: resources.AsyncImages
153166
batch_chat: resources.AsyncBatchChat
154167
model_breaker_map: dict[str, ModelBreaker]
155168
model_breaker_lock: asyncio.Lock
@@ -168,15 +181,15 @@ def __init__(
168181
) -> None:
169182
"""init async ark client, this client is thread unsafe
170183
171-
Args:
172-
ak: access key id
173-
sk: secret access key
174-
api_key: api key,this api key will not be refreshed
175-
timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
176-
max_retries: times of retry when request failed. default 1
177-
http_client: specify customized http_client
178-
Returns:
179-
async ark client
184+
Args:
185+
ak: access key id
186+
sk: secret access key
187+
api_key: api key,this api key will not be refreshed
188+
timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
189+
max_retries: times of retry when request failed. default 1
190+
http_client: specify customized http_client
191+
Returns:
192+
async ark client
180193
"""
181194

182195
if ak is None:
@@ -191,7 +204,9 @@ def __init__(
191204
self.api_key = api_key
192205
self.region = region
193206

194-
assert (api_key is not None) or (ak is not None and sk is not None), "you need to support api_key or ak&sk"
207+
assert (api_key is not None) or (ak is not None and sk is not None), (
208+
"you need to support api_key or ak&sk"
209+
)
195210

196211
super().__init__(
197212
base_url=base_url,
@@ -212,6 +227,7 @@ def __init__(
212227
self.context = resources.AsyncContext(self)
213228
self.multimodal_embeddings = resources.AsyncMultimodalEmbeddings(self)
214229
self.content_generation = resources.AsyncContentGeneration(self)
230+
self.images = resources.AsyncImages(self)
215231
self.batch_chat = resources.AsyncBatchChat(self)
216232
self.model_breaker_map = defaultdict(ModelBreaker)
217233
self.model_breaker_lock = asyncio.Lock()
@@ -227,10 +243,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
227243
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
228244
if self._certificate_manager is None:
229245
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
230-
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
231-
raise ArkAPIError("must set (api_key) or (ak and sk) \
232-
or (E2E_CERTIFICATE_PATH) before get endpoint token.")
233-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self._base_url, self.api_key)
246+
if (
247+
(self.ak is None or self.sk is None)
248+
and cert_path is None
249+
and self.api_key is None
250+
):
251+
raise ArkAPIError(
252+
"must set (api_key) or (ak and sk) \
253+
or (E2E_CERTIFICATE_PATH) before get endpoint token."
254+
)
255+
self._certificate_manager = E2ECertificateManager(
256+
self.ak, self.sk, self.region, self._base_url, self.api_key
257+
)
234258
return self._certificate_manager.get(endpoint_id)
235259

236260
@property
@@ -252,7 +276,9 @@ class StsTokenManager(object):
252276
_mandatory_refresh_timeout: int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
253277

254278
def __init__(self, ak: str, sk: str, region: str):
255-
self._endpoint_sts_tokens: Dict[str, Tuple[str, int]] = defaultdict(lambda: ("", 0))
279+
self._endpoint_sts_tokens: Dict[str, Tuple[str, int]] = defaultdict(
280+
lambda: ("", 0)
281+
)
256282
self._refresh_lock = threading.Lock()
257283

258284
import volcenginesdkcore
@@ -272,10 +298,19 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
272298

273299
return self._endpoint_sts_tokens[ep][1] - time.time() < refresh_in
274300

275-
def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandatory: bool = False,
276-
resource_type: str = _DEFAULT_RESOURCE_TYPE):
301+
def _protected_refresh(
302+
self,
303+
ep: str,
304+
ttl: int = _DEFAULT_STS_TIMEOUT,
305+
is_mandatory: bool = False,
306+
resource_type: str = _DEFAULT_RESOURCE_TYPE,
307+
):
277308
if ttl < self._advisory_refresh_timeout * 2:
278-
raise ArkAPIError("ttl should not be under {} seconds.".format(self._advisory_refresh_timeout * 2))
309+
raise ArkAPIError(
310+
"ttl should not be under {} seconds.".format(
311+
self._advisory_refresh_timeout * 2
312+
)
313+
)
279314

280315
try:
281316
api_key, expired_time = self._load_api_key(
@@ -301,7 +336,9 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
301336
ep, self._mandatory_refresh_timeout
302337
)
303338

304-
self._protected_refresh(ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type)
339+
self._protected_refresh(
340+
ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type
341+
)
305342
return
306343
finally:
307344
self._refresh_lock.release()
@@ -310,14 +347,20 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
310347
if not self._need_refresh(ep, self._mandatory_refresh_timeout):
311348
return
312349

313-
self._protected_refresh(ep, is_mandatory=True, resource_type=resource_type)
350+
self._protected_refresh(
351+
ep, is_mandatory=True, resource_type=resource_type
352+
)
314353

315354
def get(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE) -> str:
316355
self._refresh(ep, resource_type=resource_type)
317356
return self._endpoint_sts_tokens[ep][0]
318357

319-
def _load_api_key(self, ep: str, duration_seconds: int,
320-
resource_type: str = _DEFAULT_RESOURCE_TYPE) -> Tuple[str, int]:
358+
def _load_api_key(
359+
self,
360+
ep: str,
361+
duration_seconds: int,
362+
resource_type: str = _DEFAULT_RESOURCE_TYPE,
363+
) -> Tuple[str, int]:
321364
get_api_key_request = volcenginesdkark.GetApiKeyRequest(
322365
duration_seconds=duration_seconds,
323366
resource_type=resource_type,
@@ -331,19 +374,26 @@ def _load_api_key(self, ep: str, duration_seconds: int,
331374

332375

333376
class E2ECertificateManager(object):
334-
335-
class CertificateResponse():
377+
class CertificateResponse:
336378
Certificate: str
337379
"""The certificate content."""
338380

339-
def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL, api_key: str | None = None):
381+
def __init__(
382+
self,
383+
ak: str,
384+
sk: str,
385+
region: str,
386+
base_url: str | URL = BASE_URL,
387+
api_key: str | None = None,
388+
):
340389
self._certificate_manager: Dict[str, key_agreement_client] = {}
341390

342391
# local cache prepare
343392
self._init_local_cert_cache()
344393

345394
# api instance prepare
346395
import volcenginesdkcore
396+
347397
configuration = volcenginesdkcore.Configuration()
348398
configuration.ak = ak
349399
configuration.sk = sk
@@ -365,38 +415,47 @@ def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL
365415
api_key=api_key,
366416
)
367417
self._e2e_uri = "/e2e/get/certificate"
368-
self._x_session_token = {'X-Session-Token': self._e2e_uri}
418+
self._x_session_token = {"X-Session-Token": self._e2e_uri}
369419

370420
def _load_cert_by_cert_path(self) -> str:
371-
with open(self.cert_path, 'r') as f:
421+
with open(self.cert_path, "r") as f:
372422
cert_pem = f.read()
373423
return cert_pem
374424

375425
def _load_cert_by_ak_sk(self, ep: str) -> str:
376-
get_endpoint_certificate_request = volcenginesdkark.GetEndpointCertificateRequest(
377-
id=ep
426+
get_endpoint_certificate_request = (
427+
volcenginesdkark.GetEndpointCertificateRequest(id=ep)
378428
)
379429
try:
380-
resp: volcenginesdkark.GetEndpointCertificateResponse = self.api_instance.get_endpoint_certificate(
381-
get_endpoint_certificate_request)
430+
resp: volcenginesdkark.GetEndpointCertificateResponse = (
431+
self.api_instance.get_endpoint_certificate(
432+
get_endpoint_certificate_request
433+
)
434+
)
382435
except ApiException as e:
383-
raise ArkAPIError("Getting model vendor encryption certificate failed: %s\n" % e)
436+
raise ArkAPIError(
437+
"Getting model vendor encryption certificate failed: %s\n" % e
438+
)
384439

385440
return resp.pca_instance_certificate
386441

387442
def _sync_load_cert_by_auth(self, ep: str) -> str:
388443
try: # try to make request with session header (used for header statistic)
389-
resp = self.client.post(self._e2e_uri, options={"headers": self._x_session_token},
390-
body={"model": ep}, cast_to=self.CertificateResponse)
444+
resp = self.client.post(
445+
self._e2e_uri,
446+
options={"headers": self._x_session_token},
447+
body={"model": ep},
448+
cast_to=self.CertificateResponse,
449+
)
391450
except Exception as e:
392451
raise ArkAPIError("Getting Certificate failed: %s\n" % e)
393-
if 'error' in resp:
394-
raise ArkAPIError("Getting Certificate failed: %s\n" % resp['error'])
395-
return resp['Certificate']
452+
if "error" in resp:
453+
raise ArkAPIError("Getting Certificate failed: %s\n" % resp["error"])
454+
return resp["Certificate"]
396455

397456
def _save_cert_to_file(self, ep: str, cert_pem: str):
398457
cert_file_path = os.path.join(self._cert_storage_path, f"{ep}.pem")
399-
with open(cert_file_path, 'w') as f:
458+
with open(cert_file_path, "w") as f:
400459
f.write(cert_pem)
401460

402461
def _load_cert_locally(self, ep: str) -> str | None:
@@ -406,7 +465,7 @@ def _load_cert_locally(self, ep: str) -> str | None:
406465
current_time = time.time()
407466
time_difference = current_time - last_modified_time
408467
if time_difference <= self._cert_expiration_seconds:
409-
with open(cert_file_path, 'r') as f:
468+
with open(cert_file_path, "r") as f:
410469
return f.read()
411470
else:
412471
os.remove(cert_file_path)

volcenginesdkarkruntime/_compat.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
2525
...
2626

27-
def parse_datetime(
28-
value: Union[datetime, StrBytesIntFloat]
29-
) -> datetime: # noqa: ARG001
27+
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
3028
...
3129

3230
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
@@ -87,9 +85,7 @@ def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
8785
if PYDANTIC_V2:
8886
return model.model_validate(value)
8987
else:
90-
return cast(
91-
_ModelT, model.parse_obj(value)
92-
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
88+
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
9389

9490

9591
def field_is_required(field: FieldInfo) -> bool:

volcenginesdkarkruntime/_constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
DEFAULT_TIMEOUT_SECONDS = 600.0
1414
DEFAULT_CONNECT_TIMEOUT_SECONDS = 60.0
1515
# default timeout is 1 minutes
16-
DEFAULT_TIMEOUT = httpx.Timeout(timeout=DEFAULT_TIMEOUT_SECONDS, connect=DEFAULT_CONNECT_TIMEOUT_SECONDS)
16+
DEFAULT_TIMEOUT = httpx.Timeout(
17+
timeout=DEFAULT_TIMEOUT_SECONDS, connect=DEFAULT_CONNECT_TIMEOUT_SECONDS
18+
)
1719

1820
DEFAULT_MAX_RETRIES = 2
1921
DEFAULT_CONNECTION_LIMITS = httpx.Limits(

0 commit comments

Comments
 (0)