|
16 | 16 |
|
17 | 17 | package org.springframework.ai.ollama; |
18 | 18 |
|
19 | | -import java.util.Map; |
20 | | - |
21 | 19 | import org.junit.jupiter.api.Test; |
22 | | - |
| 20 | +import org.springframework.ai.chat.messages.*; |
23 | 21 | import org.springframework.ai.chat.prompt.ChatOptions; |
24 | 22 | import org.springframework.ai.chat.prompt.Prompt; |
25 | 23 | import org.springframework.ai.model.tool.ToolCallingChatOptions; |
|
30 | 28 | import org.springframework.ai.tool.definition.DefaultToolDefinition; |
31 | 29 | import org.springframework.ai.tool.definition.ToolDefinition; |
32 | 30 |
|
| 31 | +import java.util.List; |
| 32 | +import java.util.Map; |
| 33 | + |
33 | 34 | import static org.assertj.core.api.Assertions.assertThat; |
34 | 35 |
|
35 | 36 | /** |
36 | 37 | * @author Christian Tzolov |
37 | 38 | * @author Thomas Vitale |
38 | 39 | * @author Alexandros Pappas |
| 40 | + * @author Nicolas Krier |
39 | 41 | */ |
40 | 42 | class OllamaChatRequestTests { |
41 | 43 |
|
42 | | - OllamaChatModel chatModel = OllamaChatModel.builder() |
| 44 | + private final OllamaChatModel chatModel = OllamaChatModel.builder() |
43 | 45 | .ollamaApi(OllamaApi.builder().build()) |
44 | 46 | .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) |
45 | 47 | .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) |
@@ -167,6 +169,54 @@ public void createRequestWithDefaultOptionsModelOverride() { |
167 | 169 | assertThat(request.model()).isEqualTo("PROMPT_MODEL"); |
168 | 170 | } |
169 | 171 |
|
| 172 | + @Test |
| 173 | + void createRequestWithAllMessageTypes() { |
| 174 | + var prompt = this.chatModel.buildRequestPrompt(new Prompt(createMessagesWithAllMessageTypes())); |
| 175 | + |
| 176 | + var request = this.chatModel.ollamaChatRequest(prompt, false); |
| 177 | + |
| 178 | + assertThat(request.messages()).hasSize(6); |
| 179 | + |
| 180 | + var ollamaSystemMessage = request.messages().get(0); |
| 181 | + assertThat(ollamaSystemMessage.role()).isEqualTo(OllamaApi.Message.Role.SYSTEM); |
| 182 | + assertThat(ollamaSystemMessage.content()).isEqualTo("Test system message"); |
| 183 | + |
| 184 | + var ollamaUserMessage = request.messages().get(1); |
| 185 | + assertThat(ollamaUserMessage.role()).isEqualTo(OllamaApi.Message.Role.USER); |
| 186 | + assertThat(ollamaUserMessage.content()).isEqualTo("Test user message"); |
| 187 | + |
| 188 | + var ollamaToolResponse1 = request.messages().get(2); |
| 189 | + assertThat(ollamaToolResponse1.role()).isEqualTo(OllamaApi.Message.Role.TOOL); |
| 190 | + assertThat(ollamaToolResponse1.content()).isEqualTo("Test tool response 1"); |
| 191 | + |
| 192 | + var ollamaToolResponse2 = request.messages().get(3); |
| 193 | + assertThat(ollamaToolResponse2.role()).isEqualTo(OllamaApi.Message.Role.TOOL); |
| 194 | + assertThat(ollamaToolResponse2.content()).isEqualTo("Test tool response 2"); |
| 195 | + |
| 196 | + var ollamaToolResponse3 = request.messages().get(4); |
| 197 | + assertThat(ollamaToolResponse3.role()).isEqualTo(OllamaApi.Message.Role.TOOL); |
| 198 | + assertThat(ollamaToolResponse3.content()).isEqualTo("Test tool response 3"); |
| 199 | + |
| 200 | + var ollamaAssistantMessage = request.messages().get(5); |
| 201 | + assertThat(ollamaAssistantMessage.role()).isEqualTo(OllamaApi.Message.Role.ASSISTANT); |
| 202 | + assertThat(ollamaAssistantMessage.content()).isEqualTo("Test assistant message"); |
| 203 | + } |
| 204 | + |
| 205 | + private static List<Message> createMessagesWithAllMessageTypes() { |
| 206 | + var systemMessage = new SystemMessage("Test system message"); |
| 207 | + var userMessage = new UserMessage("Test user message"); |
| 208 | + // @formatter:off |
| 209 | + var toolResponseMessage = new ToolResponseMessage(List.of( |
| 210 | + new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"), |
| 211 | + new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"), |
| 212 | + new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3")) |
| 213 | + ); |
| 214 | + // @formatter:on |
| 215 | + var assistantMessage = new AssistantMessage("Test assistant message"); |
| 216 | + |
| 217 | + return List.of(systemMessage, userMessage, toolResponseMessage, assistantMessage); |
| 218 | + } |
| 219 | + |
170 | 220 | static class TestToolCallback implements ToolCallback { |
171 | 221 |
|
172 | 222 | private final ToolDefinition toolDefinition; |
|
0 commit comments