Skip to content

Commit c045fd7

Browse files
fix npe in compatible streaming api with vertex ai gemini
Signed-off-by: jonghoon park <[email protected]>
1 parent e10fbde commit c045fd7

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
304304
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
305305
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
306306
try {
307-
@SuppressWarnings("null")
308-
String id = chatCompletion2.id();
307+
// If an id is not provided, set to "NO_ID" (for compatible APIs).
308+
String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id();
309309

310310
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off
311311
if (choice.message().role() != null) {
312312
roleMap.putIfAbsent(id, choice.message().role().name());
313313
}
314314
Map<String, Object> metadata = Map.of(
315-
"id", chatCompletion2.id(),
315+
"id", id,
316316
"role", roleMap.getOrDefault(id, ""),
317317
"index", choice.index(),
318318
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.proxy;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
import org.springframework.ai.chat.messages.AssistantMessage;
22+
import org.springframework.ai.chat.messages.Message;
23+
import org.springframework.ai.chat.messages.UserMessage;
24+
import org.springframework.ai.chat.model.ChatResponse;
25+
import org.springframework.ai.chat.model.Generation;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
28+
import org.springframework.ai.model.SimpleApiKey;
29+
import org.springframework.ai.model.tool.ToolCallingManager;
30+
import org.springframework.ai.openai.OpenAiChatModel;
31+
import org.springframework.ai.openai.OpenAiChatOptions;
32+
import org.springframework.ai.openai.api.OpenAiApi;
33+
import org.springframework.beans.factory.annotation.Autowired;
34+
import org.springframework.beans.factory.annotation.Value;
35+
import org.springframework.boot.SpringBootConfiguration;
36+
import org.springframework.boot.test.context.SpringBootTest;
37+
import org.springframework.context.annotation.Bean;
38+
import org.springframework.core.io.Resource;
39+
import reactor.core.publisher.Flux;
40+
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.stream.Collectors;
44+
45+
import static org.assertj.core.api.Assertions.assertThat;
46+
47+
@SpringBootTest(classes = VertexAIGeminiWithOpenAiChatModelIT.Config.class)
48+
@EnabledIfEnvironmentVariable(named = "GEMINI_API_KEY", matches = ".+")
49+
class VertexAIGeminiWithOpenAiChatModelIT {
50+
51+
private static final String VERTEX_AI_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com";
52+
53+
private static final String VERTEX_AI_GEMINI_DEFAULT_MODEL = "gemini-2.0-flash";
54+
55+
@Value("classpath:/prompts/system-message.st")
56+
private Resource systemResource;
57+
58+
@Autowired
59+
private OpenAiChatModel chatModel;
60+
61+
@Test
62+
void roleTest() {
63+
UserMessage userMessage = new UserMessage(
64+
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
65+
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
66+
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
67+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
68+
ChatResponse response = this.chatModel.call(prompt);
69+
assertThat(response.getResults()).hasSize(1);
70+
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
71+
}
72+
73+
@Test
74+
void streamRoleTest() {
75+
UserMessage userMessage = new UserMessage(
76+
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
77+
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
78+
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
79+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
80+
Flux<ChatResponse> flux = this.chatModel.stream(prompt);
81+
82+
List<ChatResponse> responses = flux.collectList().block();
83+
assertThat(responses.size()).isGreaterThan(1);
84+
85+
String stitchedResponseContent = responses.stream()
86+
.map(ChatResponse::getResults)
87+
.flatMap(List::stream)
88+
.map(Generation::getOutput)
89+
.map(AssistantMessage::getText)
90+
.collect(Collectors.joining());
91+
92+
assertThat(stitchedResponseContent).contains("Blackbeard");
93+
}
94+
95+
@SpringBootConfiguration
96+
static class Config {
97+
98+
@Bean
99+
public OpenAiApi chatCompletionApi() {
100+
return OpenAiApi.builder()
101+
.baseUrl(VERTEX_AI_GEMINI_BASE_URL)
102+
.completionsPath("/v1beta/openai/chat/completions")
103+
.apiKey(new SimpleApiKey(System.getenv("GEMINI_API_KEY")))
104+
.build();
105+
}
106+
107+
@Bean
108+
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
109+
return OpenAiChatModel.builder()
110+
.openAiApi(openAiApi)
111+
.toolCallingManager(ToolCallingManager.builder().build())
112+
.defaultOptions(OpenAiChatOptions.builder().model(VERTEX_AI_GEMINI_DEFAULT_MODEL).build())
113+
.build();
114+
}
115+
116+
}
117+
118+
}

0 commit comments

Comments
 (0)