Skip to content

Commit 237e5d3

Browse files
committed
Fixes #3849 Add Thinking Config to the Google Gen AI Module
Signed-off-by: ddobrin <[email protected]>
1 parent 5a1cafe commit 237e5d3

File tree

5 files changed

+296
-8
lines changed

5 files changed

+296
-8
lines changed

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import com.google.genai.types.Part;
3939
import com.google.genai.types.SafetySetting;
4040
import com.google.genai.types.Schema;
41+
import com.google.genai.types.ThinkingConfig;
4142
import com.google.genai.types.Tool;
4243
import com.google.genai.types.FinishReason;
4344
import io.micrometer.observation.Observation;
@@ -672,6 +673,10 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
672673
if (requestOptions.getPresencePenalty() != null) {
673674
configBuilder.presencePenalty(requestOptions.getPresencePenalty().floatValue());
674675
}
676+
if (requestOptions.getThinkingBudget() != null) {
677+
configBuilder
678+
.thinkingConfig(ThinkingConfig.builder().thinkingBudget(requestOptions.getThinkingBudget()).build());
679+
}
675680

676681
// Add safety settings
677682
if (!CollectionUtils.isEmpty(requestOptions.getSafetySettings())) {

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions {
107107
*/
108108
private @JsonProperty("presencePenalty") Double presencePenalty;
109109

110+
/**
111+
* Optional. Thinking budget for the thinking process.
112+
* This is part of the thinkingConfig in GenerationConfig.
113+
*/
114+
private @JsonProperty("thinkingBudget") Integer thinkingBudget;
115+
110116
/**
111117
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
112118
* completion requests.
@@ -163,6 +169,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti
163169
options.setSafetySettings(fromOptions.getSafetySettings());
164170
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
165171
options.setToolContext(fromOptions.getToolContext());
172+
options.setThinkingBudget(fromOptions.getThinkingBudget());
166173
return options;
167174
}
168175

@@ -300,6 +307,14 @@ public void setPresencePenalty(Double presencePenalty) {
300307
this.presencePenalty = presencePenalty;
301308
}
302309

310+
public Integer getThinkingBudget() {
311+
return this.thinkingBudget;
312+
}
313+
314+
public void setThinkingBudget(Integer thinkingBudget) {
315+
this.thinkingBudget = thinkingBudget;
316+
}
317+
303318
public Boolean getGoogleSearchRetrieval() {
304319
return this.googleSearchRetrieval;
305320
}
@@ -341,6 +356,7 @@ public boolean equals(Object o) {
341356
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount)
342357
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
343358
&& Objects.equals(this.presencePenalty, that.presencePenalty)
359+
&& Objects.equals(this.thinkingBudget, that.thinkingBudget)
344360
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
345361
&& Objects.equals(this.responseMimeType, that.responseMimeType)
346362
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
@@ -353,20 +369,20 @@ public boolean equals(Object o) {
353369
@Override
354370
public int hashCode() {
355371
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
356-
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
357-
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
358-
this.internalToolExecutionEnabled, this.toolContext);
372+
this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model,
373+
this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval,
374+
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext);
359375
}
360376

361377
@Override
362378
public String toString() {
363379
return "GoogleGenAiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature=" + this.temperature
364380
+ ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty=" + this.frequencyPenalty
365-
+ ", presencePenalty=" + this.presencePenalty + ", candidateCount=" + this.candidateCount
366-
+ ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='"
367-
+ this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames="
368-
+ this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval + ", safetySettings="
369-
+ this.safetySettings + '}';
381+
+ ", presencePenalty=" + this.presencePenalty + ", thinkingBudget=" + this.thinkingBudget
382+
+ ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='"
383+
+ this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks="
384+
+ this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval="
385+
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}';
370386
}
371387

372388
@Override
@@ -489,6 +505,11 @@ public Builder toolContext(Map<String, Object> toolContext) {
489505
return this;
490506
}
491507

508+
public Builder thinkingBudget(Integer thinkingBudget) {
509+
this.options.setThinkingBudget(thinkingBudget);
510+
return this;
511+
}
512+
492513
public GoogleGenAiChatOptions build() {
493514
return this.options;
494515
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/CreateGeminiRequestTests.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import org.springframework.ai.chat.messages.SystemMessage;
3232
import org.springframework.ai.chat.messages.UserMessage;
33+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
34+
import org.springframework.ai.chat.model.ChatResponse;
3335
import org.springframework.ai.chat.prompt.Prompt;
3436
import org.springframework.ai.content.Media;
3537
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -299,4 +301,45 @@ public void createRequestWithGenerationConfigOptions() {
299301
assertThat(request.config().responseMimeType().orElse("")).isEqualTo("application/json");
300302
}
301303

304+
@Test
305+
public void createRequestWithThinkingBudget() {
306+
307+
var client = GoogleGenAiChatModel.builder()
308+
.genAiClient(this.genAiClient)
309+
.defaultOptions(GoogleGenAiChatOptions.builder().model("DEFAULT_MODEL").thinkingBudget(12853).build())
310+
.build();
311+
312+
GeminiRequest request = client
313+
.createGeminiRequest(client.buildRequestPrompt(new Prompt("Test message content")));
314+
315+
assertThat(request.contents()).hasSize(1);
316+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
317+
318+
// Verify thinkingConfig is present and contains thinkingBudget
319+
assertThat(request.config().thinkingConfig()).isPresent();
320+
assertThat(request.config().thinkingConfig().get().thinkingBudget()).isPresent();
321+
assertThat(request.config().thinkingConfig().get().thinkingBudget().get()).isEqualTo(12853);
322+
}
323+
324+
@Test
325+
public void createRequestWithThinkingBudgetOverride() {
326+
327+
var client = GoogleGenAiChatModel.builder()
328+
.genAiClient(this.genAiClient)
329+
.defaultOptions(GoogleGenAiChatOptions.builder().model("DEFAULT_MODEL").thinkingBudget(10000).build())
330+
.build();
331+
332+
// Override default thinkingBudget with prompt-specific value
333+
GeminiRequest request = client.createGeminiRequest(client.buildRequestPrompt(
334+
new Prompt("Test message content", GoogleGenAiChatOptions.builder().thinkingBudget(25000).build())));
335+
336+
assertThat(request.contents()).hasSize(1);
337+
assertThat(request.modelName()).isEqualTo("DEFAULT_MODEL");
338+
339+
// Verify prompt-specific thinkingBudget overrides default
340+
assertThat(request.config().thinkingConfig()).isPresent();
341+
assertThat(request.config().thinkingConfig().get().thinkingBudget()).isPresent();
342+
assertThat(request.config().thinkingConfig().get().thinkingBudget().get()).isEqualTo(25000);
343+
}
344+
302345
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelIT.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import org.junit.jupiter.api.Test;
3030
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3131

32+
import org.slf4j.Logger;
33+
import org.slf4j.LoggerFactory;
3234
import org.springframework.ai.chat.client.ChatClient;
3335
import org.springframework.ai.chat.messages.AssistantMessage;
3436
import org.springframework.ai.chat.messages.Message;
@@ -65,6 +67,8 @@
6567
@EnabledIfEnvironmentVariable(named = "GOOGLE_CLOUD_LOCATION", matches = ".*")
6668
class GoogleGenAiChatModelIT {
6769

70+
private static final Logger logger = LoggerFactory.getLogger(GoogleGenAiChatModelIT.class);
71+
6872
@Autowired
6973
private GoogleGenAiChatModel chatModel;
7074

@@ -384,6 +388,102 @@ void jsonTextToolCallingTest() {
384388
assertThat(response).contains("2025-05-08T10:10:10+02:00");
385389
}
386390

391+
@Test
392+
void testThinkingBudgetGeminiProAutomaticDecisionByModel() {
393+
GoogleGenAiChatModel chatModelWithThinkingBudget = GoogleGenAiChatModel.builder()
394+
.genAiClient(genAiClient())
395+
.defaultOptions(GoogleGenAiChatOptions.builder().model(ChatModel.GEMINI_2_5_PRO).temperature(0.1).build())
396+
.build();
397+
398+
ChatClient chatClient = ChatClient.builder(chatModelWithThinkingBudget).build();
399+
400+
// Create a prompt that will trigger the tool call with a specific request that
401+
// should invoke the tool
402+
long start = System.currentTimeMillis();
403+
String response = chatClient.prompt()
404+
.user("Explain to me briefly how I can start a SpringAI project")
405+
.call()
406+
.content();
407+
408+
assertThat(response).isNotEmpty();
409+
logger.info("Response: {} in {} ms", response, System.currentTimeMillis() - start);
410+
}
411+
412+
@Test
413+
void testThinkingBudgetGeminiProMinBudget() {
414+
GoogleGenAiChatModel chatModelWithThinkingBudget = GoogleGenAiChatModel.builder()
415+
.genAiClient(genAiClient())
416+
.defaultOptions(GoogleGenAiChatOptions.builder()
417+
.model(ChatModel.GEMINI_2_5_PRO)
418+
.temperature(0.1)
419+
.thinkingBudget(128)
420+
.build())
421+
.build();
422+
423+
ChatClient chatClient = ChatClient.builder(chatModelWithThinkingBudget).build();
424+
425+
// Create a prompt that will trigger the tool call with a specific request that
426+
// should invoke the tool
427+
long start = System.currentTimeMillis();
428+
String response = chatClient.prompt()
429+
.user("Explain to me briefly how I can start a SpringAI project")
430+
.call()
431+
.content();
432+
433+
assertThat(response).isNotEmpty();
434+
logger.info("Response: {} in {} ms", response, System.currentTimeMillis() - start);
435+
}
436+
437+
@Test
438+
void testThinkingBudgetGeminiFlashDefaultBudget() {
439+
GoogleGenAiChatModel chatModelWithThinkingBudget = GoogleGenAiChatModel.builder()
440+
.genAiClient(genAiClient())
441+
.defaultOptions(GoogleGenAiChatOptions.builder()
442+
.model(ChatModel.GEMINI_2_5_FLASH)
443+
.temperature(0.1)
444+
.thinkingBudget(8192)
445+
.build())
446+
.build();
447+
448+
ChatClient chatClient = ChatClient.builder(chatModelWithThinkingBudget).build();
449+
450+
// Create a prompt that will trigger the tool call with a specific request that
451+
// should invoke the tool
452+
long start = System.currentTimeMillis();
453+
String response = chatClient.prompt()
454+
.user("Explain to me briefly how I can start a SpringAI project")
455+
.call()
456+
.content();
457+
458+
assertThat(response).isNotEmpty();
459+
logger.info("Response: {} in {} ms", response, System.currentTimeMillis() - start);
460+
}
461+
462+
@Test
463+
void testThinkingBudgetGeminiFlashThinkingTurnedOff() {
464+
GoogleGenAiChatModel chatModelWithThinkingBudget = GoogleGenAiChatModel.builder()
465+
.genAiClient(genAiClient())
466+
.defaultOptions(GoogleGenAiChatOptions.builder()
467+
.model(ChatModel.GEMINI_2_5_FLASH)
468+
.temperature(0.1)
469+
.thinkingBudget(0)
470+
.build())
471+
.build();
472+
473+
ChatClient chatClient = ChatClient.builder(chatModelWithThinkingBudget).build();
474+
475+
// Create a prompt that will trigger the tool call with a specific request that
476+
// should invoke the tool
477+
long start = System.currentTimeMillis();
478+
String response = chatClient.prompt()
479+
.user("Explain to me briefly how I can start a SpringAI project")
480+
.call()
481+
.content();
482+
483+
assertThat(response).isNotEmpty();
484+
logger.info("Response: {} in {} ms", response, System.currentTimeMillis() - start);
485+
}
486+
387487
/**
388488
* Tool class that returns a JSON array to test the jsonToStruct method's ability to
389489
* handle JSON arrays. This specifically tests the PR changes that improve the

0 commit comments

Comments
 (0)