Skip to content

Commit 08c8149

Browse files
author
BitsAdmin
committed
Merge branch 'feat/arkruntime-seedance-sdk' into 'integration_2025-10-27_1073654796290'
feat: [development task] ark (1769068) See merge request iaasng/volcengine-python-sdk!879
2 parents f2ef679 + 2fcc0b8 commit 08c8149

File tree

11 files changed

+376
-425
lines changed

11 files changed

+376
-425
lines changed

volcenginesdkarkruntime/_client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
143143
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
144144
return self._sts_token_manager.get(endpoint_id)
145145

146-
def _get_endpoint_certificate(self, endpoint_id: str) -> tuple[key_agreement_client, str, str]:
146+
def _get_endpoint_certificate(self, endpoint_id: str) -> Tuple[key_agreement_client, str, str, float]:
147147
if self._certificate_manager is None:
148148
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
149149
if (
@@ -279,7 +279,7 @@ def _get_bot_sts_token(self, bot_id: str):
279279
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
280280
return self._sts_token_manager.get(bot_id, resource_type="bot")
281281

282-
def _get_endpoint_certificate(self, endpoint_id: str) -> tuple[key_agreement_client, str, str]:
282+
def _get_endpoint_certificate(self, endpoint_id: str) -> Tuple[key_agreement_client, str, str, float]:
283283
if self._certificate_manager is None:
284284
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
285285
if (
@@ -429,7 +429,7 @@ def __init__(
429429
base_url: str | URL = BASE_URL,
430430
api_key: str | None = None,
431431
):
432-
self._certificate_manager: Dict[str, Tuple[key_agreement_client, str, str]] = {}
432+
self._certificate_manager: Dict[str, Tuple[key_agreement_client, str, str, float]] = {}
433433

434434
# local cache prepare
435435
self._init_local_cert_cache()
@@ -522,7 +522,7 @@ def _load_cert_locally(self, ep: str) -> str | None:
522522
cert_pem = None
523523
with open(cert_file_path, "r") as f:
524524
cert_pem = f.read()
525-
ring, key = get_cert_info(cert_pem)
525+
ring, key, _ = get_cert_info(cert_pem)
526526
# check cert is complement with AICC/PCA
527527
if (ring == "" or key == "") and not self._aicc_enabled:
528528
return cert_pem
@@ -545,7 +545,7 @@ def _init_local_cert_cache(self):
545545
"failed to create certificate directory %s: %s\n" % (self._cert_storage_path, e)
546546
)
547547

548-
def get(self, ep: str) -> tuple[key_agreement_client, str, str]:
548+
def get(self, ep: str) -> Tuple[key_agreement_client, str, str, float]:
549549
if ep not in self._certificate_manager:
550550
cert_pem = self._load_cert_locally(ep)
551551
if cert_pem is None:
@@ -556,12 +556,13 @@ def get(self, ep: str) -> tuple[key_agreement_client, str, str]:
556556
else:
557557
cert_pem = self._load_cert_by_ak_sk(ep)
558558
self._save_cert_to_file(ep, cert_pem)
559-
ring, key = get_cert_info(cert_pem)
559+
ring, key, exp_time = get_cert_info(cert_pem)
560560
self._certificate_manager[ep] = (
561561
key_agreement_client(
562562
certificate_pem_string=cert_pem
563563
),
564564
ring,
565565
key,
566+
exp_time,
566567
)
567568
return self._certificate_manager[ep]

volcenginesdkarkruntime/_utils/_key_agreement.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Tuple
1818

1919

20-
def get_cert_info(cert_pem: str) -> Tuple[str, str]:
20+
def get_cert_info(cert_pem: str) -> Tuple[str, str, float]:
2121
import re
2222
from cryptography import x509
2323
from cryptography.hazmat.backends import default_backend
@@ -27,10 +27,10 @@ def get_cert_info(cert_pem: str) -> Tuple[str, str]:
2727
dns = cert.extensions.get_extension_for_class(
2828
x509.SubjectAlternativeName).value.get_values_for_type(x509.DNSName)
2929
if dns and len(dns) > 1 and re.match(r"^ring\..*$", dns[0]) and re.match(r"^key\..*$", dns[1]):
30-
return dns[0][5:], dns[1][4:]
30+
return dns[0][5:], dns[1][4:], cert.not_valid_after_utc.timestamp()
3131
except Exception:
3232
pass
33-
return "", ""
33+
return "", "", cert.not_valid_after_utc.timestamp()
3434

3535

3636
def aes_gcm_encrypt_bytes(
@@ -92,6 +92,20 @@ def aes_gcm_decrypt_base64_string(key: bytes, nonce: bytes, ciphertext: str) ->
9292
base64_pattern = r'(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{4})'
9393

9494

95+
def decrypt_corner_case(key: bytes, nonce: bytes, data: str) -> str:
96+
"""decrypt_corner_case Decrypt corner case data"""
97+
if len(data) < 24:
98+
return ''
99+
for i in range(20, len(data), 4):
100+
try:
101+
decrypted = aes_gcm_decrypt_base64_string(key, nonce, data[:i+4])
102+
if i+4 == len(data):
103+
return decrypted
104+
return decrypted + decrypt_corner_case(key, nonce, data[i+4:])
105+
except Exception:
106+
pass
107+
108+
95109
def aes_gcm_decrypt_base64_list(key: bytes, nonce: bytes, ciphertext: str) -> str:
96110
# Decrypt
97111
base64_array = re.findall(base64_pattern, ciphertext)
@@ -100,17 +114,7 @@ def aes_gcm_decrypt_base64_list(key: bytes, nonce: bytes, ciphertext: str) -> st
100114
try:
101115
result.append(aes_gcm_decrypt_base64_string(key, nonce, b64))
102116
except Exception:
103-
for i in range(20, len(b64), 4):
104-
try:
105-
decrypted = aes_gcm_decrypt_base64_string(
106-
key, nonce, b64[:i+4])
107-
result.append(decrypted)
108-
decrypted = aes_gcm_decrypt_base64_string(
109-
key, nonce, b64[i+4:])
110-
result.append(decrypted)
111-
break
112-
except Exception:
113-
pass
117+
result.append(decrypt_corner_case(key, nonce, b64))
114118
return ''.join(result)
115119

116120

volcenginesdkarkruntime/resources/batch_chat/completions.py

Lines changed: 10 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
from typing_extensions import Literal
2121

2222
import httpx
23-
import warnings
2423

2524
from ..._exceptions import ArkAPITimeoutError, ArkAPIConnectionError, ArkAPIStatusError
2625
from ..._types import Body, Query, Headers
27-
from ..._utils import with_sts_token, async_with_sts_token, deepcopy_minimal
28-
from ..._utils._key_agreement import aes_gcm_decrypt_base64_string
26+
from ..._utils import with_sts_token, async_with_sts_token
27+
from ..encryption import with_e2e_encryption, async_with_e2e_encryption
2928
from ..._base_client import make_request_options
3029
from ..._resource import SyncAPIResource, AsyncAPIResource
3130
from ..._compat import cached_property
@@ -50,43 +49,10 @@
5049
__all__ = ["Completions", "AsyncCompletions"]
5150

5251

53-
def _process_messages(
54-
messages: Iterable[ChatCompletionMessageParam], f: Callable[[str], str]
55-
):
56-
for message in messages:
57-
if message.get("content", None) is not None:
58-
current_content = message.get("content")
59-
if isinstance(current_content, str):
60-
message["content"] = f(current_content)
61-
elif isinstance(current_content, Iterable):
62-
for part in current_content:
63-
if part.get("type", None) == "text":
64-
part["text"] = f(part["text"])
65-
elif part.get("type", None) == "image_url":
66-
if part["image_url"]["url"].startswith("data:"):
67-
part["image_url"]["url"] = f(part["image_url"]["url"])
68-
else:
69-
warnings.warn(
70-
"encryption is not supported for image url, "
71-
"please use base64 image if you want encryption"
72-
)
73-
else:
74-
raise TypeError(
75-
"encryption is not supported for content type {}".format(
76-
type(part)
77-
)
78-
)
79-
else:
80-
raise TypeError(
81-
"encryption is not supported for content type {}".format(
82-
type(message.get("content"))
83-
)
84-
)
85-
86-
8752
def _calculate_retry_timeout(retry_times) -> float:
8853
nbRetries = min(retry_times, MAX_RETRY_DELAY / INITIAL_RETRY_DELAY)
89-
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2, nbRetries), MAX_RETRY_DELAY)
54+
sleep_seconds = min(INITIAL_RETRY_DELAY *
55+
pow(2, nbRetries), MAX_RETRY_DELAY)
9056
# Apply some jitter, plus-or-minus half a second.
9157
jitter = 1 - 0.25 * random()
9258
timeout = sleep_seconds * jitter
@@ -126,58 +92,8 @@ class Completions(SyncAPIResource):
12692
def with_raw_response(self) -> CompletionsWithRawResponse:
12793
return CompletionsWithRawResponse(self)
12894

129-
def _process_messages(
130-
self, messages: Iterable[ChatCompletionMessageParam], f: Callable[[str], str]
131-
):
132-
for message in messages:
133-
if message.get("content", None) is not None:
134-
current_content = message.get("content")
135-
if isinstance(current_content, str):
136-
message["content"] = f(current_content)
137-
elif isinstance(current_content, Iterable):
138-
raise TypeError(
139-
"content type {} is not supported end-to-end encryption".format(
140-
type(message.get("content"))
141-
)
142-
)
143-
else:
144-
raise TypeError(
145-
"content type {} is not supported end-to-end encryption".format(
146-
type(message.get("content"))
147-
)
148-
)
149-
150-
def _encrypt(
151-
self,
152-
model: str,
153-
messages: Iterable[ChatCompletionMessageParam],
154-
extra_headers: Headers,
155-
) -> tuple[bytes, bytes]:
156-
client = self._client._get_endpoint_certificate(model)
157-
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
158-
extra_headers["X-Session-Token"] = session_token
159-
_process_messages(
160-
messages,
161-
lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x),
162-
)
163-
return _crypto_key, _crypto_nonce
164-
165-
def _decrypt(
166-
self, key: bytes, nonce: bytes, resp: ChatCompletion
167-
) -> ChatCompletion:
168-
if resp.choices is not None:
169-
for index, choice in enumerate(resp.choices):
170-
if (
171-
choice.message is not None and choice.finish_reason != 'content_filter'
172-
and choice.message.content is not None
173-
):
174-
choice.message.content = aes_gcm_decrypt_base64_string(
175-
key, nonce, choice.message.content
176-
)
177-
resp.choices[index] = choice
178-
return resp
179-
18095
@with_sts_token
96+
@with_e2e_encryption
18197
def create(
18298
self,
18399
*,
@@ -208,14 +124,6 @@ def create(
208124
extra_body: Body | None = None,
209125
timeout: float | httpx.Timeout | None = None,
210126
) -> ChatCompletion:
211-
is_encrypt = False
212-
if (
213-
extra_headers is not None
214-
and extra_headers.get(ARK_E2E_ENCRYPTION_HEADER, None) == "true"
215-
):
216-
is_encrypt = True
217-
messages = deepcopy_minimal(messages)
218-
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
219127
retryTimes = 0
220128
last_time = self._get_request_last_time(timeout)
221129
model_breaker = self._client.get_model_breaker(model)
@@ -273,8 +181,6 @@ def create(
273181
continue
274182
else:
275183
raise err
276-
if is_encrypt:
277-
resp = self._decrypt(e2e_key, e2e_nonce, resp)
278184
return resp
279185

280186
def _get_request_last_time(self, timeout):
@@ -289,7 +195,8 @@ def _get_request_last_time(self, timeout):
289195
timeoutSeconds = timeout
290196
else:
291197
raise TypeError(
292-
"timeout type {} is not supported".format(type(self._client.timeout))
198+
"timeout type {} is not supported".format(
199+
type(self._client.timeout))
293200
)
294201
return datetime.now() + timedelta(seconds=timeoutSeconds)
295202

@@ -299,37 +206,8 @@ class AsyncCompletions(AsyncAPIResource):
299206
def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
300207
return AsyncCompletionsWithRawResponse(self)
301208

302-
def _encrypt(
303-
self,
304-
model: str,
305-
messages: Iterable[ChatCompletionMessageParam],
306-
extra_headers: Headers,
307-
) -> tuple[bytes, bytes]:
308-
client = self._client._get_endpoint_certificate(model)
309-
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
310-
extra_headers["X-Session-Token"] = session_token
311-
_process_messages(
312-
messages,
313-
lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x),
314-
)
315-
return _crypto_key, _crypto_nonce
316-
317-
async def _decrypt(
318-
self, key: bytes, nonce: bytes, resp: ChatCompletion
319-
) -> ChatCompletion:
320-
if resp.choices is not None:
321-
for index, choice in enumerate(resp.choices):
322-
if (
323-
choice.message is not None and choice.finish_reason != 'content_filter'
324-
and choice.message.content is not None
325-
):
326-
choice.message.content = aes_gcm_decrypt_base64_string(
327-
key, nonce, choice.message.content
328-
)
329-
resp.choices[index] = choice
330-
return resp
331-
332209
@async_with_sts_token
210+
@async_with_e2e_encryption
333211
async def create(
334212
self,
335213
*,
@@ -360,14 +238,6 @@ async def create(
360238
extra_body: Body | None = None,
361239
timeout: float | httpx.Timeout | None = None,
362240
) -> ChatCompletion:
363-
is_encrypt = False
364-
if (
365-
extra_headers is not None
366-
and extra_headers.get(ARK_E2E_ENCRYPTION_HEADER, None) == "true"
367-
):
368-
is_encrypt = True
369-
messages = deepcopy_minimal(messages)
370-
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
371241

372242
retryTimes = 0
373243
last_time = self._get_request_last_time(timeout)
@@ -426,8 +296,6 @@ async def create(
426296
continue
427297
else:
428298
raise err
429-
if is_encrypt:
430-
resp = await self._decrypt(e2e_key, e2e_nonce, resp)
431299
return resp
432300

433301
def _get_request_last_time(self, timeout):
@@ -442,7 +310,8 @@ def _get_request_last_time(self, timeout):
442310
timeoutSeconds = timeout
443311
else:
444312
raise TypeError(
445-
"timeout type {} is not supported".format(type(self._client.timeout))
313+
"timeout type {} is not supported".format(
314+
type(self._client.timeout))
446315
)
447316
return datetime.now() + timedelta(seconds=timeoutSeconds)
448317

0 commit comments

Comments
 (0)