diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 7cb87eb8f3b..e742ce32f94 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -266,6 +266,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { generationMetadata = ChatGenerationMetadata.builder() .finishReason(ollamaResponse.doneReason()) + .metadata("thinking", ollamaResponse.message().thinking()) .build(); } @@ -505,7 +506,8 @@ else if (message.getMessageType() == MessageType.TOOL) { OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()) .stream(stream) .messages(ollamaMessages) - .options(requestOptions); + .options(requestOptions) + .think(requestOptions.getThink()); if (requestOptions.getFormat() != null) { requestBuilder.format(requestOptions.getFormat()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index 4679b6e2539..d716b2b3c26 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -23,6 +23,7 @@ * * @author Siarhei Blashuk * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 */ public enum OllamaModel implements ChatModelDescription { @@ -51,6 +52,16 @@ public enum OllamaModel implements ChatModelDescription { */ QWEN3_4B("qwen3:4b"), + /** + * Qwen3 1.7b + */ + QWEN_3_1_7_B("qwen3:1.7b"), + + /** + * Qwen3 0.6b + */ + QWEN_3_06B("qwen3:0.6b"), + /** * QwQ is the reasoning model of the Qwen series. */ diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 84c0752654b..db6988dafe8 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -44,6 +44,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Sun Yuhan * @since 0.8.0 * @see Ollama @@ -353,6 +354,14 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonProperty("truncate") private Boolean truncate; + /** + * The model should think before responding, if supported. + * If this value is not specified, it defaults to null, and Ollama will return + * the thought process within the `content` field of the response, wrapped in `<thinking>` tags. + */ + @JsonProperty("think") + private Boolean think; + @JsonIgnore private Boolean internalToolExecutionEnabled; @@ -400,6 +409,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .format(fromOptions.getFormat()) .keepAlive(fromOptions.getKeepAlive()) .truncate(fromOptions.getTruncate()) + .think(fromOptions.getThink()) .useNUMA(fromOptions.getUseNUMA()) .numCtx(fromOptions.getNumCtx()) .numBatch(fromOptions.getNumBatch()) @@ -827,6 +837,15 @@ public void setTruncate(Boolean truncate) { this.truncate = truncate; } + @Override + public Boolean getThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -927,7 +946,8 @@ public boolean equals(Object o) { && Objects.equals(this.repeatPenalty, that.repeatPenalty) && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) - && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.think, that.think) && Objects.equals(this.mirostat, that.mirostat) + && Objects.equals(this.mirostatTau, that.mirostatTau) && Objects.equals(this.mirostatEta, that.mirostatEta) && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) && Objects.equals(this.toolCallbacks, that.toolCallbacks) @@ -937,13 +957,13 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, - this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, - this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, - this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, - this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext); + return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.think, this.useNUMA, + this.numCtx, this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, + this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, + this.topK, this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, + this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, + this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.toolContext); } @Deprecated @@ -976,6 +996,11 @@ public Builder truncate(Boolean truncate) { return this; } + public Builder think(Boolean think) { + this.options.think = think; + return this; + } + /** * @deprecated Not supported in Ollama anymore. */ diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java new file mode 100644 index 00000000000..b610780a695 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import io.micrometer.observation.tck.TestObservationRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link OllamaChatModel} asserting AI metadata. + * + * @author Sun Yuhan + */ +@SpringBootTest(classes = OllamaChatModelMetadataTests.Config.class) +class OllamaChatModelMetadataTests extends BaseOllamaIT { + + private static final String MODEL = OllamaModel.QWEN_3_06B.getName(); + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + OllamaChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void ollamaThinkingMetadataCaptured() { + var options = OllamaOptions.builder().model(MODEL).think(true).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + assertThat(chatGenerationMetadata.containsKey("thinking")); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenNotSetThinkFlag() { + var options = OllamaOptions.builder().model(MODEL).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenSetThinkFlagToFalse() { + var options = OllamaOptions.builder().model(MODEL).think(false).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public OllamaApi ollamaApi() { + return initializeOllama(MODEL); + } + + @Bean + public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { + return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); + } + + } + +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index b31ba5365f8..471ec5290e6 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -33,10 +33,12 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; /** * @author Christian Tzolov * @author Thomas Vitale + * @author Sun Yuhan */ public class OllamaApiIT extends BaseOllamaIT { @@ -146,4 +148,88 @@ public void think() { assertThat(response.message().thinking()).isNotEmpty(); } + @Test + public void chatWithThinking() { + var request = ChatRequest.builder(THINKING_MODEL) + .stream(true) + .think(true) + .messages(List.of(Message.builder(Role.USER) + .content("What is the capital of Bulgaria and what is the size? " + "What it the national anthem?") + .build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().thinking()) + .collect(Collectors.joining(System.lineSeparator()))).contains("Sofia"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + + @Test + public void streamChatWithThinking() { + var request = ChatRequest.builder(THINKING_MODEL) + .stream(true) + .think(true) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().thinking()) + .collect(Collectors.joining(System.lineSeparator()))).contains("solar"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + + @Test + public void streamChatWithoutThinking() { + var request = ChatRequest.builder(THINKING_MODEL) + .stream(true) + .think(false) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().content()) + .collect(Collectors.joining(System.lineSeparator()))).contains("Earth"); + + assertThat(responses.stream().filter(r -> r.message() != null).allMatch(r -> r.message().thinking() == null)) + .isTrue(); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 9f051ac0597..1184cecbd9c 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -83,6 +83,15 @@ public interface ChatOptions extends ModelOptions { @Nullable Double getTopP(); + /** + * Returns the think flag to use for the chat. + * @return the think flag to use for the chat + */ + @Nullable + default Boolean getThink() { + return false; + } + /** * Returns a copy of this {@link ChatOptions}. * @return a copy of this {@link ChatOptions} @@ -158,6 +167,13 @@ interface Builder { */ Builder topP(Double topP); + /** + * Builds with the think to use for the chat. + * @param think Whether to enable thinking mode + * @return the builder. + */ + Builder think(Boolean think); + /** * Build the {@link ChatOptions}. * @return the Chat options. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java index 1af33bf3467..21f0bf56da8 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -41,6 +41,8 @@ public class DefaultChatOptions implements ChatOptions { private Double topP; + private Boolean think; + @Override public String getModel() { return this.model; @@ -113,6 +115,15 @@ public void setTopP(Double topP) { this.topP = topP; } + @Override + public Boolean getThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -125,6 +136,7 @@ public T copy() { copy.setTemperature(this.getTemperature()); copy.setTopK(this.getTopK()); copy.setTopP(this.getTopP()); + copy.setThink(this.getThink()); return (T) copy; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java index 47ba5840109..a317c8c8106 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -73,6 +73,11 @@ public DefaultChatOptionsBuilder topP(Double topP) { return this; } + public DefaultChatOptionsBuilder think(Boolean think) { + this.options.setThink(think); + return this; + } + public ChatOptions build() { return this.options.copy(); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index e088994139b..9a3fc275ad5 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -70,6 +70,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Double topP; + @Nullable + private Boolean think; + @Override public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); @@ -198,6 +201,16 @@ public void setTopP(@Nullable Double topP) { this.topP = topP; } + @Override + @Nullable + public Boolean getThink() { + return this.think; + } + + public void setThink(@Nullable Boolean think) { + this.think = think; + } + @Override @SuppressWarnings("unchecked") public T copy() { @@ -325,6 +338,12 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) { return this; } + @Override + public ToolCallingChatOptions.Builder think(Boolean think) { + this.options.setThink(think); + return this; + } + @Override public ToolCallingChatOptions build() { return this.options; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..9cbdbe80c86 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -219,6 +219,9 @@ interface Builder extends ChatOptions.Builder { @Override Builder topP(@Nullable Double topP); + @Override + Builder think(@Nullable Boolean think); + @Override ToolCallingChatOptions build(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java index 995141f2cf6..8f3c34086e0 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java @@ -53,6 +53,7 @@ void shouldBuildWithAllOptions() { .topP(1.0) .topK(40) .stopSequences(List.of("stop1", "stop2")) + .think(true) .build(); assertThat(options.getModel()).isEqualTo("gpt-4"); @@ -60,6 +61,7 @@ void shouldBuildWithAllOptions() { assertThat(options.getTemperature()).isEqualTo(0.7); assertThat(options.getTopP()).isEqualTo(1.0); assertThat(options.getTopK()).isEqualTo(40); + assertThat(options.getThink()).isEqualTo(true); assertThat(options.getStopSequences()).containsExactly("stop1", "stop2"); } @@ -82,6 +84,7 @@ void shouldCopyOptions() { .temperature(0.7) .topP(1.0) .topK(40) + .think(true) .stopSequences(List.of("stop1", "stop2")) .build(); @@ -107,6 +110,7 @@ void shouldUpcastToChatOptions() { .temperature(0.7) .topP(1.0) .topK(40) + .think(true) .stopSequences(List.of("stop1", "stop2")) .toolNames(Set.of("function1", "function2")) .toolCallbacks(List.of(callback)) @@ -121,6 +125,7 @@ void shouldUpcastToChatOptions() { assertThat(chatOptions.getTemperature()).isEqualTo(0.7); assertThat(chatOptions.getTopP()).isEqualTo(1.0); assertThat(chatOptions.getTopK()).isEqualTo(40); + assertThat(chatOptions.getThink()).isEqualTo(true); assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2"); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index e4f0aa812c5..60411bc9748 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -188,6 +188,7 @@ void builderShouldCreateOptionsWithAllProperties() { .stopSequences(List.of("stop")) .topK(3) .topP(0.9) + .think(true) .build(); assertThat(options).satisfies(o -> { @@ -203,6 +204,7 @@ void builderShouldCreateOptionsWithAllProperties() { assertThat(o.getStopSequences()).containsExactly("stop"); assertThat(o.getTopK()).isEqualTo(3); assertThat(o.getTopP()).isEqualTo(0.9); + assertThat(o.getThink()).isEqualTo(true); }); }