Skip to content

Commit 4beb1ec

Browse files
committed
Introduce thought configuration to quarkus-langchain4j-ai-gemini
Closes: #1694
1 parent 34b8f97 commit 4beb1ec

File tree

10 files changed

+102
-10
lines changed

10 files changed

+102
-10
lines changed

model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiChatLanguageModel.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class AiGeminiChatLanguageModel extends GeminiChatLanguageModel {
2525

2626
private AiGeminiChatLanguageModel(Builder builder) {
2727
super(builder.modelId, builder.temperature, builder.maxOutputTokens, builder.topK, builder.topP, builder.responseFormat,
28-
builder.listeners);
28+
builder.listeners, builder.thinkingBudget, builder.includeThoughts);
2929

3030
this.apiMetadata = AiGeminiRestApi.ApiMetadata
3131
.builder()
@@ -81,6 +81,8 @@ public static final class Builder {
8181
private Boolean logRequests = false;
8282
private Boolean logResponses = false;
8383
private List<ChatModelListener> listeners = Collections.emptyList();
84+
private Long thinkingBudget;
85+
private boolean includeThoughts = false;
8486

8587
public Builder configName(String configName) {
8688
this.configName = configName;
@@ -147,6 +149,16 @@ public Builder listeners(List<ChatModelListener> listeners) {
147149
return this;
148150
}
149151

152+
public Builder thinkingBudget(Long thinkingBudget) {
153+
this.thinkingBudget = thinkingBudget;
154+
return this;
155+
}
156+
157+
public Builder includeThoughts(Boolean includeThoughts) {
158+
this.includeThoughts = includeThoughts;
159+
return this;
160+
}
161+
150162
public AiGeminiChatLanguageModel build() {
151163
return new AiGeminiChatLanguageModel(this);
152164
}

model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRecorder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import dev.langchain4j.model.chat.listener.ChatModelListener;
1717
import dev.langchain4j.model.embedding.DisabledEmbeddingModel;
1818
import dev.langchain4j.model.embedding.EmbeddingModel;
19+
import io.quarkiverse.langchain4j.ai.runtime.gemini.config.ChatModelConfig;
1920
import io.quarkiverse.langchain4j.ai.runtime.gemini.config.LangChain4jAiGeminiConfig;
2021
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
2122
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
@@ -109,6 +110,13 @@ public Function<SyntheticCreationalContext<ChatModel>, ChatModel> chatModel(Stri
109110
if (chatModelConfig.timeout().isPresent()) {
110111
builder.timeout(chatModelConfig.timeout().get());
111112
}
113+
ChatModelConfig.ThinkingConfig thinkingConfig = chatModelConfig.thinking();
114+
if (thinkingConfig.includeThoughts()) {
115+
builder.includeThoughts(thinkingConfig.includeThoughts());
116+
}
117+
if (thinkingConfig.thinkingBudget().isPresent()) {
118+
builder.thinkingBudget(thinkingConfig.thinkingBudget().getAsLong());
119+
}
112120

113121
// TODO: add the rest of the properties
114122

model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/config/ChatModelConfig.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import java.util.Optional;
55
import java.util.OptionalDouble;
66
import java.util.OptionalInt;
7+
import java.util.OptionalLong;
78

9+
import io.quarkiverse.langchain4j.gemini.common.ThinkingConfig;
810
import io.quarkus.runtime.annotations.ConfigDocDefault;
911
import io.quarkus.runtime.annotations.ConfigGroup;
1012
import io.smallrye.config.WithDefault;
@@ -93,4 +95,33 @@ public interface ChatModelConfig {
9395
@WithDefault("${quarkus.langchain4j.ai.gemini.timeout}")
9496
Optional<Duration> timeout();
9597

98+
/**
99+
* Thought related configuration
100+
*/
101+
ThinkingConfig thinking();
102+
103+
interface ThinkingConfig {
104+
105+
/**
106+
* Controls whether thought summaries are enabled.
107+
* Thought summaries are synthesized versions of the model's raw thoughts and offer insights into the model's internal
108+
* reasoning process.
109+
*/
110+
@WithDefault("false")
111+
boolean includeThoughts();
112+
113+
/**
114+
* The thinkingBudget parameter guides the model on the number of thinking tokens to use when generating a response.
115+
* A higher token count generally allows for more detailed reasoning, which can be beneficial for tackling more complex
116+
* tasks.
117+
* If latency is more important, use a lower budget or disable thinking by setting thinkingBudget to 0.
118+
* Setting the thinkingBudget to -1 turns on dynamic thinking, meaning the model will adjust the budget based on the
119+
* complexity of the request.
120+
* <p>
121+
* The thinkingBudget is only supported in Gemini 2.5 Flash, 2.5 Pro, and 2.5 Flash-Lite. Depending on the prompt, the
122+
* model might overflow or underflow the token budget.
123+
* See <a href="https://ai.google.dev/gemini-api/docs/thinking#set-budget">Gemini API docs</a> for more details.
124+
*/
125+
OptionalLong thinkingBudget();
126+
}
96127
}

model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common/BaseGeminiChatModel.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@ public class BaseGeminiChatModel {
1717
protected final Double topP;
1818
protected final ResponseFormat responseFormat;
1919
protected final List<ChatModelListener> listeners;
20+
protected final Long thinkingBudget;
21+
protected final boolean includeThoughts;
2022

2123
public BaseGeminiChatModel(String modelId, Double temperature, Integer maxOutputTokens, Integer topK, Double topP,
22-
ResponseFormat responseFormat, List<ChatModelListener> listeners) {
24+
ResponseFormat responseFormat, List<ChatModelListener> listeners, Long thinkingBudget,
25+
boolean includeThoughts) {
2326
this.modelId = modelId;
2427
this.temperature = temperature;
2528
this.maxOutputTokens = maxOutputTokens;
2629
this.topK = topK;
2730
this.topP = topP;
2831
this.responseFormat = responseFormat;
2932
this.listeners = listeners;
33+
this.thinkingBudget = thinkingBudget;
34+
this.includeThoughts = includeThoughts;
3035
}
3136
}

model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common/GeminiChatLanguageModel.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
public abstract class GeminiChatLanguageModel extends BaseGeminiChatModel implements ChatModel {
3737

3838
public GeminiChatLanguageModel(String modelId, Double temperature, Integer maxOutputTokens, Integer topK,
39-
Double topP, ResponseFormat responseFormat, List<ChatModelListener> listeners) {
40-
super(modelId, temperature, maxOutputTokens, topK, topP, responseFormat, listeners);
39+
Double topP, ResponseFormat responseFormat, List<ChatModelListener> listeners, Long thinkingBudget,
40+
boolean includeThoughts) {
41+
super(modelId, temperature, maxOutputTokens, topK, topP, responseFormat, listeners, thinkingBudget, includeThoughts);
4142
}
4243

4344
@Override
@@ -56,7 +57,7 @@ public Set<Capability> supportedCapabilities() {
5657
public ChatResponse chat(ChatRequest chatRequest) {
5758
ChatRequestParameters requestParameters = chatRequest.parameters();
5859
ResponseFormat effectiveResponseFormat = getOrDefault(requestParameters.responseFormat(), responseFormat);
59-
GenerationConfig generationConfig = GenerationConfig.builder()
60+
GenerationConfig.Builder generationConfigBuilder = GenerationConfig.builder()
6061
.maxOutputTokens(getOrDefault(requestParameters.maxOutputTokens(), this.maxOutputTokens))
6162
.responseMimeType(computeMimeType(effectiveResponseFormat))
6263
.responseSchema(effectiveResponseFormat != null
@@ -65,8 +66,11 @@ public ChatResponse chat(ChatRequest chatRequest) {
6566
.stopSequences(requestParameters.stopSequences())
6667
.temperature(getOrDefault(requestParameters.temperature(), this.temperature))
6768
.topK(getOrDefault(requestParameters.topK(), this.topK))
68-
.topP(getOrDefault(requestParameters.topP(), this.topP))
69-
.build();
69+
.topP(getOrDefault(requestParameters.topP(), this.topP));
70+
if (includeThoughts) {
71+
generationConfigBuilder.thinkingConfig(new ThinkingConfig(thinkingBudget, includeThoughts));
72+
}
73+
GenerationConfig generationConfig = generationConfigBuilder.build();
7074
GenerateContentRequest request = ContentMapper.map(chatRequest.messages(), chatRequest.toolSpecifications(),
7175
generationConfig);
7276

model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common/GeminiStreamingChatLanguageModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public abstract class GeminiStreamingChatLanguageModel extends BaseGeminiChatMod
3636

3737
public GeminiStreamingChatLanguageModel(String modelId, Double temperature, Integer maxOutputTokens, Integer topK,
3838
Double topP, ResponseFormat responseFormat, List<ChatModelListener> listeners) {
39-
super(modelId, temperature, maxOutputTokens, topK, topP, responseFormat, listeners);
39+
super(modelId, temperature, maxOutputTokens, topK, topP, responseFormat, listeners, null, false);
4040
}
4141

4242
@Override

model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common/GenerateContentResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public record Content(List<Part> parts) {
1111

1212
}
1313

14-
public record Part(String text, FunctionCall functionCall) {
14+
public record Part(String text, FunctionCall functionCall, Boolean thought) {
1515

1616
}
1717

model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common/GenerationConfig.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public class GenerationConfig {
1111
private final String responseMimeType;
1212
private final Schema responseSchema;
1313
private final List<String> stopSequences;
14+
private final ThinkingConfig thinkingConfig;
1415

1516
public GenerationConfig(Builder builder) {
1617
this.temperature = builder.temperature;
@@ -20,6 +21,7 @@ public GenerationConfig(Builder builder) {
2021
this.responseMimeType = builder.responseMimeType;
2122
this.responseSchema = builder.responseSchema;
2223
this.stopSequences = builder.stopSequences;
24+
this.thinkingConfig = builder.thinkingConfig;
2325
}
2426

2527
public Double getTemperature() {
@@ -50,6 +52,10 @@ public List<String> getStopSequences() {
5052
return stopSequences;
5153
}
5254

55+
public ThinkingConfig getThinkingConfig() {
56+
return thinkingConfig;
57+
}
58+
5359
public static Builder builder() {
5460
return new Builder();
5561
}
@@ -63,6 +69,7 @@ public static final class Builder {
6369
private String responseMimeType;
6470
private Schema responseSchema;
6571
private List<String> stopSequences;
72+
private ThinkingConfig thinkingConfig;
6673

6774
public Builder temperature(Double temperature) {
6875
this.temperature = temperature;
@@ -99,6 +106,11 @@ public Builder stopSequences(List<String> stopSequences) {
99106
return this;
100107
}
101108

109+
public Builder thinkingConfig(ThinkingConfig thinkingConfig) {
110+
this.thinkingConfig = thinkingConfig;
111+
return this;
112+
}
113+
102114
public GenerationConfig build() {
103115
return new GenerationConfig(this);
104116
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package io.quarkiverse.langchain4j.gemini.common;
2+
3+
public class ThinkingConfig {
4+
5+
private final Long thinkingBudget;
6+
private final Boolean includeThoughts;
7+
8+
public ThinkingConfig(Long thinkingBudget, Boolean includeThoughts) {
9+
this.thinkingBudget = thinkingBudget;
10+
this.includeThoughts = includeThoughts;
11+
}
12+
13+
public Long getThinkingBudget() {
14+
return thinkingBudget;
15+
}
16+
17+
public Boolean getIncludeThoughts() {
18+
return includeThoughts;
19+
}
20+
}

model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class VertexAiGeminiChatLanguageModel extends GeminiChatLanguageModel {
2525

2626
private VertexAiGeminiChatLanguageModel(Builder builder) {
2727
super(builder.modelId, builder.temperature, builder.maxOutputTokens, builder.topK, builder.topP, builder.responseFormat,
28-
builder.listeners);
28+
builder.listeners, null, false);
2929

3030
this.apiMetadata = VertxAiGeminiRestApi.ApiMetadata
3131
.builder()

0 commit comments

Comments
 (0)