Skip to content

Commit f910c47

Browse files
author
BitsAdmin
committed
Merge branch 'feat/ark_local_cert_cache' into 'integration_2025-01-02_663619381762'
feat: [development task] ark-runtime-manual-Python (951048) See merge request iaasng/volcengine-python-sdk!475
2 parents b08b8f9 + 7a4c95f commit f910c47

File tree

3 files changed

+111
-73
lines changed

3 files changed

+111
-73
lines changed

volcenginesdkarkruntime/_client.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
sk = os.environ.get("VOLC_SECRETKEY")
7272
if api_key is None:
7373
api_key = os.environ.get("ARK_API_KEY")
74+
self._base_url = base_url
7475
self.ak = ak
7576
self.sk = sk
7677
self.api_key = api_key
@@ -107,10 +108,10 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
107108
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
108109
if self._certificate_manager is None:
109110
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
110-
if (self.ak is None or self.sk is None) and cert_path is None:
111-
raise ArkAPIError("must set (ak and sk) or (E2E_CERTIFICATE_PATH) \
112-
before get endpoint token.")
113-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region)
111+
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
112+
raise ArkAPIError("must set (api_key) or (ak and sk) \
113+
or (E2E_CERTIFICATE_PATH) before get endpoint token.")
114+
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self._base_url, self.api_key)
114115
return self._certificate_manager.get(endpoint_id)
115116

116117
def _get_bot_sts_token(self, bot_id: str):
@@ -164,6 +165,7 @@ def __init__(
164165
sk = os.environ.get("VOLC_SECRETKEY")
165166
if api_key is None:
166167
api_key = os.environ.get("ARK_API_KEY")
168+
self._base_url = base_url
167169
self.ak = ak
168170
self.sk = sk
169171
self.api_key = api_key
@@ -200,10 +202,10 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
200202
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
201203
if self._certificate_manager is None:
202204
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
203-
if (self.ak is None or self.sk is None) and cert_path is None:
204-
raise ArkAPIError("must set (ak and sk) or (E2E_CERTIFICATE_PATH) \
205-
before get endpoint token.")
206-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region)
205+
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
206+
raise ArkAPIError("must set (api_key) or (ak and sk) \
207+
or (E2E_CERTIFICATE_PATH) before get endpoint token.")
208+
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self._base_url, self.api_key)
207209
return self._certificate_manager.get(endpoint_id)
208210

209211
@property
@@ -301,22 +303,46 @@ def _load_api_key(self, ep: str, duration_seconds: int,
301303

302304
class E2ECertificateManager(object):
303305

304-
def __init__(self, ak: str, sk: str, region: str):
306+
class CertificateResponse():
307+
Certificate: str
308+
"""The certificate content."""
309+
310+
def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL, api_key: str | None = None):
305311
self._certificate_manager: Dict[str, key_agreement_client] = {}
306312

307-
import volcenginesdkcore
313+
# local cache prepare
314+
self._init_local_cert_cache()
308315

316+
# api instance prepare
317+
import volcenginesdkcore
318+
self._api_instance_enabled = True
319+
if ak is None or sk is None:
320+
self._api_instance_enabled = False
309321
configuration = volcenginesdkcore.Configuration()
310322
configuration.ak = ak
311323
configuration.sk = sk
312324
configuration.region = region
313325
configuration.schema = "https"
314-
315326
volcenginesdkcore.Configuration.set_default(configuration)
316327
self.api_instance = volcenginesdkark.ARKApi()
317328

329+
# global cert path prepare
318330
self.cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
319331

332+
# ark client prepare
333+
self.client = Ark(
334+
base_url=base_url,
335+
api_key=api_key,
336+
ak=ak, sk=sk,
337+
)
338+
self._e2e_uri = "/e2e/get/certificate"
339+
self._x_session_token = {'X-Session-Token': self._e2e_uri}
340+
341+
def _load_cert_by_cert_path(self) -> str:
342+
with open(self.cert_path, 'r') as f:
343+
cert_pem = f.read()
344+
return cert_pem
345+
320346
def _load_cert_by_ak_sk(self, ep: str) -> str:
321347
get_endpoint_certificate_request = volcenginesdkark.GetEndpointCertificateRequest(
322348
id=ep
@@ -329,13 +355,52 @@ def _load_cert_by_ak_sk(self, ep: str) -> str:
329355

330356
return resp.pca_instance_certificate
331357

358+
def _sync_load_cert_by_auth(self, ep: str) -> str:
359+
try: # try to make request with session header (used for header statistic)
360+
resp = self.client.post(self._e2e_uri, options={"headers": self._x_session_token},
361+
body={"model": ep}, cast_to=self.CertificateResponse)
362+
except Exception as e:
363+
raise ArkAPIError("Getting Certificate failed: %s\n" % e)
364+
if 'error' in resp:
365+
raise ArkAPIError("Getting Certificate failed: %s\n" % resp['error'])
366+
return resp['Certificate']
367+
368+
def _save_cert_to_file(self, ep: str, cert_pem: str):
369+
cert_file_path = os.path.join(self._cert_storage_path, f"{ep}.pem")
370+
with open(cert_file_path, 'w') as f:
371+
f.write(cert_pem)
372+
373+
def _load_cert_locally(self, ep: str) -> str | None:
374+
cert_file_path = os.path.join(self._cert_storage_path, f"{ep}.pem")
375+
if os.path.exists(cert_file_path):
376+
last_modified_time = os.path.getmtime(cert_file_path)
377+
current_time = time.time()
378+
time_difference = current_time - last_modified_time
379+
if time_difference <= self._cert_expiration_seconds:
380+
with open(cert_file_path, 'r') as f:
381+
return f.read()
382+
else:
383+
os.remove(cert_file_path)
384+
return None
385+
386+
def _init_local_cert_cache(self):
387+
self._cert_storage_path = "/tmp/ark/certificates"
388+
self._cert_expiration_seconds = 14 * 24 * 60 * 60 # 14 days
389+
390+
if not os.path.exists(self._cert_storage_path):
391+
os.makedirs(self._cert_storage_path)
392+
332393
def get(self, ep: str) -> key_agreement_client:
333394
if ep not in self._certificate_manager:
334-
if self.cert_path is not None:
335-
with open(self.cert_path, 'r') as f:
336-
cert_pem = f.read()
337-
else:
338-
cert_pem = self._load_cert_by_ak_sk(ep)
395+
cert_pem = self._load_cert_locally(ep)
396+
if cert_pem is None:
397+
if self.cert_path is not None:
398+
cert_pem = self._load_cert_by_cert_path()
399+
elif self._api_instance_enabled:
400+
cert_pem = self._load_cert_by_ak_sk(ep)
401+
else:
402+
cert_pem = self._sync_load_cert_by_auth(ep)
403+
self._save_cert_to_file(ep, cert_pem)
339404
self._certificate_manager[ep] = key_agreement_client(
340405
certificate_pem_string=cert_pem
341406
)

volcenginesdkarkruntime/_utils/_key_agreement.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import base64
4-
4+
from typing import Tuple
55

66
def aes_gcm_encrypt_bytes(key: bytes, iv: bytes, plain_bytes: bytes, associated_data: bytes = b"") -> bytes:
77
# aes_gcm_encrypt_bytes encrypt message using AES-GCM
@@ -85,7 +85,7 @@ def __init__(self, certificate_pem_string: str) -> None:
8585
self._public_key = ec.EllipticCurvePublicNumbers(
8686
cert_pub.x, cert_pub.y, self._curve).public_key()
8787

88-
def encrypt_string(self, plaintext: str) -> tuple[bytes, bytes, str, str]:
88+
def encrypt_string(self, plaintext: str) -> Tuple[bytes, bytes, str, str]:
8989
"""encrypt_string encrypt plaintext with ECIES DH protocol
9090
"""
9191
key, nonce, token = self.generate_ecies_key_pair()
@@ -106,7 +106,7 @@ def decrypt_string_with_key(self, key: bytes, nonce: bytes, ciphertext: str) ->
106106
# Decrypt message using AES-GCM
107107
return aes_gcm_decrypt_base64_string(key, nonce, ciphertext)
108108

109-
def generate_ecies_key_pair(self) -> tuple[bytes, bytes, str]:
109+
def generate_ecies_key_pair(self) -> Tuple[bytes, bytes, str]:
110110
"""generate_ecies_key_pair generate ECIES key pair
111111
"""
112112
from cryptography.hazmat.primitives import hashes

volcenginesdkarkruntime/resources/chat/completions.py

Lines changed: 27 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,31 @@
3333

3434
__all__ = ["Completions", "AsyncCompletions"]
3535

36+
37+
def _process_messages(messages: Iterable[ChatCompletionMessageParam],
38+
f: Callable[[str], str]):
39+
for message in messages:
40+
if message.get("content", None) is not None:
41+
current_content = message.get("content")
42+
if isinstance(current_content, str):
43+
message["content"] = f(current_content)
44+
elif isinstance(current_content, Iterable):
45+
for part in current_content:
46+
if part.get("type", None) == "text":
47+
part["text"] = f(part["text"])
48+
elif part.get("type", None) == "image_url":
49+
if part["image_url"]["url"].startswith('data:'):
50+
part["image_url"]["url"] = f(part["image_url"]["url"])
51+
else:
52+
warnings.warn("encryption is not supported for image url, "
53+
"please use base64 image if you want encryption")
54+
else:
55+
raise TypeError("encryption is not supported for content type {}".
56+
format(type(part)))
57+
else:
58+
raise TypeError("encryption is not supported for content type {}".
59+
format(type(message.get('content'))))
60+
3661
class Completions(SyncAPIResource):
3762
@cached_property
3863
def with_raw_response(self) -> CompletionsWithRawResponse:
@@ -42,38 +67,12 @@ def with_raw_response(self) -> CompletionsWithRawResponse:
4267
def with_streaming_response(self) -> CompletionsWithStreamingResponse:
4368
return CompletionsWithStreamingResponse(self)
4469

45-
def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
46-
f: Callable[[str], str]):
47-
for message in messages:
48-
if message.get("content", None) is not None:
49-
current_content = message.get("content")
50-
if isinstance(current_content, str):
51-
message["content"] = f(current_content)
52-
elif isinstance(current_content, Iterable):
53-
for part in current_content:
54-
if part.get("type", None) == "text":
55-
part["text"] = f(part["text"])
56-
elif part.get("type", None) == "image_url":
57-
if part["image_url"]["url"].startswith('data:'):
58-
part["image_url"]["url"] = f(part["image_url"]["url"])
59-
else:
60-
warnings.warn("encryption is not supported for image url, "
61-
"please use base64 image if you want encryption")
62-
else:
63-
raise TypeError("encryption is not supported for content type {}".
64-
format(type(part)))
65-
else:
66-
raise TypeError("encryption is not supported for content type {}".
67-
format(type(message.get('content'))))
68-
6970
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers
7071
) -> tuple[bytes, bytes]:
7172
client = self._client._get_endpoint_certificate(model)
7273
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
7374
extra_headers['X-Session-Token'] = session_token
74-
self._process_messages(messages, lambda x: client.encrypt_string_with_key(_crypto_key,
75-
_crypto_nonce,
76-
x))
75+
_process_messages(messages, lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x))
7776
return _crypto_key, _crypto_nonce
7877

7978
def _decrypt_chunk(self, key: bytes, nonce: bytes, resp: Stream[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
@@ -180,38 +179,12 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
180179
def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
181180
return AsyncCompletionsWithStreamingResponse(self)
182181

183-
def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
184-
f: Callable[[str], str]):
185-
for message in messages:
186-
if message.get("content", None) is not None:
187-
current_content = message.get("content")
188-
if isinstance(current_content, str):
189-
message["content"] = f(current_content)
190-
elif isinstance(current_content, Iterable):
191-
for part in current_content:
192-
if part.get("type", None) == "text":
193-
part["text"] = f(part["text"])
194-
elif part.get("type", None) == "image_url":
195-
if part["image_url"]["url"].startswith('data:'):
196-
part["image_url"]["url"] = f(part["image_url"]["url"])
197-
else:
198-
warnings.warn("encryption is not supported for image url, "
199-
"please use base64 image if you want encryption")
200-
else:
201-
raise TypeError("encryption is not supported for content type {}".
202-
format(type(part)))
203-
else:
204-
raise TypeError("encryption is not supported for content type {}".
205-
format(type(message.get('content'))))
206-
207182
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers
208183
) -> tuple[bytes, bytes]:
209184
client = self._client._get_endpoint_certificate(model)
210185
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
211186
extra_headers['X-Session-Token'] = session_token
212-
self._process_messages(messages, lambda x: client.encrypt_string_with_key(_crypto_key,
213-
_crypto_nonce,
214-
x))
187+
_process_messages(messages, lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x))
215188
return _crypto_key, _crypto_nonce
216189

217190
async def _decrypt_chunk(self, key: bytes, nonce: bytes, resp: AsyncStream[ChatCompletionChunk]

0 commit comments

Comments
 (0)