Skip to content

Commit db67682

Browse files
sunyuhan1998ralla0405
authored andcommitted
fix: spring-projectsGH-4414 OllamaApiHelper
- Fixed the issue in `OllamaApiHelper` where `thinking` and `toolName` were omitted when merging messages; added unit tests for `OllamaApiHelper` Auto-cherry-pick to 1.0.x Fixes spring-projects#4414 Signed-off-by: Sun Yuhan <[email protected]> Signed-off-by: logan-mac <[email protected]>
1 parent d143b4e commit db67682

File tree

2 files changed

+271
-3
lines changed

2 files changed

+271
-3
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2024-2024 the original author or authors.
2+
* Copyright 2024-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@
2525

2626
/**
2727
* @author Christian Tzolov
28+
* @author Sun Yuhan
2829
* @since 1.0.0
2930
*/
3031
public final class OllamaApiHelper {
@@ -81,12 +82,20 @@ public static ChatResponse merge(ChatResponse previous, ChatResponse current) {
8182
private static OllamaApi.Message merge(OllamaApi.Message previous, OllamaApi.Message current) {
8283

8384
String content = mergeContent(previous, current);
85+
String thinking = mergeThinking(previous, current);
8486
OllamaApi.Message.Role role = (current.role() != null ? current.role() : previous.role());
8587
role = (role != null ? role : OllamaApi.Message.Role.ASSISTANT);
8688
List<String> images = mergeImages(previous, current);
8789
List<OllamaApi.Message.ToolCall> toolCalls = mergeToolCall(previous, current);
88-
89-
return OllamaApi.Message.builder(role).content(content).images(images).toolCalls(toolCalls).build();
90+
String toolName = mergeToolName(previous, current);
91+
92+
return OllamaApi.Message.builder(role)
93+
.content(content)
94+
.thinking(thinking)
95+
.images(images)
96+
.toolCalls(toolCalls)
97+
.toolName(toolName)
98+
.build();
9099
}
91100

92101
private static Instant merge(Instant previous, Instant current) {
@@ -145,6 +154,28 @@ private static List<OllamaApi.Message.ToolCall> mergeToolCall(OllamaApi.Message
145154
return merge(previous.toolCalls(), current.toolCalls());
146155
}
147156

157+
private static String mergeThinking(OllamaApi.Message previous, OllamaApi.Message current) {
158+
if (previous == null || previous.thinking() == null) {
159+
return (current != null ? current.thinking() : null);
160+
}
161+
if (current == null || current.thinking() == null) {
162+
return (previous.thinking());
163+
}
164+
165+
return previous.thinking() + current.thinking();
166+
}
167+
168+
private static String mergeToolName(OllamaApi.Message previous, OllamaApi.Message current) {
169+
if (previous == null || previous.toolName() == null) {
170+
return (current != null ? current.toolName() : null);
171+
}
172+
if (current == null || current.toolName() == null) {
173+
return (previous.toolName());
174+
}
175+
176+
return previous.toolName() + current.toolName();
177+
}
178+
148179
private static List<String> mergeImages(OllamaApi.Message previous, OllamaApi.Message current) {
149180
if (previous == null) {
150181
return (current != null ? current.images() : null);
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.ollama.api;
18+
19+
import java.time.Instant;
20+
import java.util.Arrays;
21+
import java.util.Collections;
22+
import java.util.List;
23+
24+
import org.junit.jupiter.api.Test;
25+
import org.junit.jupiter.api.extension.ExtendWith;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.mockito.Mockito.mock;
30+
import static org.mockito.Mockito.when;
31+
32+
/**
33+
* Tests for {@link OllamaApiHelper}
34+
*
35+
* @author Sun Yuhan
36+
*/
37+
@ExtendWith(MockitoExtension.class)
38+
class OllamaApiHelperTests {
39+
40+
@Test
41+
void isStreamingToolCallWhenResponseIsNullShouldReturnFalse() {
42+
boolean result = OllamaApiHelper.isStreamingToolCall(null);
43+
assertThat(result).isFalse();
44+
}
45+
46+
@Test
47+
void isStreamingToolCallWhenMessageIsNullShouldReturnFalse() {
48+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
49+
when(response.message()).thenReturn(null);
50+
51+
boolean result = OllamaApiHelper.isStreamingToolCall(response);
52+
assertThat(result).isFalse();
53+
}
54+
55+
@Test
56+
void isStreamingToolCallWhenToolCallsIsNullShouldReturnFalse() {
57+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
58+
OllamaApi.Message message = mock(OllamaApi.Message.class);
59+
when(response.message()).thenReturn(message);
60+
when(message.toolCalls()).thenReturn(null);
61+
62+
boolean result = OllamaApiHelper.isStreamingToolCall(response);
63+
assertThat(result).isFalse();
64+
}
65+
66+
@Test
67+
void isStreamingToolCallWhenToolCallsIsEmptyShouldReturnFalse() {
68+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
69+
OllamaApi.Message message = mock(OllamaApi.Message.class);
70+
when(response.message()).thenReturn(message);
71+
when(message.toolCalls()).thenReturn(Collections.emptyList());
72+
73+
boolean result = OllamaApiHelper.isStreamingToolCall(response);
74+
assertThat(result).isFalse();
75+
}
76+
77+
@Test
78+
void isStreamingToolCallWhenToolCallsHasElementsShouldReturnTrue() {
79+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
80+
OllamaApi.Message message = mock(OllamaApi.Message.class);
81+
List<OllamaApi.Message.ToolCall> toolCalls = Arrays.asList(mock(OllamaApi.Message.ToolCall.class));
82+
when(response.message()).thenReturn(message);
83+
when(message.toolCalls()).thenReturn(toolCalls);
84+
85+
boolean result = OllamaApiHelper.isStreamingToolCall(response);
86+
assertThat(result).isTrue();
87+
}
88+
89+
@Test
90+
void isStreamingDoneWhenResponseIsNullShouldReturnFalse() {
91+
boolean result = OllamaApiHelper.isStreamingDone(null);
92+
assertThat(result).isFalse();
93+
}
94+
95+
@Test
96+
void isStreamingDoneWhenDoneIsFalseShouldReturnFalse() {
97+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
98+
when(response.done()).thenReturn(false);
99+
100+
boolean result = OllamaApiHelper.isStreamingDone(response);
101+
assertThat(result).isFalse();
102+
}
103+
104+
@Test
105+
void isStreamingDoneWhenDoneReasonIsNotStopShouldReturnFalse() {
106+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
107+
when(response.done()).thenReturn(true);
108+
when(response.doneReason()).thenReturn("other");
109+
110+
boolean result = OllamaApiHelper.isStreamingDone(response);
111+
assertThat(result).isFalse();
112+
}
113+
114+
@Test
115+
void isStreamingDoneWhenDoneIsTrueAndDoneReasonIsStopShouldReturnTrue() {
116+
OllamaApi.ChatResponse response = mock(OllamaApi.ChatResponse.class);
117+
when(response.done()).thenReturn(true);
118+
when(response.doneReason()).thenReturn("stop");
119+
120+
boolean result = OllamaApiHelper.isStreamingDone(response);
121+
assertThat(result).isTrue();
122+
}
123+
124+
@Test
125+
void mergeWhenBothResponsesHaveValuesShouldMergeCorrectly() {
126+
Instant previousCreatedAt = Instant.now().minusSeconds(10);
127+
OllamaApi.Message previousMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
128+
.content("Previous content")
129+
.thinking("Previous thinking")
130+
.images(Arrays.asList("image1"))
131+
.toolCalls(Arrays.asList(mock(OllamaApi.Message.ToolCall.class)))
132+
.toolName("Previous tool")
133+
.build();
134+
135+
OllamaApi.ChatResponse previous = new OllamaApi.ChatResponse("previous-model", previousCreatedAt,
136+
previousMessage, "previous-reason", false, 100L, 50L, 10, 200L, 5, 100L);
137+
138+
Instant currentCreatedAt = Instant.now();
139+
OllamaApi.Message currentMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.USER)
140+
.content("Current content")
141+
.thinking("Current thinking")
142+
.images(Arrays.asList("image2"))
143+
.toolCalls(Arrays.asList(mock(OllamaApi.Message.ToolCall.class)))
144+
.toolName("Current tool")
145+
.build();
146+
147+
OllamaApi.ChatResponse current = new OllamaApi.ChatResponse("current-model", currentCreatedAt, currentMessage,
148+
"stop", true, 200L, 100L, 20, 400L, 10, 200L);
149+
150+
OllamaApi.ChatResponse result = OllamaApiHelper.merge(previous, current);
151+
152+
assertThat(result.model()).isEqualTo("previous-modelcurrent-model");
153+
assertThat(result.createdAt()).isEqualTo(currentCreatedAt);
154+
assertThat(result.message().content()).isEqualTo("Previous contentCurrent content");
155+
assertThat(result.message().thinking()).isEqualTo("Previous thinkingCurrent thinking");
156+
assertThat(result.message().role()).isEqualTo(OllamaApi.Message.Role.USER);
157+
assertThat(result.message().images()).containsExactly("image1", "image2");
158+
assertThat(result.message().toolCalls()).hasSize(2);
159+
assertThat(result.message().toolName()).isEqualTo("Previous toolCurrent tool");
160+
assertThat(result.doneReason()).isEqualTo("stop");
161+
assertThat(result.done()).isTrue();
162+
assertThat(result.totalDuration()).isEqualTo(300L);
163+
assertThat(result.loadDuration()).isEqualTo(150L);
164+
assertThat(result.promptEvalCount()).isEqualTo(30);
165+
assertThat(result.promptEvalDuration()).isEqualTo(600L);
166+
assertThat(result.evalCount()).isEqualTo(15);
167+
assertThat(result.evalDuration()).isEqualTo(300L);
168+
}
169+
170+
@Test
171+
void mergeStringsShouldConcatenate() {
172+
OllamaApi.Message previousMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
173+
.content("Hello")
174+
.thinking("Think")
175+
.toolName("Tool")
176+
.build();
177+
OllamaApi.ChatResponse previous = new OllamaApi.ChatResponse("model1", Instant.now(), previousMessage,
178+
"reason1", false, null, null, null, null, null, null);
179+
180+
OllamaApi.Message currentMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
181+
.content(" World")
182+
.thinking("ing")
183+
.toolName("Box")
184+
.build();
185+
OllamaApi.ChatResponse current = new OllamaApi.ChatResponse("model2", Instant.now(), currentMessage, "reason2",
186+
true, null, null, null, null, null, null);
187+
188+
OllamaApi.ChatResponse result = OllamaApiHelper.merge(previous, current);
189+
190+
assertThat(result.model()).isEqualTo("model1model2");
191+
assertThat(result.message().content()).isEqualTo("Hello World");
192+
assertThat(result.message().thinking()).isEqualTo("Thinking");
193+
assertThat(result.message().toolName()).isEqualTo("ToolBox");
194+
assertThat(result.doneReason()).isEqualTo("reason2");
195+
assertThat(result.done()).isTrue();
196+
}
197+
198+
@Test
199+
void mergeNumbersShouldSum() {
200+
OllamaApi.Message dummyMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).build();
201+
202+
OllamaApi.ChatResponse previous = new OllamaApi.ChatResponse(null, null, dummyMessage, null, null, 100L, 50L,
203+
10, 200L, 5, 100L);
204+
205+
OllamaApi.ChatResponse current = new OllamaApi.ChatResponse(null, null, dummyMessage, null, null, 200L, 100L,
206+
20, 400L, 10, 200L);
207+
208+
OllamaApi.ChatResponse result = OllamaApiHelper.merge(previous, current);
209+
210+
assertThat(result.totalDuration()).isEqualTo(300L);
211+
assertThat(result.loadDuration()).isEqualTo(150L);
212+
assertThat(result.promptEvalCount()).isEqualTo(30);
213+
assertThat(result.promptEvalDuration()).isEqualTo(600L);
214+
assertThat(result.evalCount()).isEqualTo(15);
215+
assertThat(result.evalDuration()).isEqualTo(300L);
216+
}
217+
218+
@Test
219+
void mergeListsShouldCombine() {
220+
OllamaApi.Message previousMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
221+
.images(Arrays.asList("image1", "image2"))
222+
.build();
223+
OllamaApi.ChatResponse previous = new OllamaApi.ChatResponse(null, null, previousMessage, null, null, null,
224+
null, null, null, null, null);
225+
226+
OllamaApi.Message currentMessage = OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
227+
.images(Arrays.asList("image3", "image4"))
228+
.build();
229+
OllamaApi.ChatResponse current = new OllamaApi.ChatResponse(null, null, currentMessage, null, null, null, null,
230+
null, null, null, null);
231+
232+
OllamaApi.ChatResponse result = OllamaApiHelper.merge(previous, current);
233+
234+
assertThat(result.message().images()).containsExactly("image1", "image2", "image3", "image4");
235+
}
236+
237+
}

0 commit comments

Comments
 (0)