Skip to content

Commit 71416fc

Browse files
committed
Use Ollama system role for system message
Signed-off-by: Nicolas Krier <[email protected]>
1 parent 4889131 commit 71416fc

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,10 @@ Prompt buildRequestPrompt(Prompt prompt) {
439439
OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
440440

441441
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
442-
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
442+
if (message.getMessageType() == MessageType.SYSTEM) {
443+
return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(message.getText()).build());
444+
}
445+
else if (message.getMessageType() == MessageType.USER) {
443446
var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getText());
444447
if (message instanceof UserMessage userMessage) {
445448
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616

1717
package org.springframework.ai.ollama;
1818

19-
import java.util.Map;
20-
2119
import org.junit.jupiter.api.Test;
22-
20+
import org.springframework.ai.chat.messages.*;
2321
import org.springframework.ai.chat.prompt.ChatOptions;
2422
import org.springframework.ai.chat.prompt.Prompt;
2523
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -30,16 +28,20 @@
3028
import org.springframework.ai.tool.definition.DefaultToolDefinition;
3129
import org.springframework.ai.tool.definition.ToolDefinition;
3230

31+
import java.util.List;
32+
import java.util.Map;
33+
3334
import static org.assertj.core.api.Assertions.assertThat;
3435

3536
/**
3637
* @author Christian Tzolov
3738
* @author Thomas Vitale
3839
* @author Alexandros Pappas
40+
* @author Nicolas Krier
3941
*/
4042
class OllamaChatRequestTests {
4143

42-
OllamaChatModel chatModel = OllamaChatModel.builder()
44+
private final OllamaChatModel chatModel = OllamaChatModel.builder()
4345
.ollamaApi(OllamaApi.builder().build())
4446
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
4547
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
@@ -167,6 +169,54 @@ public void createRequestWithDefaultOptionsModelOverride() {
167169
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
168170
}
169171

172+
@Test
173+
void createRequestWithAllMessageTypes() {
174+
var prompt = this.chatModel.buildRequestPrompt(new Prompt(createMessagesWithAllMessageTypes()));
175+
176+
var request = this.chatModel.ollamaChatRequest(prompt, false);
177+
178+
assertThat(request.messages()).hasSize(6);
179+
180+
var ollamaSystemMessage = request.messages().get(0);
181+
assertThat(ollamaSystemMessage.role()).isEqualTo(OllamaApi.Message.Role.SYSTEM);
182+
assertThat(ollamaSystemMessage.content()).isEqualTo("Test system message");
183+
184+
var ollamaUserMessage = request.messages().get(1);
185+
assertThat(ollamaUserMessage.role()).isEqualTo(OllamaApi.Message.Role.USER);
186+
assertThat(ollamaUserMessage.content()).isEqualTo("Test user message");
187+
188+
var ollamaToolResponse1 = request.messages().get(2);
189+
assertThat(ollamaToolResponse1.role()).isEqualTo(OllamaApi.Message.Role.TOOL);
190+
assertThat(ollamaToolResponse1.content()).isEqualTo("Test tool response 1");
191+
192+
var ollamaToolResponse2 = request.messages().get(3);
193+
assertThat(ollamaToolResponse2.role()).isEqualTo(OllamaApi.Message.Role.TOOL);
194+
assertThat(ollamaToolResponse2.content()).isEqualTo("Test tool response 2");
195+
196+
var ollamaToolResponse3 = request.messages().get(4);
197+
assertThat(ollamaToolResponse3.role()).isEqualTo(OllamaApi.Message.Role.TOOL);
198+
assertThat(ollamaToolResponse3.content()).isEqualTo("Test tool response 3");
199+
200+
var ollamaAssistantMessage = request.messages().get(5);
201+
assertThat(ollamaAssistantMessage.role()).isEqualTo(OllamaApi.Message.Role.ASSISTANT);
202+
assertThat(ollamaAssistantMessage.content()).isEqualTo("Test assistant message");
203+
}
204+
205+
private static List<Message> createMessagesWithAllMessageTypes() {
206+
var systemMessage = new SystemMessage("Test system message");
207+
var userMessage = new UserMessage("Test user message");
208+
// @formatter:off
209+
var toolResponseMessage = new ToolResponseMessage(List.of(
210+
new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"),
211+
new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"),
212+
new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3"))
213+
);
214+
// @formatter:on
215+
var assistantMessage = new AssistantMessage("Test assistant message");
216+
217+
return List.of(systemMessage, userMessage, toolResponseMessage, assistantMessage);
218+
}
219+
170220
static class TestToolCallback implements ToolCallback {
171221

172222
private final ToolDefinition toolDefinition;

0 commit comments

Comments
 (0)