Skip to content

Commit 1c5f73c

Browse files
committed
Added ThinkingBudget support
Signed-off-by: ddobrin <[email protected]>
1 parent 5a4b438 commit 1c5f73c

File tree

4 files changed

+200
-2
lines changed

4 files changed

+200
-2
lines changed

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import com.google.genai.types.Part;
3939
import com.google.genai.types.SafetySetting;
4040
import com.google.genai.types.Schema;
41+
import com.google.genai.types.ThinkingConfig;
4142
import com.google.genai.types.Tool;
4243
import com.google.genai.types.FinishReason;
4344
import io.micrometer.observation.Observation;
@@ -672,6 +673,13 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
672673
if (requestOptions.getPresencePenalty() != null) {
673674
configBuilder.presencePenalty(requestOptions.getPresencePenalty().floatValue());
674675
}
676+
if (requestOptions.getThinkingBudget() != null) {
677+
configBuilder.thinkingConfig(
678+
ThinkingConfig.builder()
679+
.thinkingBudget(requestOptions.getThinkingBudget())
680+
.build()
681+
);
682+
}
675683

676684
// Add safety settings
677685
if (!CollectionUtils.isEmpty(requestOptions.getSafetySettings())) {

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions {
107107
*/
108108
private @JsonProperty("presencePenalty") Double presencePenalty;
109109

110+
/**
111+
* Optional. Thinking budget for the thinking process.
112+
* This is part of the thinkingConfig in GenerationConfig.
113+
*/
114+
private @JsonProperty("thinkingBudget") Integer thinkingBudget;
115+
110116
/**
111117
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
112118
* completion requests.
@@ -163,6 +169,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti
163169
options.setSafetySettings(fromOptions.getSafetySettings());
164170
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
165171
options.setToolContext(fromOptions.getToolContext());
172+
options.setThinkingBudget(fromOptions.getThinkingBudget());
166173
return options;
167174
}
168175

@@ -300,6 +307,14 @@ public void setPresencePenalty(Double presencePenalty) {
300307
this.presencePenalty = presencePenalty;
301308
}
302309

310+
public Integer getThinkingBudget() {
311+
return this.thinkingBudget;
312+
}
313+
314+
public void setThinkingBudget(Integer thinkingBudget) {
315+
this.thinkingBudget = thinkingBudget;
316+
}
317+
303318
public Boolean getGoogleSearchRetrieval() {
304319
return this.googleSearchRetrieval;
305320
}
@@ -341,6 +356,7 @@ public boolean equals(Object o) {
341356
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount)
342357
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
343358
&& Objects.equals(this.presencePenalty, that.presencePenalty)
359+
&& Objects.equals(this.thinkingBudget, that.thinkingBudget)
344360
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
345361
&& Objects.equals(this.responseMimeType, that.responseMimeType)
346362
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
@@ -353,7 +369,7 @@ public boolean equals(Object o) {
353369
@Override
354370
public int hashCode() {
355371
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
356-
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
372+
this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model, this.responseMimeType,
357373
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
358374
this.internalToolExecutionEnabled, this.toolContext);
359375
}
@@ -362,7 +378,7 @@ public int hashCode() {
362378
public String toString() {
363379
return "GoogleGenAiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature=" + this.temperature
364380
+ ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty=" + this.frequencyPenalty
365-
+ ", presencePenalty=" + this.presencePenalty + ", candidateCount=" + this.candidateCount
381+
+ ", presencePenalty=" + this.presencePenalty + ", thinkingBudget=" + this.thinkingBudget + ", candidateCount=" + this.candidateCount
366382
+ ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='"
367383
+ this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames="
368384
+ this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval + ", safetySettings="
@@ -489,6 +505,11 @@ public Builder toolContext(Map<String, Object> toolContext) {
489505
return this;
490506
}
491507

508+
public Builder thinkingBudget(Integer thinkingBudget) {
509+
this.options.setThinkingBudget(thinkingBudget);
510+
return this;
511+
}
512+
492513
public GoogleGenAiChatOptions build() {
493514
return this.options;
494515
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.google.genai;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import static org.assertj.core.api.Assertions.assertThat;
22+
23+
/**
24+
* Test for GoogleGenAiChatOptions
25+
*
26+
* @author Dan Dobrin
27+
*/
28+
public class GoogleGenAiChatOptionsTest {
29+
30+
@Test
31+
public void testThinkingBudgetGetterSetter() {
32+
GoogleGenAiChatOptions options = new GoogleGenAiChatOptions();
33+
34+
assertThat(options.getThinkingBudget()).isNull();
35+
36+
options.setThinkingBudget(12853);
37+
assertThat(options.getThinkingBudget()).isEqualTo(12853);
38+
39+
options.setThinkingBudget(null);
40+
assertThat(options.getThinkingBudget()).isNull();
41+
}
42+
43+
@Test
44+
public void testThinkingBudgetWithBuilder() {
45+
GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder()
46+
.model("test-model")
47+
.thinkingBudget(15000)
48+
.build();
49+
50+
assertThat(options.getModel()).isEqualTo("test-model");
51+
assertThat(options.getThinkingBudget()).isEqualTo(15000);
52+
}
53+
54+
@Test
55+
public void testFromOptionsWithThinkingBudget() {
56+
GoogleGenAiChatOptions original = GoogleGenAiChatOptions.builder()
57+
.model("test-model")
58+
.temperature(0.8)
59+
.thinkingBudget(20000)
60+
.build();
61+
62+
GoogleGenAiChatOptions copy = GoogleGenAiChatOptions.fromOptions(original);
63+
64+
assertThat(copy.getModel()).isEqualTo("test-model");
65+
assertThat(copy.getTemperature()).isEqualTo(0.8);
66+
assertThat(copy.getThinkingBudget()).isEqualTo(20000);
67+
assertThat(copy).isNotSameAs(original);
68+
}
69+
70+
@Test
71+
public void testCopyWithThinkingBudget() {
72+
GoogleGenAiChatOptions original = GoogleGenAiChatOptions.builder()
73+
.model("test-model")
74+
.thinkingBudget(30000)
75+
.build();
76+
77+
GoogleGenAiChatOptions copy = original.copy();
78+
79+
assertThat(copy.getModel()).isEqualTo("test-model");
80+
assertThat(copy.getThinkingBudget()).isEqualTo(30000);
81+
assertThat(copy).isNotSameAs(original);
82+
}
83+
84+
@Test
85+
public void testEqualsAndHashCodeWithThinkingBudget() {
86+
GoogleGenAiChatOptions options1 = GoogleGenAiChatOptions.builder()
87+
.model("test-model")
88+
.thinkingBudget(12853)
89+
.build();
90+
91+
GoogleGenAiChatOptions options2 = GoogleGenAiChatOptions.builder()
92+
.model("test-model")
93+
.thinkingBudget(12853)
94+
.build();
95+
96+
GoogleGenAiChatOptions options3 = GoogleGenAiChatOptions.builder()
97+
.model("test-model")
98+
.thinkingBudget(25000)
99+
.build();
100+
101+
assertThat(options1).isEqualTo(options2);
102+
assertThat(options1.hashCode()).isEqualTo(options2.hashCode());
103+
assertThat(options1).isNotEqualTo(options3);
104+
assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode());
105+
}
106+
107+
@Test
108+
public void testToStringWithThinkingBudget() {
109+
GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder()
110+
.model("test-model")
111+
.thinkingBudget(12853)
112+
.build();
113+
114+
String toString = options.toString();
115+
assertThat(toString).contains("thinkingBudget=12853");
116+
assertThat(toString).contains("test-model");
117+
}
118+
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/gemini/CreateGeminiRequestTests.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import org.springframework.ai.chat.messages.SystemMessage;
3232
import org.springframework.ai.chat.messages.UserMessage;
33+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
34+
import org.springframework.ai.chat.model.ChatResponse;
3335
import org.springframework.ai.chat.prompt.Prompt;
3436
import org.springframework.ai.content.Media;
3537
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -299,4 +301,53 @@ public void createRequestWithGenerationConfigOptions() {
299301
assertThat(request.config().responseMimeType().orElse("")).isEqualTo("application/json");
300302
}
301303

304+
@Test
305+
public void createRequestWithThinkingBudget() {
306+
307+
var client = GoogleGenAiChatModel.builder()
308+
.genAiClient(this.genAiClient)
309+
.defaultOptions(GoogleGenAiChatOptions.builder()
310+
.model("DEFAULT_MODEL")
311+
.thinkingBudget(12853)
312+
.build())
313+
.build();
314+
315+
GeminiRequest request = client
316+
.createGeminiRequest(client.buildRequestPrompt(new Prompt("Test message content")));
317+
318+
assertThat(request.contents()).hasSize(1);
319+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
320+
321+
// Verify thinkingConfig is present and contains thinkingBudget
322+
assertThat(request.config().thinkingConfig()).isPresent();
323+
assertThat(request.config().thinkingConfig().get().thinkingBudget()).isPresent();
324+
assertThat(request.config().thinkingConfig().get().thinkingBudget().get()).isEqualTo(12853);
325+
}
326+
327+
@Test
328+
public void createRequestWithThinkingBudgetOverride() {
329+
330+
var client = GoogleGenAiChatModel.builder()
331+
.genAiClient(this.genAiClient)
332+
.defaultOptions(GoogleGenAiChatOptions.builder()
333+
.model("DEFAULT_MODEL")
334+
.thinkingBudget(10000)
335+
.build())
336+
.build();
337+
338+
// Override default thinkingBudget with prompt-specific value
339+
GeminiRequest request = client.createGeminiRequest(client.buildRequestPrompt(
340+
new Prompt("Test message content",
341+
GoogleGenAiChatOptions.builder()
342+
.thinkingBudget(25000)
343+
.build())));
344+
345+
assertThat(request.contents()).hasSize(1);
346+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
347+
348+
// Verify prompt-specific thinkingBudget overrides default
349+
assertThat(request.config().thinkingConfig()).isPresent();
350+
assertThat(request.config().thinkingConfig().get().thinkingBudget()).isPresent();
351+
assertThat(request.config().thinkingConfig().get().thinkingBudget().get()).isEqualTo(25000);
352+
}
302353
}

0 commit comments

Comments
 (0)