2020from typing_extensions import Literal
2121
2222import httpx
23- import warnings
2423
2524from ..._exceptions import ArkAPITimeoutError , ArkAPIConnectionError , ArkAPIStatusError
2625from ..._types import Body , Query , Headers
27- from ..._utils import with_sts_token , async_with_sts_token , deepcopy_minimal
28- from ... _utils . _key_agreement import aes_gcm_decrypt_base64_string
26+ from ..._utils import with_sts_token , async_with_sts_token
27+ from ..encryption import with_e2e_encryption , async_with_e2e_encryption
2928from ..._base_client import make_request_options
3029from ..._resource import SyncAPIResource , AsyncAPIResource
3130from ..._compat import cached_property
5049__all__ = ["Completions" , "AsyncCompletions" ]
5150
5251
53- def _process_messages (
54- messages : Iterable [ChatCompletionMessageParam ], f : Callable [[str ], str ]
55- ):
56- for message in messages :
57- if message .get ("content" , None ) is not None :
58- current_content = message .get ("content" )
59- if isinstance (current_content , str ):
60- message ["content" ] = f (current_content )
61- elif isinstance (current_content , Iterable ):
62- for part in current_content :
63- if part .get ("type" , None ) == "text" :
64- part ["text" ] = f (part ["text" ])
65- elif part .get ("type" , None ) == "image_url" :
66- if part ["image_url" ]["url" ].startswith ("data:" ):
67- part ["image_url" ]["url" ] = f (part ["image_url" ]["url" ])
68- else :
69- warnings .warn (
70- "encryption is not supported for image url, "
71- "please use base64 image if you want encryption"
72- )
73- else :
74- raise TypeError (
75- "encryption is not supported for content type {}" .format (
76- type (part )
77- )
78- )
79- else :
80- raise TypeError (
81- "encryption is not supported for content type {}" .format (
82- type (message .get ("content" ))
83- )
84- )
85-
86-
8752def _calculate_retry_timeout (retry_times ) -> float :
8853 nbRetries = min (retry_times , MAX_RETRY_DELAY / INITIAL_RETRY_DELAY )
89- sleep_seconds = min (INITIAL_RETRY_DELAY * pow (2 , nbRetries ), MAX_RETRY_DELAY )
54+ sleep_seconds = min (INITIAL_RETRY_DELAY *
55+ pow (2 , nbRetries ), MAX_RETRY_DELAY )
9056 # Apply some jitter, plus-or-minus half a second.
9157 jitter = 1 - 0.25 * random ()
9258 timeout = sleep_seconds * jitter
@@ -126,58 +92,8 @@ class Completions(SyncAPIResource):
12692 def with_raw_response (self ) -> CompletionsWithRawResponse :
12793 return CompletionsWithRawResponse (self )
12894
129- def _process_messages (
130- self , messages : Iterable [ChatCompletionMessageParam ], f : Callable [[str ], str ]
131- ):
132- for message in messages :
133- if message .get ("content" , None ) is not None :
134- current_content = message .get ("content" )
135- if isinstance (current_content , str ):
136- message ["content" ] = f (current_content )
137- elif isinstance (current_content , Iterable ):
138- raise TypeError (
139- "content type {} is not supported end-to-end encryption" .format (
140- type (message .get ("content" ))
141- )
142- )
143- else :
144- raise TypeError (
145- "content type {} is not supported end-to-end encryption" .format (
146- type (message .get ("content" ))
147- )
148- )
149-
150- def _encrypt (
151- self ,
152- model : str ,
153- messages : Iterable [ChatCompletionMessageParam ],
154- extra_headers : Headers ,
155- ) -> tuple [bytes , bytes ]:
156- client = self ._client ._get_endpoint_certificate (model )
157- _crypto_key , _crypto_nonce , session_token = client .generate_ecies_key_pair ()
158- extra_headers ["X-Session-Token" ] = session_token
159- _process_messages (
160- messages ,
161- lambda x : client .encrypt_string_with_key (_crypto_key , _crypto_nonce , x ),
162- )
163- return _crypto_key , _crypto_nonce
164-
165- def _decrypt (
166- self , key : bytes , nonce : bytes , resp : ChatCompletion
167- ) -> ChatCompletion :
168- if resp .choices is not None :
169- for index , choice in enumerate (resp .choices ):
170- if (
171- choice .message is not None and choice .finish_reason != 'content_filter'
172- and choice .message .content is not None
173- ):
174- choice .message .content = aes_gcm_decrypt_base64_string (
175- key , nonce , choice .message .content
176- )
177- resp .choices [index ] = choice
178- return resp
179-
18095 @with_sts_token
96+ @with_e2e_encryption
18197 def create (
18298 self ,
18399 * ,
@@ -208,14 +124,6 @@ def create(
208124 extra_body : Body | None = None ,
209125 timeout : float | httpx .Timeout | None = None ,
210126 ) -> ChatCompletion :
211- is_encrypt = False
212- if (
213- extra_headers is not None
214- and extra_headers .get (ARK_E2E_ENCRYPTION_HEADER , None ) == "true"
215- ):
216- is_encrypt = True
217- messages = deepcopy_minimal (messages )
218- e2e_key , e2e_nonce = self ._encrypt (model , messages , extra_headers )
219127 retryTimes = 0
220128 last_time = self ._get_request_last_time (timeout )
221129 model_breaker = self ._client .get_model_breaker (model )
@@ -273,8 +181,6 @@ def create(
273181 continue
274182 else :
275183 raise err
276- if is_encrypt :
277- resp = self ._decrypt (e2e_key , e2e_nonce , resp )
278184 return resp
279185
280186 def _get_request_last_time (self , timeout ):
@@ -289,7 +195,8 @@ def _get_request_last_time(self, timeout):
289195 timeoutSeconds = timeout
290196 else :
291197 raise TypeError (
292- "timeout type {} is not supported" .format (type (self ._client .timeout ))
198+ "timeout type {} is not supported" .format (
199+ type (self ._client .timeout ))
293200 )
294201 return datetime .now () + timedelta (seconds = timeoutSeconds )
295202
@@ -299,37 +206,8 @@ class AsyncCompletions(AsyncAPIResource):
299206 def with_raw_response (self ) -> AsyncCompletionsWithRawResponse :
300207 return AsyncCompletionsWithRawResponse (self )
301208
302- def _encrypt (
303- self ,
304- model : str ,
305- messages : Iterable [ChatCompletionMessageParam ],
306- extra_headers : Headers ,
307- ) -> tuple [bytes , bytes ]:
308- client = self ._client ._get_endpoint_certificate (model )
309- _crypto_key , _crypto_nonce , session_token = client .generate_ecies_key_pair ()
310- extra_headers ["X-Session-Token" ] = session_token
311- _process_messages (
312- messages ,
313- lambda x : client .encrypt_string_with_key (_crypto_key , _crypto_nonce , x ),
314- )
315- return _crypto_key , _crypto_nonce
316-
317- async def _decrypt (
318- self , key : bytes , nonce : bytes , resp : ChatCompletion
319- ) -> ChatCompletion :
320- if resp .choices is not None :
321- for index , choice in enumerate (resp .choices ):
322- if (
323- choice .message is not None and choice .finish_reason != 'content_filter'
324- and choice .message .content is not None
325- ):
326- choice .message .content = aes_gcm_decrypt_base64_string (
327- key , nonce , choice .message .content
328- )
329- resp .choices [index ] = choice
330- return resp
331-
332209 @async_with_sts_token
210+ @async_with_e2e_encryption
333211 async def create (
334212 self ,
335213 * ,
@@ -360,14 +238,6 @@ async def create(
360238 extra_body : Body | None = None ,
361239 timeout : float | httpx .Timeout | None = None ,
362240 ) -> ChatCompletion :
363- is_encrypt = False
364- if (
365- extra_headers is not None
366- and extra_headers .get (ARK_E2E_ENCRYPTION_HEADER , None ) == "true"
367- ):
368- is_encrypt = True
369- messages = deepcopy_minimal (messages )
370- e2e_key , e2e_nonce = self ._encrypt (model , messages , extra_headers )
371241
372242 retryTimes = 0
373243 last_time = self ._get_request_last_time (timeout )
@@ -426,8 +296,6 @@ async def create(
426296 continue
427297 else :
428298 raise err
429- if is_encrypt :
430- resp = await self ._decrypt (e2e_key , e2e_nonce , resp )
431299 return resp
432300
433301 def _get_request_last_time (self , timeout ):
@@ -442,7 +310,8 @@ def _get_request_last_time(self, timeout):
442310 timeoutSeconds = timeout
443311 else :
444312 raise TypeError (
445- "timeout type {} is not supported" .format (type (self ._client .timeout ))
313+ "timeout type {} is not supported" .format (
314+ type (self ._client .timeout ))
446315 )
447316 return datetime .now () + timedelta (seconds = timeoutSeconds )
448317
0 commit comments