Skip to content

Commit 601f0e2

Browse files
committed
fix: handle invalid JSON chunk in OpenAiChatModel
Signed-off-by: Minu Kim <[email protected]>
1 parent 14b91dc commit 601f0e2

File tree

3 files changed

+64
-81
lines changed

3 files changed

+64
-81
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
329329
previousChatResponse);
330330
return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));
331331
}
332+
332333
catch (Exception e) {
333334
logger.error("Error processing chat completion", e);
334335
return new ChatResponse(List.of());
@@ -492,14 +493,20 @@ private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usa
492493
* @return the ChatCompletion
493494
*/
494495
private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chunk) {
495-
List<Choice> choices = chunk.choices()
496-
.stream()
497-
.map(chunkChoice -> new Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(),
498-
chunkChoice.logprobs()))
499-
.toList();
500-
501-
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(),
502-
chunk.systemFingerprint(), "chat.completion", chunk.usage());
496+
try {
497+
List<Choice> choices = chunk.choices()
498+
.stream()
499+
.map(chunkChoice -> new Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(),
500+
chunkChoice.logprobs()))
501+
.toList();
502+
503+
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
504+
chunk.serviceTier(), chunk.systemFingerprint(), "chat.completion", chunk.usage());
505+
}
506+
catch (Exception e) {
507+
logger.warn("Invalid JSON chunk received, skipping. Raw chunk: {}", chunk, e);
508+
throw new RuntimeException("Failed to parse ChatCompletionChunk", e);
509+
}
503510
}
504511

505512
private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.mockito.Mockito;
21+
import reactor.core.publisher.Flux;
22+
23+
import org.springframework.ai.chat.prompt.Prompt;
24+
import org.springframework.ai.openai.api.OpenAiApi;
25+
26+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
27+
import static org.mockito.ArgumentMatchers.any;
28+
import static org.mockito.Mockito.when;
29+
30+
public class OpenAiChatModelStreamingTest {
31+
32+
@Test
33+
void shouldThrowExceptionOnInvalidJsonChunk() {
34+
OpenAiApi mockApi = Mockito.mock(OpenAiApi.class);
35+
36+
OpenAiApi.ChatCompletionChunk invalidChunk = new OpenAiApi.ChatCompletionChunk("invalid-id", null,
37+
System.currentTimeMillis() / 1000L, "gpt-test-model", null, null, null, null);
38+
39+
when(mockApi.chatCompletionStream(any(), any())).thenReturn(Flux.just(invalidChunk));
40+
41+
OpenAiChatOptions options = OpenAiChatOptions.builder().model("gpt-test-model").build();
42+
OpenAiChatModel model = OpenAiChatModel.builder().openAiApi(mockApi).defaultOptions(options).build();
43+
44+
assertThatThrownBy(() -> model.stream(new Prompt("Hello")).collectList().block())
45+
.isInstanceOf(RuntimeException.class)
46+
.hasMessageContaining("Failed to parse ChatCompletionChunk");
47+
}
48+
49+
}

spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolOnFinishPredicate.java

Lines changed: 0 additions & 73 deletions
This file was deleted.

0 commit comments

Comments
 (0)