|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import Dict, List, Union, Iterable, Optional |
| 3 | +from typing import Dict, List, Union, Iterable, Optional, Callable |
4 | 4 |
|
5 | 5 | import httpx |
6 | 6 | from typing_extensions import Literal |
@@ -40,15 +40,37 @@ def with_raw_response(self) -> CompletionsWithRawResponse: |
40 | 40 | def with_streaming_response(self) -> CompletionsWithStreamingResponse: |
41 | 41 | return CompletionsWithStreamingResponse(self) |
42 | 42 |
|
43 | | - def encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers): |
| 43 | + def _process_messages(self, messages: Iterable[ChatCompletionMessageParam], |
| 44 | + f: Callable[[str], str]): |
| 45 | + for message in messages: |
| 46 | + if message.get("content", None) is not None: |
| 47 | + if isinstance(message.get("content"), str): |
| 48 | + message["content"] = f(message.get("content")) |
| 49 | + elif isinstance(message.get("content"), Iterable): |
| 50 | + content = message.get("content") |
| 51 | + for i, c in enumerate(content): |
| 52 | + if not isinstance(c, Dict): |
| 53 | + raise TypeError("content type {} is not supported end-to-end encryption". |
| 54 | + format(type(c))) |
| 55 | + for key in c.keys(): |
| 56 | + if isinstance(c[key], str): |
| 57 | + content[i][key] = f(c[key]) |
| 58 | + if isinstance(c[key], Dict): |
| 59 | + for k in c[key].keys(): |
| 60 | + if isinstance(c[key][k], str): |
| 61 | + content[i][key][k] = f(c[key][k]) |
| 62 | + message["content"] = content |
| 63 | + else: |
| 64 | + raise TypeError("content type {} is not supported end-to-end encryption". |
| 65 | + format(type(message.get('content')))) |
| 66 | + |
| 67 | + def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], extra_headers: Headers): |
44 | 68 | client = self._client._get_endpoint_certificate(model) |
45 | 69 | self._crypto_key, self._crypto_nonce, session_token = client.generate_ecies_key_pair() |
46 | 70 | extra_headers['X-Session-Token'] = session_token |
47 | | - for message in messages: |
48 | | - if message.get("content", None) is not None: |
49 | | - message["content"] = client.encrypt_string_with_key(self._crypto_key, |
50 | | - self._crypto_nonce, |
51 | | - message.get("content")) |
| 71 | + self._process_messages(messages, lambda x: client.encrypt_string_with_key(self._crypto_key, |
| 72 | + self._crypto_nonce, |
| 73 | + x)) |
52 | 74 |
|
53 | 75 | @with_sts_token |
54 | 76 | def create( |
@@ -81,7 +103,8 @@ def create( |
81 | 103 | is_encrypt: bool | None = None, |
82 | 104 | ) -> ChatCompletion | Stream[ChatCompletionChunk]: |
83 | 105 | if is_encrypt: |
84 | | - self.encrypt(model, messages, extra_headers) |
| 106 | + self._encrypt(model, messages, extra_headers) |
| 107 | + print(messages) |
85 | 108 |
|
86 | 109 | return self._post( |
87 | 110 | "/chat/completions", |
|
0 commit comments