Skip to content

Commit f4cc4be

Browse files
author
liuhuiqi.7
committed
feat(e2e api key): support ark api key agent
Change-Id: I95a02f8043931d075a45993ee5cc1672c4b6ef48
1 parent cd17025 commit f4cc4be

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

volcenginesdkarkruntime/_client.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
sk = os.environ.get("VOLC_SECRETKEY")
7272
if api_key is None:
7373
api_key = os.environ.get("ARK_API_KEY")
74+
self._base_url = base_url
7475
self.ak = ak
7576
self.sk = sk
7677
self.api_key = api_key
@@ -107,10 +108,10 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
107108
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
108109
if self._certificate_manager is None:
109110
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
110-
if (self.ak is None or self.sk is None) and cert_path is None:
111+
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
111112
raise ArkAPIError("must set (ak and sk) or (E2E_CERTIFICATE_PATH) \
112-
before get endpoint token.")
113-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region)
113+
or (api_key) before get endpoint token.")
114+
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self, self._base_url, self.api_key)
114115
return self._certificate_manager.get(endpoint_id)
115116

116117
def _get_bot_sts_token(self, bot_id: str):
@@ -164,6 +165,7 @@ def __init__(
164165
sk = os.environ.get("VOLC_SECRETKEY")
165166
if api_key is None:
166167
api_key = os.environ.get("ARK_API_KEY")
168+
self._base_url = base_url
167169
self.ak = ak
168170
self.sk = sk
169171
self.api_key = api_key
@@ -200,10 +202,10 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
200202
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
201203
if self._certificate_manager is None:
202204
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
203-
if (self.ak is None or self.sk is None) and cert_path is None:
205+
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
204206
raise ArkAPIError("must set (ak and sk) or (E2E_CERTIFICATE_PATH) \
205-
before get endpoint token.")
206-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region)
207+
or (api_key) before get endpoint token.")
208+
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self, self._base_url, self.api_key)
207209
return self._certificate_manager.get(endpoint_id)
208210

209211
@property
@@ -301,11 +303,18 @@ def _load_api_key(self, ep: str, duration_seconds: int,
301303

302304
class E2ECertificateManager(object):
303305

304-
def __init__(self, ak: str, sk: str, region: str):
306+
class CertificateResponse():
307+
Certificate: str
308+
"""The certificate content."""
309+
310+
def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL, api_key: str | None = None):
305311
self._certificate_manager: Dict[str, key_agreement_client] = {}
306312

307313
import volcenginesdkcore
308314

315+
self._api_instance_enabled = True
316+
if ak is None or sk is None:
317+
self._api_instance_enabled = False
309318
configuration = volcenginesdkcore.Configuration()
310319
configuration.ak = ak
311320
configuration.sk = sk
@@ -317,6 +326,17 @@ def __init__(self, ak: str, sk: str, region: str):
317326

318327
self.cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
319328

329+
self.client = Ark(
330+
base_url=base_url,
331+
api_key=api_key
332+
)
333+
self._e2e_uri = "/e2e/get/certificate"
334+
335+
def _load_cert_by_cert_path(self) -> str:
336+
with open(self.cert_path, 'r') as f:
337+
cert_pem = f.read()
338+
return cert_pem
339+
320340
def _load_cert_by_ak_sk(self, ep: str) -> str:
321341
get_endpoint_certificate_request = volcenginesdkark.GetEndpointCertificateRequest(
322342
id=ep
@@ -329,13 +349,21 @@ def _load_cert_by_ak_sk(self, ep: str) -> str:
329349

330350
return resp.pca_instance_certificate
331351

352+
def _sync_load_cert_by_auth(self, ep: str) -> str:
353+
try:
354+
resp = self.client.post(self._e2e_uri, body={"model": ep}, cast_to=self.CertificateResponse)
355+
except Exception as e:
356+
raise ArkAPIError("Getting Certificate failed: %s\n" % e)
357+
return resp['Certificate']
358+
332359
def get(self, ep: str) -> key_agreement_client:
333360
if ep not in self._certificate_manager:
334361
if self.cert_path is not None:
335-
with open(self.cert_path, 'r') as f:
336-
cert_pem = f.read()
337-
else:
362+
cert_pem = self._load_cert_by_cert_path()
363+
elif self._api_instance_enabled:
338364
cert_pem = self._load_cert_by_ak_sk(ep)
365+
else:
366+
cert_pem = self._sync_load_cert_by_auth(ep)
339367
self._certificate_manager[ep] = key_agreement_client(
340368
certificate_pem_string=cert_pem
341369
)

volcenginesdkarkruntime/_utils/_key_agreement.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import base64
4-
4+
from typing import Tuple
55

66
def aes_gcm_encrypt_bytes(key: bytes, iv: bytes, plain_bytes: bytes, associated_data: bytes = b"") -> bytes:
77
# aes_gcm_encrypt_bytes encrypt message using AES-GCM
@@ -85,7 +85,7 @@ def __init__(self, certificate_pem_string: str) -> None:
8585
self._public_key = ec.EllipticCurvePublicNumbers(
8686
cert_pub.x, cert_pub.y, self._curve).public_key()
8787

88-
def encrypt_string(self, plaintext: str) -> tuple[bytes, bytes, str, str]:
88+
def encrypt_string(self, plaintext: str) -> Tuple[bytes, bytes, str, str]:
8989
"""encrypt_string encrypt plaintext with ECIES DH protocol
9090
"""
9191
key, nonce, token = self.generate_ecies_key_pair()
@@ -106,7 +106,7 @@ def decrypt_string_with_key(self, key: bytes, nonce: bytes, ciphertext: str) ->
106106
# Decrypt message using AES-GCM
107107
return aes_gcm_decrypt_base64_string(key, nonce, ciphertext)
108108

109-
def generate_ecies_key_pair(self) -> tuple[bytes, bytes, str]:
109+
def generate_ecies_key_pair(self) -> Tuple[bytes, bytes, str]:
110110
"""generate_ecies_key_pair generate ECIES key pair
111111
"""
112112
from cryptography.hazmat.primitives import hashes

0 commit comments

Comments
 (0)