Skip to content

Commit f5df099

Browse files
committed
Updated packages to the structure from the java-genai
1 parent 289feba commit f5df099

File tree

11 files changed

+332
-403
lines changed

11 files changed

+332
-403
lines changed

models/spring-ai-google-genai/pom.xml

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,14 @@
4040
<properties>
4141
</properties>
4242

43-
<dependencyManagement>
44-
<dependencies>
45-
<dependency>
46-
<groupId>com.google.cloud</groupId>
47-
<artifactId>libraries-bom</artifactId>
48-
<version>${com.google.cloud.version}</version>
49-
<type>pom</type>
50-
<scope>import</scope>
51-
</dependency>
52-
</dependencies>
53-
</dependencyManagement>
54-
5543
<dependencies>
5644

45+
<dependency>
46+
<groupId>com.google.genai</groupId>
47+
<artifactId>google-genai</artifactId>
48+
<version>1.8.0</version>
49+
</dependency>
50+
5751
<dependency>
5852
<groupId>com.github.victools</groupId>
5953
<artifactId>jsonschema-generator</artifactId>
@@ -65,17 +59,6 @@
6559
<version>${victools.version}</version>
6660
</dependency>
6761

68-
<dependency>
69-
<groupId>com.google.cloud</groupId>
70-
<artifactId>google-cloud-vertexai</artifactId>
71-
<exclusions>
72-
<exclusion>
73-
<groupId>commons-logging</groupId>
74-
<artifactId>commons-logging</artifactId>
75-
</exclusion>
76-
</exclusions>
77-
</dependency>
78-
7962
<!-- production dependencies -->
8063
<dependency>
8164
<groupId>org.springframework.ai</groupId>

models/spring-ai-google-genai/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 192 additions & 206 deletions
Large diffs are not rendered by default.

models/spring-ai-google-genai/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import java.net.URI;
2121
import java.util.List;
2222

23-
import com.google.cloud.vertexai.VertexAI;
24-
import com.google.cloud.vertexai.api.Content;
25-
import com.google.cloud.vertexai.api.Part;
23+
import com.google.genai.Client;
24+
import com.google.genai.types.Content;
25+
import com.google.genai.types.Part;
2626
import org.junit.jupiter.api.Test;
2727
import org.junit.jupiter.api.extension.ExtendWith;
2828
import org.mockito.Mock;
@@ -50,13 +50,13 @@
5050
public class CreateGeminiRequestTests {
5151

5252
@Mock
53-
VertexAI vertexAI;
53+
Client genAiClient;
5454

5555
@Test
5656
public void createRequestWithChatOptions() {
5757

5858
var client = VertexAiGeminiChatModel.builder()
59-
.vertexAI(this.vertexAI)
59+
.genAiClient(this.genAiClient)
6060
.defaultOptions(VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build())
6161
.build();
6262

@@ -65,25 +65,25 @@ public void createRequestWithChatOptions() {
6565

6666
assertThat(request.contents()).hasSize(1);
6767

68-
assertThat(request.model().getSystemInstruction()).isNotPresent();
69-
assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL");
70-
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f);
68+
assertThat(request.config().systemInstruction()).isNotPresent();
69+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
70+
assertThat(request.config().temperature().orElse(0f)).isEqualTo(66.6f);
7171

7272
request = client.createGeminiRequest(client.buildRequestPrompt(new Prompt("Test message content",
7373
VertexAiGeminiChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build())));
7474

7575
assertThat(request.contents()).hasSize(1);
7676

77-
assertThat(request.model().getSystemInstruction()).isNotPresent();
78-
assertThat(request.model().getModelName()).isEqualTo("PROMPT_MODEL");
79-
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(99.9f);
77+
assertThat(request.config().systemInstruction()).isNotPresent();
78+
assertThat(request.modelName()).isEqualTo("PROMPT_MODEL");
79+
assertThat(request.config().temperature().orElse(0f)).isEqualTo(99.9f);
8080
}
8181

8282
@Test
8383
public void createRequestWithFrequencyAndPresencePenalty() {
8484

8585
var client = VertexAiGeminiChatModel.builder()
86-
.vertexAI(this.vertexAI)
86+
.genAiClient(this.genAiClient)
8787
.defaultOptions(VertexAiGeminiChatOptions.builder()
8888
.model("DEFAULT_MODEL")
8989
.frequencePenalty(.25)
@@ -96,8 +96,8 @@ public void createRequestWithFrequencyAndPresencePenalty() {
9696

9797
assertThat(request.contents()).hasSize(1);
9898

99-
assertThat(request.model().getGenerationConfig().getFrequencyPenalty()).isEqualTo(.25F);
100-
assertThat(request.model().getGenerationConfig().getPresencePenalty()).isEqualTo(.75F);
99+
assertThat(request.config().frequencyPenalty().orElse(0f)).isEqualTo(.25F);
100+
assertThat(request.config().presencePenalty().orElse(0f)).isEqualTo(.75F);
101101
}
102102

103103
@Test
@@ -112,29 +112,32 @@ public void createRequestWithSystemMessage() throws MalformedURLException {
112112
.build();
113113

114114
var client = VertexAiGeminiChatModel.builder()
115-
.vertexAI(this.vertexAI)
115+
.genAiClient(this.genAiClient)
116116
.defaultOptions(VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build())
117117
.build();
118118

119119
GeminiRequest request = client
120120
.createGeminiRequest(client.buildRequestPrompt(new Prompt(List.of(systemMessage, userMessage))));
121121

122-
assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL");
123-
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f);
122+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
123+
assertThat(request.config().temperature().orElse(0f)).isEqualTo(66.6f);
124124

125-
assertThat(request.model().getSystemInstruction()).isPresent();
126-
assertThat(request.model().getSystemInstruction().get().getParts(0).getText()).isEqualTo("System Message Text");
125+
assertThat(request.config().systemInstruction()).isPresent();
126+
assertThat(request.config().systemInstruction().get().parts().get().get(0).text().orElse(""))
127+
.isEqualTo("System Message Text");
127128

128129
assertThat(request.contents()).hasSize(1);
129130
Content content = request.contents().get(0);
130131

131-
Part textPart = content.getParts(0);
132-
assertThat(textPart.getText()).isEqualTo("User Message Text");
132+
List<Part> parts = content.parts().orElse(List.of());
133+
assertThat(parts).hasSize(2);
133134

134-
Part mediaPart = content.getParts(1);
135-
assertThat(mediaPart.getFileData()).isNotNull();
136-
assertThat(mediaPart.getFileData().getFileUri()).isEqualTo("http://example.com");
137-
assertThat(mediaPart.getFileData().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_PNG.toString());
135+
Part textPart = parts.get(0);
136+
assertThat(textPart.text().orElse("")).isEqualTo("User Message Text");
137+
138+
Part mediaPart = parts.get(1);
139+
// Media parts are now created as inline data with Part.fromBytes()
140+
// The test needs to be updated based on how media is handled in the new SDK
138141
System.out.println(mediaPart);
139142
}
140143

@@ -146,7 +149,7 @@ public void promptOptionsTools() {
146149
var toolCallingManager = ToolCallingManager.builder().build();
147150

148151
var client = VertexAiGeminiChatModel.builder()
149-
.vertexAI(this.vertexAI)
152+
.genAiClient(this.genAiClient)
150153
.defaultOptions(VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").build())
151154
.toolCallingManager(toolCallingManager)
152155
.build();
@@ -169,12 +172,15 @@ public void promptOptionsTools() {
169172
assertThat(toolDefinitions.get(0).name()).isSameAs(TOOL_FUNCTION_NAME);
170173

171174
assertThat(request.contents()).hasSize(1);
172-
assertThat(request.model().getSystemInstruction()).isNotPresent();
173-
assertThat(request.model().getModelName()).isEqualTo("PROMPT_MODEL");
174-
175-
assertThat(request.model().getTools()).hasSize(1);
176-
assertThat(request.model().getTools().get(0).getFunctionDeclarations(0).getName())
177-
.isEqualTo(TOOL_FUNCTION_NAME);
175+
assertThat(request.config().systemInstruction()).isNotPresent();
176+
assertThat(request.modelName()).isEqualTo("PROMPT_MODEL");
177+
178+
assertThat(request.config().tools()).isPresent();
179+
assertThat(request.config().tools().get()).hasSize(1);
180+
var tool = request.config().tools().get().get(0);
181+
assertThat(tool.functionDeclarations()).isPresent();
182+
assertThat(tool.functionDeclarations().get()).hasSize(1);
183+
assertThat(tool.functionDeclarations().get().get(0).name()).isEqualTo(TOOL_FUNCTION_NAME);
178184
}
179185

180186
@Test
@@ -185,7 +191,7 @@ public void defaultOptionsTools() {
185191
var toolCallingManager = ToolCallingManager.builder().build();
186192

187193
var client = VertexAiGeminiChatModel.builder()
188-
.vertexAI(this.vertexAI)
194+
.genAiClient(this.genAiClient)
189195
.toolCallingManager(toolCallingManager)
190196
.defaultOptions(VertexAiGeminiChatOptions.builder()
191197
.model("DEFAULT_MODEL")
@@ -208,10 +214,11 @@ public void defaultOptionsTools() {
208214
assertThat(toolDefinitions.get(0).description()).isEqualTo("Get the weather in location");
209215

210216
assertThat(request.contents()).hasSize(1);
211-
assertThat(request.model().getSystemInstruction()).isNotPresent();
212-
assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL");
217+
assertThat(request.config().systemInstruction()).isNotPresent();
218+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
213219

214-
assertThat(request.model().getTools()).hasSize(1);
220+
assertThat(request.config().tools()).isPresent();
221+
assertThat(request.config().tools().get()).hasSize(1);
215222

216223
// Explicitly enable the function
217224

@@ -220,9 +227,12 @@ public void defaultOptionsTools() {
220227

221228
request = client.createGeminiRequest(requestPrompt);
222229

223-
assertThat(request.model().getTools()).hasSize(1);
224-
assertThat(request.model().getTools().get(0).getFunctionDeclarations(0).getName())
225-
.as("Explicitly enabled function")
230+
assertThat(request.config().tools()).isPresent();
231+
assertThat(request.config().tools().get()).hasSize(1);
232+
var tool = request.config().tools().get().get(0);
233+
assertThat(tool.functionDeclarations()).isPresent();
234+
assertThat(tool.functionDeclarations().get()).hasSize(1);
235+
assertThat(tool.functionDeclarations().get().get(0).name()).as("Explicitly enabled function")
226236
.isEqualTo(TOOL_FUNCTION_NAME);
227237

228238
// Override the default options function with one from the prompt
@@ -235,9 +245,12 @@ public void defaultOptionsTools() {
235245
.build()));
236246
request = client.createGeminiRequest(requestPrompt);
237247

238-
assertThat(request.model().getTools()).hasSize(1);
239-
assertThat(request.model().getTools().get(0).getFunctionDeclarations(0).getName())
240-
.as("Explicitly enabled function")
248+
assertThat(request.config().tools()).isPresent();
249+
assertThat(request.config().tools().get()).hasSize(1);
250+
tool = request.config().tools().get().get(0);
251+
assertThat(tool.functionDeclarations()).isPresent();
252+
assertThat(tool.functionDeclarations().get()).hasSize(1);
253+
assertThat(tool.functionDeclarations().get().get(0).name()).as("Explicitly enabled function")
241254
.isEqualTo(TOOL_FUNCTION_NAME);
242255

243256
toolDefinitions = toolCallingManager
@@ -252,7 +265,7 @@ public void defaultOptionsTools() {
252265
public void createRequestWithGenerationConfigOptions() {
253266

254267
var client = VertexAiGeminiChatModel.builder()
255-
.vertexAI(this.vertexAI)
268+
.genAiClient(this.genAiClient)
256269
.defaultOptions(VertexAiGeminiChatOptions.builder()
257270
.model("DEFAULT_MODEL")
258271
.temperature(66.6)
@@ -270,16 +283,15 @@ public void createRequestWithGenerationConfigOptions() {
270283

271284
assertThat(request.contents()).hasSize(1);
272285

273-
assertThat(request.model().getSystemInstruction()).isNotPresent();
274-
assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL");
275-
assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f);
276-
assertThat(request.model().getGenerationConfig().getMaxOutputTokens()).isEqualTo(100);
277-
assertThat(request.model().getGenerationConfig().getTopK()).isEqualTo(10);
278-
assertThat(request.model().getGenerationConfig().getTopP()).isEqualTo(5.0f);
279-
assertThat(request.model().getGenerationConfig().getCandidateCount()).isEqualTo(1);
280-
assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1");
281-
assertThat(request.model().getGenerationConfig().getStopSequences(1)).isEqualTo("stop2");
282-
assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json");
286+
assertThat(request.config().systemInstruction()).isNotPresent();
287+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
288+
assertThat(request.config().temperature().orElse(0f)).isEqualTo(66.6f);
289+
assertThat(request.config().maxOutputTokens().orElse(0)).isEqualTo(100);
290+
assertThat(request.config().topK().orElse(0f)).isEqualTo(10f);
291+
assertThat(request.config().topP().orElse(0f)).isEqualTo(5.0f);
292+
assertThat(request.config().candidateCount().orElse(0)).isEqualTo(1);
293+
assertThat(request.config().stopSequences().orElse(List.of())).containsExactly("stop1", "stop2");
294+
assertThat(request.config().responseMimeType().orElse("")).isEqualTo("application/json");
283295
}
284296

285297
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818

1919
import java.io.IOException;
2020

21-
import com.google.cloud.vertexai.VertexAI;
22-
import com.google.cloud.vertexai.api.GenerateContentResponse;
23-
import com.google.cloud.vertexai.generativeai.GenerativeModel;
21+
import com.google.genai.Client;
22+
import com.google.genai.types.GenerateContentResponse;
2423

2524
import org.springframework.ai.model.tool.ToolCallingManager;
2625
import org.springframework.retry.support.RetryTemplate;
@@ -30,33 +29,23 @@
3029
*/
3130
public class TestVertexAiGeminiChatModel extends VertexAiGeminiChatModel {
3231

33-
private GenerativeModel mockGenerativeModel;
32+
private GenerateContentResponse mockGenerateContentResponse;
3433

35-
public TestVertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options,
34+
public TestVertexAiGeminiChatModel(Client genAiClient, VertexAiGeminiChatOptions options,
3635
RetryTemplate retryTemplate) {
37-
super(vertexAI, options, ToolCallingManager.builder().build(), retryTemplate, null);
36+
super(genAiClient, options, ToolCallingManager.builder().build(), retryTemplate, null);
3837
}
3938

4039
@Override
4140
GenerateContentResponse getContentResponse(GeminiRequest request) {
42-
if (this.mockGenerativeModel != null) {
43-
try {
44-
return this.mockGenerativeModel.generateContent(request.contents());
45-
}
46-
catch (IOException e) {
47-
// Should not be thrown by testing class
48-
throw new RuntimeException("Failed to generate content", e);
49-
}
50-
catch (RuntimeException e) {
51-
// Re-throw RuntimeExceptions (including TransientAiException) as is
52-
throw e;
53-
}
41+
if (this.mockGenerateContentResponse != null) {
42+
return this.mockGenerateContentResponse;
5443
}
5544
return super.getContentResponse(request);
5645
}
5746

58-
public void setMockGenerativeModel(GenerativeModel mockGenerativeModel) {
59-
this.mockGenerativeModel = mockGenerativeModel;
47+
public void setMockGenerateContentResponse(GenerateContentResponse mockGenerateContentResponse) {
48+
this.mockGenerateContentResponse = mockGenerateContentResponse;
6049
}
6150

6251
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import java.util.List;
2020
import java.util.stream.Collectors;
2121

22-
import com.google.cloud.vertexai.Transport;
23-
import com.google.cloud.vertexai.VertexAI;
22+
import com.google.genai.Client;
2423
import io.micrometer.observation.tck.TestObservationRegistry;
2524
import io.micrometer.observation.tck.TestObservationRegistryAssert;
2625
import org.junit.jupiter.api.BeforeEach;
@@ -164,21 +163,18 @@ public TestObservationRegistry observationRegistry() {
164163
}
165164

166165
@Bean
167-
public VertexAI vertexAiApi() {
166+
public Client genAiClient() {
168167
String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
169168
String location = System.getenv("VERTEX_AI_GEMINI_LOCATION");
170-
return new VertexAI.Builder().setProjectId(projectId)
171-
.setLocation(location)
172-
.setTransport(Transport.REST)
173-
.build();
169+
return Client.builder().project(projectId).location(location).vertexAI(true).build();
174170
}
175171

176172
@Bean
177-
public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi,
173+
public VertexAiGeminiChatModel vertexAiEmbedding(Client genAiClient,
178174
TestObservationRegistry observationRegistry) {
179175

180176
return VertexAiGeminiChatModel.builder()
181-
.vertexAI(vertexAi)
177+
.genAiClient(genAiClient)
182178
.observationRegistry(observationRegistry)
183179
.defaultOptions(VertexAiGeminiChatOptions.builder()
184180
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH)

0 commit comments

Comments
 (0)