Skip to content

Commit 41b020f

Browse files
author
liuhuiqi.7
committed
feat(ark e2e): fix bug of n-choices
Change-Id: I25311bafb45a25eac245f04dc5f0a8fbde3d3807
1 parent 4563888 commit 41b020f

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

volcenginesdkarkruntime/resources/chat/completions.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
4646
f: Callable[[str], str]):
4747
for message in messages:
4848
if message.get("content", None) is not None:
49-
if isinstance(message.get("content"), str):
50-
message["content"] = f(message.get("content"))
51-
elif isinstance(message.get("content"), Iterable):
49+
current_content = message.get("content")
50+
if isinstance(current_content, str):
51+
message["content"] = f(current_content)
52+
elif isinstance(current_content, Iterable):
5253
raise TypeError("content type {} is not supported end-to-end encryption".
5354
format(type(message.get('content'))))
5455
else:
@@ -68,22 +69,20 @@ def _encrypt(self, model: str, messages: Iterable[ChatCompletionMessageParam], e
6869
def _decrypt_chunk(self, key: bytes, nonce: bytes, resp: Stream[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
6970
for chunk in resp:
7071
if chunk.choices is not None:
71-
choice = chunk.choices[0]
72-
if choice.delta is not None:
73-
if choice.delta.content is not None:
74-
choice.delta.content = aes_gcm_decrypt_base64_string(key, nonce, choice.delta.content)
75-
chunk.choices[0] = choice
72+
for index, choice in enumerate(chunk.choices):
73+
if choice.delta is not None and choice.delta.content is not None:
74+
choice.delta.content = aes_gcm_decrypt_base64_string(key, nonce, choice.delta.content)
75+
chunk.choices[index] = choice
7676
yield chunk
7777

7878
def _decrypt(self, key: bytes, nonce: bytes, resp: ChatCompletion | Stream[ChatCompletionChunk]
7979
) -> ChatCompletion | Stream[ChatCompletionChunk]:
8080
if isinstance(resp, ChatCompletion):
8181
if resp.choices is not None:
82-
if len(resp.choices) > 0:
83-
choice = resp.choices[0]
82+
for index, choice in enumerate(resp.choices):
8483
if choice.message is not None and choice.message.content is not None:
8584
choice.message.content = aes_gcm_decrypt_base64_string(key, nonce, choice.message.content)
86-
resp.choices[0] = choice
85+
resp.choices[index] = choice
8786
return resp
8887
else:
8988
return Stream._make_stream_from_iterator(self._decrypt_chunk(key, nonce, resp))
@@ -175,8 +174,9 @@ def _process_messages(self, messages: Iterable[ChatCompletionMessageParam],
175174
f: Callable[[str], str]):
176175
for message in messages:
177176
if message.get("content", None) is not None:
178-
if isinstance(message.get("content"), str):
179-
message["content"] = f(message.get("content"))
177+
current_content = message.get("content")
178+
if isinstance(current_content, str):
179+
message["content"] = f(current_content)
180180
else:
181181
raise TypeError("content type {} is not supported end-to-end encryption".
182182
format(type(message.get('content'))))
@@ -195,22 +195,20 @@ async def _decrypt_chunk(self, key: bytes, nonce: bytes, resp: AsyncStream[ChatC
195195
) -> AsyncIterator[ChatCompletionChunk]:
196196
async for chunk in resp:
197197
if chunk.choices is not None:
198-
if len(chunk.choices) > 0:
199-
choice = chunk.choices[0]
198+
for index, choice in enumerate(chunk.choices):
200199
if choice.delta is not None and choice.delta.content is not None:
201200
choice.delta.content = aes_gcm_decrypt_base64_string(key, nonce, choice.delta.content)
202-
chunk.choices[0] = choice
201+
chunk.choices[index] = choice
203202
yield chunk
204203

205204
async def _decrypt(self, key: bytes, nonce: bytes, resp: ChatCompletion | AsyncStream[ChatCompletionChunk]
206205
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
207206
if isinstance(resp, ChatCompletion):
208207
if resp.choices is not None:
209-
choice = resp.choices[0]
210-
if choice.message is not None:
211-
if choice.message.content is not None:
208+
for index, choice in enumerate(resp.choices):
209+
if choice.message is not None and choice.message.content is not None:
212210
choice.message.content = aes_gcm_decrypt_base64_string(key, nonce, choice.message.content)
213-
resp.choices[0] = choice
211+
resp.choices[index] = choice
214212
return resp
215213
else:
216214
return AsyncStream._make_stream_from_iterator(self._decrypt_chunk(key, nonce, resp))

0 commit comments

Comments
 (0)