11from __future__ import annotations
22
3- from typing import Dict , List , Union , Iterable , Optional , Callable
3+ from typing import Dict , List , Union , Iterable , Optional , Callable , Iterator , AsyncIterator
44
55import httpx
66from typing_extensions import Literal
77
88from ..._types import Body , Query , Headers
99from ..._utils ._utils import with_sts_token , async_with_sts_token
10+ from ..._utils ._key_agreement import aes_gcm_decrypt_base64_string
1011from ..._base_client import make_request_options
1112from ..._resource import SyncAPIResource , AsyncAPIResource
1213from ..._compat import cached_property
2728 ChatCompletionToolParam ,
2829 ChatCompletionToolChoiceOptionParam
2930)
31+ from ..._constants import ARK_E2E_ENCRYPTION_HEADER
3032
3133__all__ = ["Completions" , "AsyncCompletions" ]
3234
@@ -49,36 +51,42 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
4951 elif isinstance (message .get ("content" ), Iterable ):
5052 raise TypeError ("content type {} is not supported end-to-end encryption" .
5153 format (type (message .get ('content' ))))
52- # content = message.get("content")
53- # for i, c in enumerate(content):
54- # if not isinstance(c, Dict):
55- # raise TypeError("content type {} is not supported end-to-end encryption".
56- # format(type(c)))
57- # for key in c.keys():
58- # if key == 'type':
59- # continue
60- # if isinstance(c[key], str):
61- # content[i][key] = f(c[key])
62- # if isinstance(c[key], Dict):
63- # for k in c[key].keys():
64- # if isinstance(c[key][k], str):
65- # content[i][key][k] = f(c[key][k])
66- # message["content"] = content
6754 else :
6855 raise TypeError ("content type {} is not supported end-to-end encryption" .
6956 format (type (message .get ('content' ))))
7057
71- def _encrypt (self , model : str , messages : Iterable [ChatCompletionMessageParam ], extra_headers : Headers ):
58+ def _encrypt (self , model : str , messages : Iterable [ChatCompletionMessageParam ], extra_headers : Headers
59+ ) -> tuple [bytes , bytes ]:
7260 client = self ._client ._get_endpoint_certificate (model )
73- self ._ka_client = client
74- self ._crypto_key , self ._crypto_nonce , session_token = client .generate_ecies_key_pair ()
61+ _crypto_key , _crypto_nonce , session_token = client .generate_ecies_key_pair ()
7562 extra_headers ['X-Session-Token' ] = session_token
76- self ._process_messages (messages , lambda x : client .encrypt_string_with_key (self . _crypto_key ,
77- self . _crypto_nonce ,
63+ self ._process_messages (messages , lambda x : client .encrypt_string_with_key (_crypto_key ,
64+ _crypto_nonce ,
7865 x ))
79-
80- def decrypt (self , ciphertext : str ) -> str :
81- return self ._ka_client .decrypt_string_with_key (self ._crypto_key , self ._crypto_nonce , ciphertext )
66+ return _crypto_key , _crypto_nonce
67+
68+ def _decrypt_chunk (self , key : bytes , nonce : bytes , resp : Stream [ChatCompletionChunk ]) -> Iterator [ChatCompletionChunk ]:
69+ for chunk in resp :
70+ if chunk .choices is not None :
71+ choice = chunk .choices [0 ]
72+ if choice .delta is not None :
73+ if choice .delta .content is not None :
74+ choice .delta .content = aes_gcm_decrypt_base64_string (key , nonce , choice .delta .content )
75+ chunk .choices [0 ] = choice
76+ yield chunk
77+
78+ def _decrypt (self , key : bytes , nonce : bytes , resp : ChatCompletion | Stream [ChatCompletionChunk ]
79+ ) -> ChatCompletion | Stream [ChatCompletionChunk ]:
80+ if isinstance (resp , ChatCompletion ):
81+ if resp .choices is not None :
82+ choice = resp .choices [0 ]
83+ if choice .message is not None :
84+ if choice .message .content is not None :
85+ choice .message .content = aes_gcm_decrypt_base64_string (key , nonce , choice .message .content )
86+ resp .choices [0 ] = choice
87+ return resp
88+ else :
89+ return Stream ._make_stream_from_iterator (self ._decrypt_chunk (key , nonce , resp ))
8290
8391 @with_sts_token
8492 def create (
@@ -108,12 +116,11 @@ def create(
108116 extra_query : Query | None = None ,
109117 extra_body : Body | None = None ,
110118 timeout : float | httpx .Timeout | None = None ,
111- is_encrypt : bool | None = None ,
112119 ) -> ChatCompletion | Stream [ChatCompletionChunk ]:
113- if is_encrypt :
114- if extra_headers is None :
115- extra_headers = dict ()
116- self ._encrypt (model , messages , extra_headers )
120+ is_encrypt = False
121+ if extra_headers is not None and extra_headers [ ARK_E2E_ENCRYPTION_HEADER ] == 'true' :
122+ is_encrypt = True
123+ e2e_key , e2e_nonce = self ._encrypt (model , messages , extra_headers )
117124
118125 resp = self ._post (
119126 "/chat/completions" ,
@@ -150,6 +157,8 @@ def create(
150157 stream_cls = Stream [ChatCompletionChunk ],
151158 )
152159
160+ if is_encrypt :
161+ resp = self ._decrypt (e2e_key , e2e_nonce , resp )
153162 return resp
154163
155164
@@ -172,17 +181,39 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
172181 raise TypeError ("content type {} is not supported end-to-end encryption" .
173182 format (type (message .get ('content' ))))
174183
175- def _encrypt (self , model : str , messages : Iterable [ChatCompletionMessageParam ], extra_headers : Headers ):
184+ def _encrypt (self , model : str , messages : Iterable [ChatCompletionMessageParam ], extra_headers : Headers
185+ ) -> tuple [bytes , bytes ]:
176186 client = self ._client ._get_endpoint_certificate (model )
177- self ._ka_client = client
178- self ._crypto_key , self ._crypto_nonce , session_token = client .generate_ecies_key_pair ()
187+ _crypto_key , _crypto_nonce , session_token = client .generate_ecies_key_pair ()
179188 extra_headers ['X-Session-Token' ] = session_token
180- self ._process_messages (messages , lambda x : client .encrypt_string_with_key (self . _crypto_key ,
181- self . _crypto_nonce ,
189+ self ._process_messages (messages , lambda x : client .encrypt_string_with_key (_crypto_key ,
190+ _crypto_nonce ,
182191 x ))
183-
184- def decrypt (self , ciphertext : str ) -> str :
185- return self ._ka_client .decrypt_string_with_key (self ._crypto_key , self ._crypto_nonce , ciphertext )
192+ return _crypto_key , _crypto_nonce
193+
194+ async def _decrypt_chunk (self , key : bytes , nonce : bytes , resp : AsyncStream [ChatCompletionChunk ]
195+ ) -> AsyncIterator [ChatCompletionChunk ]:
196+ async for chunk in resp :
197+ if chunk .choices is not None :
198+ choice = chunk .choices [0 ]
199+ if choice .delta is not None :
200+ if choice .delta .content is not None :
201+ choice .delta .content = aes_gcm_decrypt_base64_string (key , nonce , choice .delta .content )
202+ chunk .choices [0 ] = choice
203+ yield chunk
204+
205+ async def _decrypt (self , key : bytes , nonce : bytes , resp : ChatCompletion | AsyncStream [ChatCompletionChunk ]
206+ ) -> ChatCompletion | AsyncStream [ChatCompletionChunk ]:
207+ if isinstance (resp , ChatCompletion ):
208+ if resp .choices is not None :
209+ choice = resp .choices [0 ]
210+ if choice .message is not None :
211+ if choice .message .content is not None :
212+ choice .message .content = aes_gcm_decrypt_base64_string (key , nonce , choice .message .content )
213+ resp .choices [0 ] = choice
214+ return await resp
215+ else :
216+ return AsyncStream ._make_stream_from_iterator (self ._decrypt_chunk (key , nonce , resp ))
186217
187218 @async_with_sts_token
188219 async def create (
@@ -212,14 +243,13 @@ async def create(
212243 extra_query : Query | None = None ,
213244 extra_body : Body | None = None ,
214245 timeout : float | httpx .Timeout | None = None ,
215- is_encrypt : bool | None = None ,
216246 ) -> ChatCompletion | AsyncStream [ChatCompletionChunk ]:
217- if is_encrypt :
218- if extra_headers is None :
219- extra_headers = dict ()
220- self ._encrypt (model , messages , extra_headers )
247+ is_encrypt = False
248+ if extra_headers is not None and extra_headers [ ARK_E2E_ENCRYPTION_HEADER ] == 'true' :
249+ is_encrypt = True
250+ e2e_key , e2e_nonce = self ._encrypt (model , messages , extra_headers )
221251
222- return await self ._post (
252+ resp = await self ._post (
223253 "/chat/completions" ,
224254 body = {
225255 "messages" : messages ,
@@ -254,6 +284,10 @@ async def create(
254284 stream_cls = AsyncStream [ChatCompletionChunk ],
255285 )
256286
287+ if is_encrypt :
288+ resp = await self ._decrypt (e2e_key , e2e_nonce , resp )
289+ return resp
290+
257291
258292class CompletionsWithRawResponse :
259293 def __init__ (self , completions : Completions ) -> None :
0 commit comments