Skip to content

Commit ee41d32

Browse files
author
liuhuiqi.7
committed
feat(e2ee): add aicc option
feat(e2ee): add aicc option in requests feat(tks): get cname from cert feat(tks): x-Encrypt-Info feat(e2ee): fix lint feat(cert): info from dns feat(cert): info from dns feat(tks client): add dns to cert feat(cert): check local cert feat(aicc): exception feat(cert): check with re feat(aicc): skip moderation feat(skip): rm header
1 parent 43ed87d commit ee41d32

File tree

4 files changed

+111
-26
lines changed

4 files changed

+111
-26
lines changed

volcenginesdkark/models/get_endpoint_certificate_request.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,28 @@ class GetEndpointCertificateRequest(object):
3333
and the value is json key in definition.
3434
"""
3535
swagger_types = {
36-
'id': 'str'
36+
'id': 'str',
37+
'type': 'str'
3738
}
3839

3940
attribute_map = {
40-
'id': 'Id'
41+
'id': 'Id',
42+
'type': 'Type'
4143
}
4244

43-
def __init__(self, id=None, _configuration=None): # noqa: E501
45+
def __init__(self, id=None, type=None, _configuration=None): # noqa: E501
4446
"""GetEndpointCertificateRequest - a model defined in Swagger""" # noqa: E501
4547
if _configuration is None:
4648
_configuration = Configuration()
4749
self._configuration = _configuration
4850

4951
self._id = None
52+
self._type = None
5053
self.discriminator = None
5154

5255
self.id = id
56+
if type is not None:
57+
self.type = type
5358

5459
@property
5560
def id(self):
@@ -74,6 +79,27 @@ def id(self, id):
7479

7580
self._id = id
7681

82+
@property
83+
def type(self):
84+
"""Gets the type of this GetEndpointCertificateRequest. # noqa: E501
85+
86+
87+
:return: The type of this GetEndpointCertificateRequest. # noqa: E501
88+
:rtype: str
89+
"""
90+
return self._type
91+
92+
@type.setter
93+
def type(self, type):
94+
"""Sets the type of this GetEndpointCertificateRequest.
95+
96+
97+
:param type: The type of this GetEndpointCertificateRequest. # noqa: E501
98+
:type: str
99+
"""
100+
101+
self._type = type
102+
77103
def to_dict(self):
78104
"""Returns the model properties as a dict"""
79105
result = {}
@@ -121,4 +147,4 @@ def __ne__(self, other):
121147
if not isinstance(other, GetEndpointCertificateRequest):
122148
return True
123149

124-
return self.to_dict() != other.to_dict()
150+
return self.to_dict() != other.to_dict()

volcenginesdkarkruntime/_client.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from ._streaming import Stream
4444

45-
from ._utils._key_agreement import key_agreement_client
45+
from ._utils._key_agreement import key_agreement_client, get_cert_info
4646
from ._utils._model_breaker import ModelBreaker
4747

4848
__all__ = ["Ark", "AsyncArk"]
@@ -143,7 +143,7 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
143143
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
144144
return self._sts_token_manager.get(endpoint_id)
145145

146-
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
146+
def _get_endpoint_certificate(self, endpoint_id: str) -> tuple[key_agreement_client, str, str]:
147147
if self._certificate_manager is None:
148148
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
149149
if (
@@ -279,7 +279,7 @@ def _get_bot_sts_token(self, bot_id: str):
279279
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
280280
return self._sts_token_manager.get(bot_id, resource_type="bot")
281281

282-
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
282+
def _get_endpoint_certificate(self, endpoint_id: str) -> tuple[key_agreement_client, str, str]:
283283
if self._certificate_manager is None:
284284
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
285285
if (
@@ -429,7 +429,7 @@ def __init__(
429429
base_url: str | URL = BASE_URL,
430430
api_key: str | None = None,
431431
):
432-
self._certificate_manager: Dict[str, key_agreement_client] = {}
432+
self._certificate_manager: Dict[str, Tuple[key_agreement_client, str, str]] = {}
433433

434434
# local cache prepare
435435
self._init_local_cert_cache()
@@ -460,6 +460,10 @@ def __init__(
460460
self._e2e_uri = "/e2e/get/certificate"
461461
self._x_session_token = {"X-Session-Token": self._e2e_uri}
462462

463+
self._aicc_enabled = False
464+
if os.environ.get("VOLC_ARK_ENCRYPTION") == "AICC":
465+
self._aicc_enabled = True
466+
463467
def _load_cert_by_cert_path(self) -> str:
464468
with open(self.cert_path, "r") as f:
465469
cert_pem = f.read()
@@ -469,6 +473,10 @@ def _load_cert_by_ak_sk(self, ep: str) -> str:
469473
get_endpoint_certificate_request = (
470474
volcenginesdkark.GetEndpointCertificateRequest(id=ep)
471475
)
476+
if self._aicc_enabled:
477+
get_endpoint_certificate_request = (
478+
volcenginesdkark.GetEndpointCertificateRequest(id=ep, type="AICCv0.1")
479+
)
472480
try:
473481
resp: volcenginesdkark.GetEndpointCertificateResponse = (
474482
self.api_instance.get_endpoint_certificate(
@@ -484,10 +492,13 @@ def _load_cert_by_ak_sk(self, ep: str) -> str:
484492

485493
def _sync_load_cert_by_auth(self, ep: str) -> str:
486494
try: # try to make request with session header (used for header statistic)
495+
req_body = {"model": ep}
496+
if self._aicc_enabled:
497+
req_body["type"] = "AICCv0.1"
487498
resp = self.client.post(
488499
self._e2e_uri,
489500
options={"headers": self._x_session_token},
490-
body={"model": ep},
501+
body=req_body,
491502
cast_to=CertificateResponse,
492503
)
493504
except Exception as e:
@@ -508,10 +519,16 @@ def _load_cert_locally(self, ep: str) -> str | None:
508519
current_time = time.time()
509520
time_difference = current_time - last_modified_time
510521
if time_difference <= self._cert_expiration_seconds:
522+
cert_pem = None
511523
with open(cert_file_path, "r") as f:
512-
return f.read()
513-
else:
514-
os.remove(cert_file_path)
524+
cert_pem = f.read()
525+
ring, key = get_cert_info(cert_pem)
526+
# check cert is complement with AICC/PCA
527+
if (ring == "" or key == "") and not self._aicc_enabled:
528+
return cert_pem
529+
if ring != "" and key != "" and self._aicc_enabled:
530+
return cert_pem
531+
os.remove(cert_file_path)
515532
return None
516533

517534
def _init_local_cert_cache(self):
@@ -528,7 +545,7 @@ def _init_local_cert_cache(self):
528545
"failed to create certificate directory %s: %s\n" % (self._cert_storage_path, e)
529546
)
530547

531-
def get(self, ep: str) -> key_agreement_client:
548+
def get(self, ep: str) -> tuple[key_agreement_client, str, str]:
532549
if ep not in self._certificate_manager:
533550
cert_pem = self._load_cert_locally(ep)
534551
if cert_pem is None:
@@ -539,7 +556,12 @@ def get(self, ep: str) -> key_agreement_client:
539556
else:
540557
cert_pem = self._load_cert_by_ak_sk(ep)
541558
self._save_cert_to_file(ep, cert_pem)
542-
self._certificate_manager[ep] = key_agreement_client(
543-
certificate_pem_string=cert_pem
559+
ring, key = get_cert_info(cert_pem)
560+
self._certificate_manager[ep] = (
561+
key_agreement_client(
562+
certificate_pem_string=cert_pem
563+
),
564+
ring,
565+
key,
544566
)
545567
return self._certificate_manager[ep]

volcenginesdkarkruntime/_utils/_key_agreement.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@
1717
from typing import Tuple
1818

1919

20+
def get_cert_info(cert_pem: str) -> Tuple[str, str]:
21+
import re
22+
from cryptography import x509
23+
from cryptography.hazmat.backends import default_backend
24+
25+
cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend())
26+
try:
27+
dns = cert.extensions.get_extension_for_class(
28+
x509.SubjectAlternativeName).value.get_values_for_type(x509.DNSName)
29+
if dns and len(dns) > 1 and re.match(r"^ring\..*$", dns[0]) and re.match(r"^key\..*$", dns[1]):
30+
return dns[0].strip("ring."), dns[1].strip("key.")
31+
except Exception:
32+
pass
33+
return "", ""
34+
35+
2036
def aes_gcm_encrypt_bytes(
2137
key: bytes, iv: bytes, plain_bytes: bytes, associated_data: bytes = b""
2238
) -> bytes:

volcenginesdkarkruntime/resources/chat/completions.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
AsyncIterator,
2424
)
2525

26+
import os
27+
import json
2628
import httpx
2729
import warnings
2830
from typing_extensions import Literal
@@ -69,7 +71,8 @@ def _process_messages(
6971
part["text"] = f(part["text"])
7072
elif part.get("type", None) == "image_url":
7173
if part["image_url"]["url"].startswith("data:"):
72-
part["image_url"]["url"] = f(part["image_url"]["url"])
74+
part["image_url"]["url"] = f(
75+
part["image_url"]["url"])
7376
else:
7477
warnings.warn(
7578
"encryption is not supported for image url, "
@@ -103,15 +106,16 @@ def _encrypt(
103106
model: str,
104107
messages: Iterable[ChatCompletionMessageParam],
105108
extra_headers: Headers,
106-
) -> tuple[bytes, bytes]:
107-
client = self._client._get_endpoint_certificate(model)
109+
) -> tuple[bytes, bytes, str, str]:
110+
client, ring_id, key_id = self._client._get_endpoint_certificate(model)
108111
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
109112
extra_headers["X-Session-Token"] = session_token
110113
_process_messages(
111114
messages,
112-
lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x),
115+
lambda x: client.encrypt_string_with_key(
116+
_crypto_key, _crypto_nonce, x),
113117
)
114-
return _crypto_key, _crypto_nonce
118+
return _crypto_key, _crypto_nonce, ring_id, key_id
115119

116120
def _decrypt_chunk(
117121
self, key: bytes, nonce: bytes, resp: Stream[ChatCompletionChunk]
@@ -197,7 +201,15 @@ def create(
197201
):
198202
is_encrypt = True
199203
messages = deepcopy_minimal(messages)
200-
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
204+
e2e_key, e2e_nonce, ring_id, key_id = self._encrypt(
205+
model, messages, extra_headers)
206+
if os.environ.get("VOLC_ARK_ENCRYPTION") == "AICC":
207+
info = {
208+
'Version': 'AICCv0.1',
209+
'RingID': ring_id,
210+
'KeyID': key_id,
211+
}
212+
extra_headers["X-Encrypt-Info"] = json.dumps(info)
201213

202214
resp = self._post(
203215
"/chat/completions",
@@ -257,15 +269,16 @@ def _encrypt(
257269
model: str,
258270
messages: Iterable[ChatCompletionMessageParam],
259271
extra_headers: Headers,
260-
) -> tuple[bytes, bytes]:
261-
client = self._client._get_endpoint_certificate(model)
272+
) -> tuple[bytes, bytes, str, str]:
273+
client, ring_id, key_id = self._client._get_endpoint_certificate(model)
262274
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
263275
extra_headers["X-Session-Token"] = session_token
264276
_process_messages(
265277
messages,
266-
lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x),
278+
lambda x: client.encrypt_string_with_key(
279+
_crypto_key, _crypto_nonce, x),
267280
)
268-
return _crypto_key, _crypto_nonce
281+
return _crypto_key, _crypto_nonce, ring_id, key_id
269282

270283
async def _decrypt_chunk(
271284
self, key: bytes, nonce: bytes, resp: AsyncStream[ChatCompletionChunk]
@@ -351,7 +364,15 @@ async def create(
351364
):
352365
is_encrypt = True
353366
messages = deepcopy_minimal(messages)
354-
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
367+
e2e_key, e2e_nonce, ring_id, key_id = self._encrypt(
368+
model, messages, extra_headers)
369+
if os.environ.get("VOLC_ARK_ENCRYPTION") == "AICC":
370+
info = {
371+
'Version': 'AICCv0.1',
372+
'RingID': ring_id,
373+
'KeyID': key_id,
374+
}
375+
extra_headers["X-Encrypt-Info"] = json.dumps(info)
355376

356377
resp = await self._post(
357378
"/chat/completions",

0 commit comments

Comments
 (0)