diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 7316e90b017..d9a8a441225 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -439,7 +439,10 @@ Prompt buildRequestPrompt(Prompt prompt) { OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) { List ollamaMessages = prompt.getInstructions().stream().map(message -> { - if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { + if (message.getMessageType() == MessageType.SYSTEM) { + return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(message.getText()).build()); + } + else if (message.getMessageType() == MessageType.USER) { var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getText()); if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index d03de073b7e..bda1304b867 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -16,10 +16,16 @@ package org.springframework.ai.ollama; +import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.tool.ToolCallingChatOptions; @@ -36,10 +42,11 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Alexandros Pappas + * @author Nicolas Krier */ class OllamaChatRequestTests { - OllamaChatModel chatModel = OllamaChatModel.builder() + private final OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(OllamaApi.builder().build()) .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) @@ -167,6 +174,54 @@ public void createRequestWithDefaultOptionsModelOverride() { assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } + @Test + void createRequestWithAllMessageTypes() { + var prompt = this.chatModel.buildRequestPrompt(new Prompt(createMessagesWithAllMessageTypes())); + + var request = this.chatModel.ollamaChatRequest(prompt, false); + + assertThat(request.messages()).hasSize(6); + + var ollamaSystemMessage = request.messages().get(0); + assertThat(ollamaSystemMessage.role()).isEqualTo(OllamaApi.Message.Role.SYSTEM); + assertThat(ollamaSystemMessage.content()).isEqualTo("Test system message"); + + var ollamaUserMessage = request.messages().get(1); + assertThat(ollamaUserMessage.role()).isEqualTo(OllamaApi.Message.Role.USER); + assertThat(ollamaUserMessage.content()).isEqualTo("Test user message"); + + var ollamaToolResponse1 = request.messages().get(2); + assertThat(ollamaToolResponse1.role()).isEqualTo(OllamaApi.Message.Role.TOOL); + assertThat(ollamaToolResponse1.content()).isEqualTo("Test tool response 1"); + + var ollamaToolResponse2 = request.messages().get(3); + assertThat(ollamaToolResponse2.role()).isEqualTo(OllamaApi.Message.Role.TOOL); + assertThat(ollamaToolResponse2.content()).isEqualTo("Test tool response 2"); + + var ollamaToolResponse3 = request.messages().get(4); + assertThat(ollamaToolResponse3.role()).isEqualTo(OllamaApi.Message.Role.TOOL); + assertThat(ollamaToolResponse3.content()).isEqualTo("Test tool response 3"); + + var ollamaAssistantMessage = request.messages().get(5); + assertThat(ollamaAssistantMessage.role()).isEqualTo(OllamaApi.Message.Role.ASSISTANT); + assertThat(ollamaAssistantMessage.content()).isEqualTo("Test assistant message"); + } + + private static List createMessagesWithAllMessageTypes() { + var systemMessage = new SystemMessage("Test system message"); + var userMessage = new UserMessage("Test user message"); + // @formatter:off + var toolResponseMessage = new ToolResponseMessage(List.of( + new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"), + new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"), + new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3")) + ); + // @formatter:on + var assistantMessage = new AssistantMessage("Test assistant message"); + + return List.of(systemMessage, userMessage, toolResponseMessage, assistantMessage); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition;