Skip to content

Commit 524f12c

Browse files
committed
feat: Added support for the "think" field in Ollama
1. Added the `think` field to Ollama's `ChatRequest` 2. Added the `thinking` field to Ollama's `Message` 3. Added the `think` property to `OllamaOptions`, allowing users to specify whether to enable or disable thinking Signed-off-by: Sun Yuhan <[email protected]>
1 parent a03f7d5 commit 524f12c

File tree

7 files changed

+114
-14
lines changed

7 files changed

+114
-14
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ else if (message instanceof ToolResponseMessage toolMessage) {
460460
OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel())
461461
.stream(stream)
462462
.messages(ollamaMessages)
463-
.options(requestOptions);
463+
.options(requestOptions)
464+
.think(requestOptions.getThink());
464465

465466
if (requestOptions.getFormat() != null) {
466467
requestBuilder.format(requestOptions.getFormat());

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
* @author Christian Tzolov
5252
* @author Thomas Vitale
5353
* @author Jonghoon Park
54+
* @author Sun Yuhan
5455
* @since 0.8.0
5556
*/
5657
// @formatter:off
@@ -251,6 +252,7 @@ public Flux<ProgressResponse> pullModel(PullModelRequest pullModelRequest) {
251252
*
252253
* @param role The role of the message of type {@link Role}.
253254
* @param content The content of the message.
255+
* @param thinking The thinking of the model.
254256
* @param images The list of base64-encoded images to send with the message.
255257
* Requires multimodal models such as llava or bakllava.
256258
* @param toolCalls The relevant tool call.
@@ -260,6 +262,7 @@ public Flux<ProgressResponse> pullModel(PullModelRequest pullModelRequest) {
260262
public record Message(
261263
@JsonProperty("role") Role role,
262264
@JsonProperty("content") String content,
265+
@JsonProperty("thinking") String thinking,
263266
@JsonProperty("images") List<String> images,
264267
@JsonProperty("tool_calls") List<ToolCall> toolCalls) {
265268

@@ -321,6 +324,7 @@ public static class Builder {
321324

322325
private final Role role;
323326
private String content;
327+
private String thinking;
324328
private List<String> images;
325329
private List<ToolCall> toolCalls;
326330

@@ -333,6 +337,11 @@ public Builder content(String content) {
333337
return this;
334338
}
335339

340+
public Builder thinking(String thinking) {
341+
this.thinking = thinking;
342+
return this;
343+
}
344+
336345
public Builder images(List<String> images) {
337346
this.images = images;
338347
return this;
@@ -344,7 +353,7 @@ public Builder toolCalls(List<ToolCall> toolCalls) {
344353
}
345354

346355
public Message build() {
347-
return new Message(this.role, this.content, this.images, this.toolCalls);
356+
return new Message(this.role, this.content, this.thinking, this.images, this.toolCalls);
348357
}
349358
}
350359
}
@@ -359,6 +368,7 @@ public Message build() {
359368
* @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m).
360369
* @param tools List of tools the model has access to.
361370
* @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it.
371+
* @param think The model should think before responding, if the model supports it.
362372
* You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map.
363373
*
364374
* @see <a href=
@@ -375,7 +385,8 @@ public record ChatRequest(
375385
@JsonProperty("format") Object format,
376386
@JsonProperty("keep_alive") String keepAlive,
377387
@JsonProperty("tools") List<Tool> tools,
378-
@JsonProperty("options") Map<String, Object> options
388+
@JsonProperty("options") Map<String, Object> options,
389+
@JsonProperty("think") Boolean think
379390
) {
380391

381392
public static Builder builder(String model) {
@@ -448,6 +459,7 @@ public static class Builder {
448459
private String keepAlive;
449460
private List<Tool> tools = List.of();
450461
private Map<String, Object> options = Map.of();
462+
private boolean think;
451463

452464
public Builder(String model) {
453465
Assert.notNull(model, "The model can not be null.");
@@ -492,8 +504,13 @@ public Builder options(OllamaOptions options) {
492504
return this;
493505
}
494506

507+
public Builder think(boolean think) {
508+
this.think = think;
509+
return this;
510+
}
511+
495512
public ChatRequest build() {
496-
return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options);
513+
return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options, this.think);
497514
}
498515
}
499516
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
/**
2727
* @author Christian Tzolov
28+
* @author Sun Yuhan
2829
* @since 1.0.0
2930
*/
3031
public final class OllamaApiHelper {
@@ -81,12 +82,18 @@ public static ChatResponse merge(ChatResponse previous, ChatResponse current) {
8182
private static OllamaApi.Message merge(OllamaApi.Message previous, OllamaApi.Message current) {
8283

8384
String content = mergeContent(previous, current);
85+
String thinking = mergeThinking(previous, current);
8486
OllamaApi.Message.Role role = (current.role() != null ? current.role() : previous.role());
8587
role = (role != null ? role : OllamaApi.Message.Role.ASSISTANT);
8688
List<String> images = mergeImages(previous, current);
8789
List<OllamaApi.Message.ToolCall> toolCalls = mergeToolCall(previous, current);
8890

89-
return OllamaApi.Message.builder(role).content(content).images(images).toolCalls(toolCalls).build();
91+
return OllamaApi.Message.builder(role)
92+
.content(content)
93+
.thinking(thinking)
94+
.images(images)
95+
.toolCalls(toolCalls)
96+
.build();
9097
}
9198

9299
private static Instant merge(Instant previous, Instant current) {
@@ -134,6 +141,17 @@ private static String mergeContent(OllamaApi.Message previous, OllamaApi.Message
134141
return previous.content() + current.content();
135142
}
136143

144+
private static String mergeThinking(OllamaApi.Message previous, OllamaApi.Message current) {
145+
if (previous == null || previous.thinking() == null) {
146+
return (current != null ? current.thinking() : null);
147+
}
148+
if (current == null || current.thinking() == null) {
149+
return (previous != null ? previous.thinking() : null);
150+
}
151+
152+
return previous.thinking() + current.thinking();
153+
}
154+
137155
private static List<OllamaApi.Message.ToolCall> mergeToolCall(OllamaApi.Message previous,
138156
OllamaApi.Message current) {
139157
if (previous == null) {

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
*
2424
* @author Siarhei Blashuk
2525
* @author Thomas Vitale
26+
* @author Sun Yuhan
2627
* @since 1.0.0
2728
*/
2829
public enum OllamaModel implements ChatModelDescription {
@@ -32,6 +33,16 @@ public enum OllamaModel implements ChatModelDescription {
3233
*/
3334
QWEN_2_5_7B("qwen2.5"),
3435

36+
/**
37+
* Qwen3
38+
*/
39+
QWEN_3_8B("qwen3"),
40+
41+
/**
42+
* Qwen3 4b
43+
*/
44+
QWEN_3_4B("qwen3:4b"),
45+
3546
/**
3647
* QwQ is the reasoning model of the Qwen series.
3748
*/

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
* @author Christian Tzolov
4545
* @author Thomas Vitale
4646
* @author Ilayaperumal Gopinathan
47+
* @author Sun Yuhan
4748
* @since 0.8.0
4849
* @see <a href=
4950
* "https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values">Ollama
@@ -318,6 +319,14 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions {
318319
@JsonProperty("truncate")
319320
private Boolean truncate;
320321

322+
/**
323+
* The model should think before responding, if supported.
324+
* If this value is not specified, it defaults to null, and Ollama will return
325+
* the thought process within the `content` field of the response, wrapped in `&lt;thinking&gt;` tags.
326+
*/
327+
@JsonProperty("think")
328+
private Boolean think;
329+
321330
@JsonIgnore
322331
private Boolean internalToolExecutionEnabled;
323332

@@ -365,6 +374,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
365374
.format(fromOptions.getFormat())
366375
.keepAlive(fromOptions.getKeepAlive())
367376
.truncate(fromOptions.getTruncate())
377+
.think(fromOptions.getThink())
368378
.useNUMA(fromOptions.getUseNUMA())
369379
.numCtx(fromOptions.getNumCtx())
370380
.numBatch(fromOptions.getNumBatch())
@@ -704,6 +714,14 @@ public void setTruncate(Boolean truncate) {
704714
this.truncate = truncate;
705715
}
706716

717+
public Boolean getThink() {
718+
return this.think;
719+
}
720+
721+
public void setThink(Boolean think) {
722+
this.think = think;
723+
}
724+
707725
@Override
708726
@JsonIgnore
709727
public List<ToolCallback> getToolCallbacks() {
@@ -804,7 +822,8 @@ public boolean equals(Object o) {
804822
&& Objects.equals(this.repeatPenalty, that.repeatPenalty)
805823
&& Objects.equals(this.presencePenalty, that.presencePenalty)
806824
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
807-
&& Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau)
825+
&& Objects.equals(this.think, that.think) && Objects.equals(this.mirostat, that.mirostat)
826+
&& Objects.equals(this.mirostatTau, that.mirostatTau)
808827
&& Objects.equals(this.mirostatEta, that.mirostatEta)
809828
&& Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop)
810829
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
@@ -814,13 +833,13 @@ public boolean equals(Object o) {
814833

815834
@Override
816835
public int hashCode() {
817-
return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx,
818-
this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly,
819-
this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK,
820-
this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
821-
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
822-
this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
823-
this.toolContext);
836+
return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.think, this.useNUMA,
837+
this.numCtx, this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll,
838+
this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict,
839+
this.topK, this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature,
840+
this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau,
841+
this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames,
842+
this.internalToolExecutionEnabled, this.toolContext);
824843
}
825844

826845
public static class Builder {
@@ -852,6 +871,11 @@ public Builder truncate(Boolean truncate) {
852871
return this;
853872
}
854873

874+
public Builder think(Boolean think) {
875+
this.options.think = think;
876+
return this;
877+
}
878+
855879
public Builder useNUMA(Boolean useNUMA) {
856880
this.options.useNUMA = useNUMA;
857881
return this;

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public final class OllamaImage {
2525

26-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2");
26+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.9.0");
2727

2828
private OllamaImage() {
2929

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
3434

3535
import static org.assertj.core.api.Assertions.assertThat;
36+
import static org.junit.jupiter.api.Assertions.assertNull;
3637

3738
/**
3839
* @author Christian Tzolov
@@ -114,4 +115,32 @@ public void embedText() {
114115
assertThat(response.totalDuration()).isGreaterThan(1);
115116
}
116117

118+
@Test
119+
public void chatWithThinking() {
120+
var request = ChatRequest.builder(MODEL)
121+
.stream(true)
122+
.think(true)
123+
.messages(List.of(Message.builder(Role.USER)
124+
.content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?")
125+
.build()))
126+
.options(OllamaOptions.builder().temperature(0.9).build().toMap())
127+
.build();
128+
129+
Flux<ChatResponse> response = getOllamaApi().streamingChat(request);
130+
131+
List<ChatResponse> responses = response.collectList().block();
132+
System.out.println(responses);
133+
134+
assertThat(responses).isNotNull();
135+
assertThat(responses.stream()
136+
.filter(r -> r.message() != null)
137+
.map(r -> r.message().thinking())
138+
.collect(Collectors.joining(System.lineSeparator()))).contains("Sofia");
139+
140+
ChatResponse lastResponse = responses.get(responses.size() - 1);
141+
assertThat(lastResponse.message().content()).isEmpty();
142+
assertNull(lastResponse.message().thinking());
143+
assertThat(lastResponse.done()).isTrue();
144+
}
145+
117146
}

0 commit comments

Comments
 (0)