Skip to content

Commit b188adc

Browse files
committed
feat: Propagate the thinking returned by Ollama back to ChatGenerationMetadata, and added corresponding unit tests.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent f77e08a commit b188adc

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
249249
if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
250250
generationMetadata = ChatGenerationMetadata.builder()
251251
.finishReason(ollamaResponse.doneReason())
252+
.metadata("thinking", ollamaResponse.message().thinking())
252253
.build();
253254
}
254255

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.ollama;
18+
19+
import io.micrometer.observation.tck.TestObservationRegistry;
20+
import org.junit.jupiter.api.BeforeEach;
21+
import org.junit.jupiter.api.Test;
22+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
23+
import org.springframework.ai.chat.model.ChatResponse;
24+
import org.springframework.ai.chat.prompt.Prompt;
25+
import org.springframework.ai.ollama.api.OllamaApi;
26+
import org.springframework.ai.ollama.api.OllamaModel;
27+
import org.springframework.ai.ollama.api.OllamaOptions;
28+
import org.springframework.beans.factory.annotation.Autowired;
29+
import org.springframework.boot.SpringBootConfiguration;
30+
import org.springframework.boot.test.context.SpringBootTest;
31+
import org.springframework.context.annotation.Bean;
32+
33+
import static org.assertj.core.api.Assertions.assertThat;
34+
35+
/**
36+
* Unit Tests for {@link OllamaChatModel} asserting AI metadata.
37+
*
38+
* @author Sun Yuhan
39+
*/
40+
@SpringBootTest(classes = OllamaChatModelMetadataTests.Config.class)
41+
class OllamaChatModelMetadataTests extends BaseOllamaIT {
42+
43+
private static final String MODEL = OllamaModel.QWEN_3_06B.getName();
44+
45+
@Autowired
46+
TestObservationRegistry observationRegistry;
47+
48+
@Autowired
49+
OllamaChatModel chatModel;
50+
51+
@BeforeEach
52+
void beforeEach() {
53+
this.observationRegistry.clear();
54+
}
55+
56+
@Test
57+
void ollamaThinkingMetadataCaptured() {
58+
var options = OllamaOptions.builder().model(MODEL).think(true).build();
59+
60+
Prompt prompt = new Prompt("Why is the sky blue?", options);
61+
62+
ChatResponse chatResponse = this.chatModel.call(prompt);
63+
assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty();
64+
65+
chatResponse.getResults().forEach(generation -> {
66+
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();
67+
assertThat(chatGenerationMetadata).isNotNull();
68+
var thinking = chatGenerationMetadata.get("thinking");
69+
assertThat(thinking).isNotNull();
70+
});
71+
}
72+
73+
@Test
74+
void ollamaThinkingMetadataNotCapturedWhenNotSetThinkFlag() {
75+
var options = OllamaOptions.builder().model(MODEL).build();
76+
77+
Prompt prompt = new Prompt("Why is the sky blue?", options);
78+
79+
ChatResponse chatResponse = this.chatModel.call(prompt);
80+
assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty();
81+
82+
chatResponse.getResults().forEach(generation -> {
83+
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();
84+
assertThat(chatGenerationMetadata).isNotNull();
85+
var thinking = chatGenerationMetadata.get("thinking");
86+
assertThat(thinking).isNull();
87+
});
88+
}
89+
90+
@Test
91+
void ollamaThinkingMetadataNotCapturedWhenSetThinkFlagToFalse() {
92+
var options = OllamaOptions.builder().model(MODEL).think(false).build();
93+
94+
Prompt prompt = new Prompt("Why is the sky blue?", options);
95+
96+
ChatResponse chatResponse = this.chatModel.call(prompt);
97+
assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty();
98+
99+
chatResponse.getResults().forEach(generation -> {
100+
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();
101+
assertThat(chatGenerationMetadata).isNotNull();
102+
var thinking = chatGenerationMetadata.get("thinking");
103+
assertThat(thinking).isNull();
104+
});
105+
}
106+
107+
@SpringBootConfiguration
108+
static class Config {
109+
110+
@Bean
111+
public TestObservationRegistry observationRegistry() {
112+
return TestObservationRegistry.create();
113+
}
114+
115+
@Bean
116+
public OllamaApi ollamaApi() {
117+
return initializeOllama(MODEL);
118+
}
119+
120+
@Bean
121+
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
122+
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
123+
}
124+
125+
}
126+
127+
}

0 commit comments

Comments
 (0)