Skip to content

Commit 552a346

Browse files
committed
feat: Add support for the think flag at the ChatMode level.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent b188adc commit 552a346

File tree

9 files changed

+66
-3
lines changed

9 files changed

+66
-3
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ else if (message instanceof ToolResponseMessage toolMessage) {
462462
.stream(stream)
463463
.messages(ollamaMessages)
464464
.options(requestOptions)
465-
.think(requestOptions.getThink());
465+
.think(requestOptions.isThink());
466466

467467
if (requestOptions.getFormat() != null) {
468468
requestBuilder.format(requestOptions.getFormat());

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
374374
.format(fromOptions.getFormat())
375375
.keepAlive(fromOptions.getKeepAlive())
376376
.truncate(fromOptions.getTruncate())
377-
.think(fromOptions.getThink())
377+
.think(fromOptions.isThink())
378378
.useNUMA(fromOptions.getUseNUMA())
379379
.numCtx(fromOptions.getNumCtx())
380380
.numBatch(fromOptions.getNumBatch())
@@ -714,7 +714,8 @@ public void setTruncate(Boolean truncate) {
714714
this.truncate = truncate;
715715
}
716716

717-
public Boolean getThink() {
717+
@Override
718+
public Boolean isThink() {
718719
return this.think;
719720
}
720721

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ public interface ChatOptions extends ModelOptions {
8383
@Nullable
8484
Double getTopP();
8585

86+
/**
87+
* Returns the think flag to use for the chat.
88+
* @return the think flag to use for the chat
89+
*/
90+
@Nullable
91+
default Boolean isThink() {
92+
return false;
93+
}
94+
8695
/**
8796
* Returns a copy of this {@link ChatOptions}.
8897
* @return a copy of this {@link ChatOptions}
@@ -158,6 +167,13 @@ interface Builder {
158167
*/
159168
Builder topP(Double topP);
160169

170+
/**
171+
* Builds with the think to use for the chat.
172+
* @param think Whether to enable thinking mode
173+
* @return the builder.
174+
*/
175+
Builder think(Boolean think);
176+
161177
/**
162178
* Build the {@link ChatOptions}.
163179
* @return the Chat options.

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ public class DefaultChatOptions implements ChatOptions {
4141

4242
private Double topP;
4343

44+
private Boolean think;
45+
4446
@Override
4547
public String getModel() {
4648
return this.model;
@@ -113,6 +115,15 @@ public void setTopP(Double topP) {
113115
this.topP = topP;
114116
}
115117

118+
@Override
119+
public Boolean isThink() {
120+
return this.think;
121+
}
122+
123+
public void setThink(Boolean think) {
124+
this.think = think;
125+
}
126+
116127
@Override
117128
@SuppressWarnings("unchecked")
118129
public <T extends ChatOptions> T copy() {
@@ -125,6 +136,7 @@ public <T extends ChatOptions> T copy() {
125136
copy.setTemperature(this.getTemperature());
126137
copy.setTopK(this.getTopK());
127138
copy.setTopP(this.getTopP());
139+
copy.setThink(this.isThink());
128140
return (T) copy;
129141
}
130142

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ public DefaultChatOptionsBuilder topP(Double topP) {
7373
return this;
7474
}
7575

76+
public DefaultChatOptionsBuilder think(Boolean think) {
77+
this.options.setThink(think);
78+
return this;
79+
}
80+
7681
public ChatOptions build() {
7782
return this.options.copy();
7883
}

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
7070
@Nullable
7171
private Double topP;
7272

73+
@Nullable
74+
private Boolean think;
75+
7376
@Override
7477
public List<ToolCallback> getToolCallbacks() {
7578
return List.copyOf(this.toolCallbacks);
@@ -198,6 +201,16 @@ public void setTopP(@Nullable Double topP) {
198201
this.topP = topP;
199202
}
200203

204+
@Override
205+
@Nullable
206+
public Boolean isThink() {
207+
return this.think;
208+
}
209+
210+
public void setThink(@Nullable Boolean think) {
211+
this.think = think;
212+
}
213+
201214
@Override
202215
@SuppressWarnings("unchecked")
203216
public <T extends ChatOptions> T copy() {
@@ -325,6 +338,12 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) {
325338
return this;
326339
}
327340

341+
@Override
342+
public ToolCallingChatOptions.Builder think(Boolean think) {
343+
this.options.setThink(think);
344+
return this;
345+
}
346+
328347
@Override
329348
public ToolCallingChatOptions build() {
330349
return this.options;

spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ interface Builder extends ChatOptions.Builder {
219219
@Override
220220
Builder topP(@Nullable Double topP);
221221

222+
@Override
223+
Builder think(@Nullable Boolean think);
224+
222225
@Override
223226
ToolCallingChatOptions build();
224227

spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ void shouldBuildWithAllOptions() {
5353
.topP(1.0)
5454
.topK(40)
5555
.stopSequences(List.of("stop1", "stop2"))
56+
.think(true)
5657
.build();
5758

5859
assertThat(options.getModel()).isEqualTo("gpt-4");
5960
assertThat(options.getMaxTokens()).isEqualTo(100);
6061
assertThat(options.getTemperature()).isEqualTo(0.7);
6162
assertThat(options.getTopP()).isEqualTo(1.0);
6263
assertThat(options.getTopK()).isEqualTo(40);
64+
assertThat(options.isThink()).isEqualTo(true);
6365
assertThat(options.getStopSequences()).containsExactly("stop1", "stop2");
6466
}
6567

@@ -82,6 +84,7 @@ void shouldCopyOptions() {
8284
.temperature(0.7)
8385
.topP(1.0)
8486
.topK(40)
87+
.think(true)
8588
.stopSequences(List.of("stop1", "stop2"))
8689
.build();
8790

@@ -107,6 +110,7 @@ void shouldUpcastToChatOptions() {
107110
.temperature(0.7)
108111
.topP(1.0)
109112
.topK(40)
113+
.think(true)
110114
.stopSequences(List.of("stop1", "stop2"))
111115
.toolNames(Set.of("function1", "function2"))
112116
.toolCallbacks(List.of(callback))
@@ -121,6 +125,7 @@ void shouldUpcastToChatOptions() {
121125
assertThat(chatOptions.getTemperature()).isEqualTo(0.7);
122126
assertThat(chatOptions.getTopP()).isEqualTo(1.0);
123127
assertThat(chatOptions.getTopK()).isEqualTo(40);
128+
assertThat(chatOptions.isThink()).isEqualTo(true);
124129
assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2");
125130
}
126131

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void builderShouldCreateOptionsWithAllProperties() {
188188
.stopSequences(List.of("stop"))
189189
.topK(3)
190190
.topP(0.9)
191+
.think(true)
191192
.build();
192193

193194
assertThat(options).satisfies(o -> {
@@ -203,6 +204,7 @@ void builderShouldCreateOptionsWithAllProperties() {
203204
assertThat(o.getStopSequences()).containsExactly("stop");
204205
assertThat(o.getTopK()).isEqualTo(3);
205206
assertThat(o.getTopP()).isEqualTo(0.9);
207+
assertThat(o.isThink()).isEqualTo(true);
206208
});
207209
}
208210

0 commit comments

Comments
 (0)