Skip to content

Commit a277085

Browse files
author
liuhuiqi.7
committed
feat(ark e2e): support async/sync
Change-Id: I7e6fc6175b337d8cec3fca9b6d988d5dc6706361
1 parent 0b630a8 commit a277085

File tree

3 files changed

+105
-54
lines changed

3 files changed

+105
-54
lines changed

volcenginesdkarkruntime/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
99
CLIENT_REQUEST_HEADER = "X-Client-Request-Id"
1010
SERVER_REQUEST_HEADER = "X-Request-Id"
11+
ARK_E2E_ENCRYPTION_HEADER = "x-is-encrypted"
1112

1213
# default timeout is 1 minutes
1314
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=60.0)

volcenginesdkarkruntime/_streaming.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import inspect
66
from types import TracebackType
7-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
7+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast, Optional
88
from typing_extensions import (
99
Self,
1010
Protocol,
@@ -40,12 +40,20 @@ def __init__(
4040
cast_to: type[_T],
4141
response: httpx.Response,
4242
client: Ark,
43+
iterator: Optional[Iterator[_T]] | None = None,
4344
) -> None:
44-
self.response = response
45-
self._cast_to = cast_to
46-
self._client = client
47-
self._decoder = client._make_sse_decoder()
48-
self._iterator = self.__stream__()
45+
if iterator is not None:
46+
self._iterator = iterator
47+
else:
48+
self.response = response
49+
self._cast_to = cast_to
50+
self._client = client
51+
self._decoder = client._make_sse_decoder()
52+
self._iterator = self.__stream__()
53+
54+
@classmethod
55+
def _make_stream_from_iterator(cls, iterator: Iterator[_T]) -> Stream[_T]:
56+
return Stream(cast_to=None, response=None, client=None, iterator=iterator)
4957

5058
def __next__(self) -> _T:
5159
return self._iterator.__next__()
@@ -148,12 +156,20 @@ def __init__(
148156
cast_to: type[_T],
149157
response: httpx.Response,
150158
client: AsyncArk,
159+
iterator: Optional[AsyncIterator[_T]] | None = None,
151160
) -> None:
152-
self.response = response
153-
self._cast_to = cast_to
154-
self._client = client
155-
self._decoder = client._make_sse_decoder()
156-
self._iterator = self.__stream__()
161+
if iterator is not None:
162+
self._iterator = iterator
163+
else:
164+
self.response = response
165+
self._cast_to = cast_to
166+
self._client = client
167+
self._decoder = client._make_sse_decoder()
168+
self._iterator = self.__stream__()
169+
170+
@classmethod
171+
def _make_stream_from_iterator(cls, iterator: Iterator[_T]) -> Stream[_T]:
172+
return AsyncStream(cast_to=None, response=None, client=None, iterator=iterator)
157173

158174
async def __anext__(self) -> _T:
159175
return await self._iterator.__anext__()

volcenginesdkarkruntime/resources/chat/completions.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import Dict, List, Union, Iterable, Optional, Callable
3+
from typing import Dict, List, Union, Iterable, Optional, Callable, Iterator, AsyncIterator
44

55
import httpx
66
from typing_extensions import Literal
77

88
from ..._types import Body, Query, Headers
99
from ..._utils._utils import with_sts_token, async_with_sts_token
10+
from ..._utils._key_agreement import aes_gcm_decrypt_base64_string
1011
from ..._base_client import make_request_options
1112
from ..._resource import SyncAPIResource, AsyncAPIResource
1213
from ..._compat import cached_property
@@ -27,6 +28,7 @@
2728
ChatCompletionToolParam,
2829
ChatCompletionToolChoiceOptionParam
2930
)
31+
from ..._constants import ARK_E2E_ENCRYPTION_HEADER
3032

3133
__all__ = ["Completions", "AsyncCompletions"]
3234

@@ -49,36 +51,42 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
4951
elif isinstance(message.get("content"), Iterable):
5052
raise TypeError("content type {} is not supported end-to-end encryption".
5153
format(type(message.get('content'))))
52-
# content = message.get("content")
53-
# for i, c in enumerate(content):
54-
# if not isinstance(c, Dict):
55-
# raise TypeError("content type {} is not supported end-to-end encryption".
56-
# format(type(c)))
57-
# for key in c.keys():
58-
# if key == 'type':
59-
# continue
60-
# if isinstance(c[key], str):
61-
# content[i][key] = f(c[key])
62-
# if isinstance(c[key], Dict):
63-
# for k in c[key].keys():
64-
# if isinstance(c[key][k], str):
65-
# content[i][key][k] = f(c[key][k])
66-
# message["content"] = content
6754
else:
6855
raise TypeError("content type {} is not supported end-to-end encryption".
6956
format(type(message.get('content'))))
7057

71-
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers):
58+
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers
59+
) -> tuple[bytes, bytes]:
7260
client = self._client._get_endpoint_certificate(model)
73-
self._ka_client = client
74-
self._crypto_key, self._crypto_nonce, session_token = client.generate_ecies_key_pair()
61+
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
7562
extra_headers['X-Session-Token'] = session_token
76-
self._process_messages(messages, lambda x: client.encrypt_string_with_key(self._crypto_key,
77-
self._crypto_nonce,
63+
self._process_messages(messages, lambda x: client.encrypt_string_with_key(_crypto_key,
64+
_crypto_nonce,
7865
x))
79-
80-
def decrypt(self, ciphertext: str) -> str:
81-
return self._ka_client.decrypt_string_with_key(self._crypto_key, self._crypto_nonce, ciphertext)
66+
return _crypto_key, _crypto_nonce
67+
68+
def _decrypt_chunk(self, key: bytes, nonce: bytes, resp: Stream[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
69+
for chunk in resp:
70+
if chunk.choices is not None:
71+
choice = chunk.choices[0]
72+
if choice.delta is not None:
73+
if choice.delta.content is not None:
74+
choice.delta.content = aes_gcm_decrypt_base64_string(key, nonce, choice.delta.content)
75+
chunk.choices[0] = choice
76+
yield chunk
77+
78+
def _decrypt(self, key: bytes, nonce: bytes, resp: ChatCompletion | Stream[ChatCompletionChunk]
79+
) -> ChatCompletion | Stream[ChatCompletionChunk]:
80+
if isinstance(resp, ChatCompletion):
81+
if resp.choices is not None:
82+
choice = resp.choices[0]
83+
if choice.message is not None:
84+
if choice.message.content is not None:
85+
choice.message.content = aes_gcm_decrypt_base64_string(key, nonce, choice.message.content)
86+
resp.choices[0] = choice
87+
return resp
88+
else:
89+
return Stream._make_stream_from_iterator(self._decrypt_chunk(key, nonce, resp))
8290

8391
@with_sts_token
8492
def create(
@@ -108,12 +116,11 @@ def create(
108116
extra_query: Query | None = None,
109117
extra_body: Body | None = None,
110118
timeout: float | httpx.Timeout | None = None,
111-
is_encrypt: bool | None = None,
112119
) -> ChatCompletion | Stream[ChatCompletionChunk]:
113-
if is_encrypt:
114-
if extra_headers is None:
115-
extra_headers = dict()
116-
self._encrypt(model, messages, extra_headers)
120+
is_encrypt = False
121+
if extra_headers is not None and extra_headers[ARK_E2E_ENCRYPTION_HEADER] == 'true':
122+
is_encrypt = True
123+
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
117124

118125
resp = self._post(
119126
"/chat/completions",
@@ -150,6 +157,8 @@ def create(
150157
stream_cls=Stream[ChatCompletionChunk],
151158
)
152159

160+
if is_encrypt:
161+
resp = self._decrypt(e2e_key, e2e_nonce, resp)
153162
return resp
154163

155164

@@ -172,17 +181,39 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
172181
raise TypeError("content type {} is not supported end-to-end encryption".
173182
format(type(message.get('content'))))
174183

175-
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers):
184+
def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers
185+
) -> tuple[bytes, bytes]:
176186
client = self._client._get_endpoint_certificate(model)
177-
self._ka_client = client
178-
self._crypto_key, self._crypto_nonce, session_token = client.generate_ecies_key_pair()
187+
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
179188
extra_headers['X-Session-Token'] = session_token
180-
self._process_messages(messages, lambda x: client.encrypt_string_with_key(self._crypto_key,
181-
self._crypto_nonce,
189+
self._process_messages(messages, lambda x: client.encrypt_string_with_key(_crypto_key,
190+
_crypto_nonce,
182191
x))
183-
184-
def decrypt(self, ciphertext: str) -> str:
185-
return self._ka_client.decrypt_string_with_key(self._crypto_key, self._crypto_nonce, ciphertext)
192+
return _crypto_key, _crypto_nonce
193+
194+
async def _decrypt_chunk(self, key: bytes, nonce: bytes, resp: AsyncStream[ChatCompletionChunk]
195+
) -> AsyncIterator[ChatCompletionChunk]:
196+
async for chunk in resp:
197+
if chunk.choices is not None:
198+
choice = chunk.choices[0]
199+
if choice.delta is not None:
200+
if choice.delta.content is not None:
201+
choice.delta.content = aes_gcm_decrypt_base64_string(key, nonce, choice.delta.content)
202+
chunk.choices[0] = choice
203+
yield chunk
204+
205+
async def _decrypt(self, key: bytes, nonce: bytes, resp: ChatCompletion | AsyncStream[ChatCompletionChunk]
206+
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
207+
if isinstance(resp, ChatCompletion):
208+
if resp.choices is not None:
209+
choice = resp.choices[0]
210+
if choice.message is not None:
211+
if choice.message.content is not None:
212+
choice.message.content = aes_gcm_decrypt_base64_string(key, nonce, choice.message.content)
213+
resp.choices[0] = choice
214+
return await resp
215+
else:
216+
return AsyncStream._make_stream_from_iterator(self._decrypt_chunk(key, nonce, resp))
186217

187218
@async_with_sts_token
188219
async def create(
@@ -212,14 +243,13 @@ async def create(
212243
extra_query: Query | None = None,
213244
extra_body: Body | None = None,
214245
timeout: float | httpx.Timeout | None = None,
215-
is_encrypt: bool | None = None,
216246
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
217-
if is_encrypt:
218-
if extra_headers is None:
219-
extra_headers = dict()
220-
self._encrypt(model, messages, extra_headers)
247+
is_encrypt = False
248+
if extra_headers is not None and extra_headers[ARK_E2E_ENCRYPTION_HEADER] == 'true':
249+
is_encrypt = True
250+
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
221251

222-
return await self._post(
252+
resp = await self._post(
223253
"/chat/completions",
224254
body={
225255
"messages": messages,
@@ -254,6 +284,10 @@ async def create(
254284
stream_cls=AsyncStream[ChatCompletionChunk],
255285
)
256286

287+
if is_encrypt:
288+
resp = await self._decrypt(e2e_key, e2e_nonce, resp)
289+
return resp
290+
257291

258292
class CompletionsWithRawResponse:
259293
def __init__(self, completions: Completions) -> None:

0 commit comments

Comments
 (0)