@@ -75,14 +75,32 @@ def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], e
7575 self ._crypto_nonce ,
7676 x ))
7777
78- def _decrypt (self , completion : ChatCompletion ):
79- if completion .choices is not None :
80- for choice in completion .choices :
78+ def _decrypt (self , completion : ChatCompletion | Stream [ChatCompletionChunk ]):
79+ if isinstance (completion , ChatCompletion ):
80+ if completion .choices is not None :
81+ choice = completion .choices [0 ]
8182 if choice .message .content is not None :
8283 if isinstance (choice .message .content , str ):
8384 choice .message .content = self ._ka_client .decrypt_string_with_key (self ._crypto_key ,
84- self ._crypto_nonce ,
85- choice .message .content )
85+ self ._crypto_nonce ,
86+ choice .message .content )
87+ else :
88+ raise TypeError ("content type {} is not supported end-to-end encryption" .
89+ format (type (choice .message .content )))
90+ completion .choices [0 ] = choice
91+ elif isinstance (completion , Stream ):
92+ for chunk in completion :
93+ if chunk .choices :
94+ choice = chunk .choices [0 ]
95+ if choice .delta .content is not None :
96+ if isinstance (choice .delta .content , str ):
97+ choice .delta .content = self ._ka_client .decrypt_string_with_key (self ._crypto_key ,
98+ self ._crypto_nonce ,
99+ choice .delta .content )
100+ else :
101+ raise TypeError ("content type {} is not supported end-to-end encryption" .
102+ format (type (choice .delta .content )))
103+ chunk .choices [0 ] = choice
86104
87105 @with_sts_token
88106 def create (
@@ -115,6 +133,8 @@ def create(
115133 is_encrypt : bool | None = None ,
116134 ) -> ChatCompletion | Stream [ChatCompletionChunk ]:
117135 if is_encrypt :
136+ if extra_headers is None :
137+ extra_headers = dict ()
118138 self ._encrypt (model , messages , extra_headers )
119139
120140 resp = self ._post (
@@ -152,8 +172,8 @@ def create(
152172 stream_cls = Stream [ChatCompletionChunk ],
153173 )
154174
155- # if is_encrypt:
156- # return self._decrypt(resp)
175+ if is_encrypt :
176+ self ._decrypt (resp )
157177 return resp
158178
159179
0 commit comments