Skip to content

Commit 0b630a8

Browse files
author
liuhuiqi.7
committed
feat(ark e2e): support async
Change-Id: If3cae5914537c4346aff1aa81ef995801f4a96bc
1 parent f3a7047 commit 0b630a8

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

volcenginesdkarkruntime/_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(
178178

179179
self._default_stream_cls = Stream
180180
self._sts_token_manager: StsTokenManager | None = None
181+
self._certificate_manager: E2ECertificateManager | None = None
181182

182183
self.chat = resources.AsyncChat(self)
183184
self.bot_chat = resources.AsyncBotChat(self)
@@ -192,6 +193,15 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
192193
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
193194
return self._sts_token_manager.get(endpoint_id)
194195

196+
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
197+
if self._certificate_manager is None:
198+
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
199+
if (self.ak is None or self.sk is None) and cert_path is None:
200+
raise ArkAPIError("must set (ak and sk) or (E2E_CERTIFICATE_PATH) \
201+
before get endpoint token.")
202+
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region)
203+
return self._certificate_manager.get(endpoint_id)
204+
195205
@property
196206
def auth_headers(self) -> dict[str, str]:
197207
api_key = self.api_key

volcenginesdkarkruntime/resources/chat/completions.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,23 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
4747
if isinstance(message.get("content"), str):
4848
message["content"] = f(message.get("content"))
4949
elif isinstance(message.get("content"), Iterable):
50-
content = message.get("content")
51-
for i, c in enumerate(content):
52-
if not isinstance(c, Dict):
53-
raise TypeError("content type {} is not supported end-to-end encryption".
54-
format(type(c)))
55-
for key in c.keys():
56-
if key == 'type':
57-
continue
58-
if isinstance(c[key], str):
59-
content[i][key] = f(c[key])
60-
if isinstance(c[key], Dict):
61-
for k in c[key].keys():
62-
if isinstance(c[key][k], str):
63-
content[i][key][k] = f(c[key][k])
64-
message["content"] = content
50+
raise TypeError("content type {} is not supported end-to-end encryption".
51+
format(type(message.get('content'))))
52+
# content = message.get("content")
53+
# for i, c in enumerate(content):
54+
# if not isinstance(c, Dict):
55+
# raise TypeError("content type {} is not supported end-to-end encryption".
56+
# format(type(c)))
57+
# for key in c.keys():
58+
# if key == 'type':
59+
# continue
60+
# if isinstance(c[key], str):
61+
# content[i][key] = f(c[key])
62+
# if isinstance(c[key], Dict):
63+
# for k in c[key].keys():
64+
# if isinstance(c[key][k], str):
65+
# content[i][key][k] = f(c[key][k])
66+
# message["content"] = content
6567
else:
6668
raise TypeError("content type {} is not supported end-to-end encryption".
6769
format(type(message.get('content'))))
@@ -160,6 +162,28 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
160162
def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
161163
return AsyncCompletionsWithStreamingResponse(self)
162164

165+
def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
166+
f: Callable[[str], str]):
167+
for message in messages:
168+
if message.get("content", None) is not None:
169+
if isinstance(message.get("content"), str):
170+
message["content"] = f(message.get("content"))
171+
else:
172+
raise TypeError("content type {} is not supported end-to-end encryption".
173+
format(type(message.get('content'))))
174+
175+
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers):
176+
client = self._client._get_endpoint_certificate(model)
177+
self._ka_client = client
178+
self._crypto_key, self._crypto_nonce, session_token = client.generate_ecies_key_pair()
179+
extra_headers['X-Session-Token'] = session_token
180+
self._process_messages(messages, lambda x: client.encrypt_string_with_key(self._crypto_key,
181+
self._crypto_nonce,
182+
x))
183+
184+
def decrypt(self, ciphertext: str) -> str:
185+
return self._ka_client.decrypt_string_with_key(self._crypto_key, self._crypto_nonce, ciphertext)
186+
163187
@async_with_sts_token
164188
async def create(
165189
self,
@@ -188,7 +212,13 @@ async def create(
188212
extra_query: Query | None = None,
189213
extra_body: Body | None = None,
190214
timeout: float | httpx.Timeout | None = None,
215+
is_encrypt: bool | None = None,
191216
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
217+
if is_encrypt:
218+
if extra_headers is None:
219+
extra_headers = dict()
220+
self._encrypt(model, messages, extra_headers)
221+
192222
return await self._post(
193223
"/chat/completions",
194224
body={

0 commit comments

Comments
 (0)