diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index e05130ffb7f..526e1a2d317 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; @@ -32,7 +33,7 @@ import reactor.core.scheduler.Schedulers; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; @@ -84,6 +85,7 @@ * @author luocongqiu * @author Ilayaperumal Gopinathan * @author Alexandros Pappas + * @author Nicolas Krier * @since 1.0.0 */ public class MistralAiChatModel implements ChatModel { @@ -425,52 +427,12 @@ Prompt buildRequestPrompt(Prompt prompt) { * Accessible for testing. */ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { - List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { - if (message instanceof UserMessage userMessage) { - Object content = message.getText(); - - if (!CollectionUtils.isEmpty(userMessage.getMedia())) { - List contentList = new ArrayList<>( - List.of(new ChatCompletionMessage.MediaContent(message.getText()))); - - contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); - - content = contentList; - } - - return List - .of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER)); - } - else if (message instanceof SystemMessage systemMessage) { - return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(), - MistralAiApi.ChatCompletionMessage.Role.SYSTEM)); - } - else if (message instanceof AssistantMessage assistantMessage) { - List toolCalls = null; - if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { - toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { - var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); - return new ToolCall(toolCall.id(), toolCall.type(), function, null); - }).toList(); - } - - return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(), - MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null)); - } - else if (message instanceof ToolResponseMessage toolResponseMessage) { - toolResponseMessage.getResponses() - .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); - - return toolResponseMessage.getResponses() - .stream() - .map(toolResponse -> new MistralAiApi.ChatCompletionMessage(toolResponse.responseData(), - MistralAiApi.ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id())) - .toList(); - } - else { - throw new IllegalStateException("Unexpected message type: " + message); - } - }).flatMap(List::stream).toList(); + // @formatter:off + List chatCompletionMessages = prompt.getInstructions() + .stream() + .flatMap(this::createChatCompletionMessages) + .toList(); + // @formatter:on var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream); @@ -488,6 +450,78 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) { return request; } + private Stream createChatCompletionMessages(Message message) { + switch (message.getMessageType()) { + case USER: + return Stream.of(createUserChatCompletionMessage(message)); + case SYSTEM: + return Stream.of(createSystemChatCompletionMessage(message)); + case ASSISTANT: + return Stream.of(createAssistantChatCompletionMessage(message)); + case TOOL: + return createToolChatCompletionMessages(message); + default: + throw new IllegalStateException("Unknown message type: " + message.getMessageType()); + } + } + + private Stream createToolChatCompletionMessages(Message message) { + if (message instanceof ToolResponseMessage toolResponseMessage) { + var chatCompletionMessages = new ArrayList(); + + for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) { + Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage.ToolResponse must have an id."); + var chatCompletionMessage = new ChatCompletionMessage(toolResponse.responseData(), + ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()); + chatCompletionMessages.add(chatCompletionMessage); + } + + return chatCompletionMessages.stream(); + } + else { + throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName()); + } + } + + private ChatCompletionMessage createAssistantChatCompletionMessage(Message message) { + if (message instanceof AssistantMessage assistantMessage) { + List toolCalls = null; + + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList(); + } + + return new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, null, + toolCalls, null); + } + else { + throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName()); + } + } + + private ChatCompletionMessage createSystemChatCompletionMessage(Message message) { + return new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM); + } + + private ChatCompletionMessage createUserChatCompletionMessage(Message message) { + Object content = message.getText(); + + if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) { + List contentList = new ArrayList<>( + List.of(new ChatCompletionMessage.MediaContent(message.getText()))); + contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); + content = contentList; + } + + return new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER); + } + + private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) { + var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); + + return new ToolCall(toolCall.id(), toolCall.type(), function, null); + } + private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) { return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl( this.fromMediaData(media.getMimeType(), media.getData()))); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java deleted file mode 100644 index e6bf2490cc0..00000000000 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mistralai; - -import java.util.Map; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.definition.DefaultToolDefinition; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.boot.test.context.SpringBootTest; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Ricken Bazolo - * @author Alexandros Pappas - * @author Thomas Vitale - * @since 0.8.1 - */ -@SpringBootTest(classes = MistralAiTestConfiguration.class) -@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") -public class MistralAiChatCompletionRequestTest { - - MistralAiChatModel chatModel = MistralAiChatModel.builder().mistralAiApi(new MistralAiApi("test")).build(); - - @Test - void chatCompletionDefaultRequestTest() { - var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content")); - var request = this.chatModel.createRequest(prompt, false); - - assertThat(request.messages()).hasSize(1); - assertThat(request.topP()).isEqualTo(1); - assertThat(request.temperature()).isEqualTo(0.7); - assertThat(request.safePrompt()).isFalse(); - assertThat(request.maxTokens()).isNull(); - assertThat(request.stream()).isFalse(); - } - - @Test - void chatCompletionRequestWithOptionsTest() { - var options = MistralAiChatOptions.builder().temperature(0.5).topP(0.8).build(); - var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content", options)); - var request = this.chatModel.createRequest(prompt, true); - - assertThat(request.messages().size()).isEqualTo(1); - assertThat(request.topP()).isEqualTo(0.8); - assertThat(request.temperature()).isEqualTo(0.5); - assertThat(request.stream()).isTrue(); - } - - @Test - void whenToolRuntimeOptionsThenMergeWithDefaults() { - MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder() - .model("DEFAULT_MODEL") - .internalToolExecutionEnabled(true) - .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) - .toolNames("tool1", "tool2") - .toolContext(Map.of("key1", "value1", "key2", "valueA")) - .build(); - - MistralAiChatModel chatModel = MistralAiChatModel.builder() - .mistralAiApi(new MistralAiApi("test")) - .defaultOptions(defaultOptions) - .build(); - - MistralAiChatOptions runtimeOptions = MistralAiChatOptions.builder() - .internalToolExecutionEnabled(false) - .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) - .toolNames("tool3") - .toolContext(Map.of("key2", "valueB")) - .build(); - Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); - - assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() - .stream() - .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4"); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); - assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") - .containsEntry("key2", "valueB"); - } - - static class TestToolCallback implements ToolCallback { - - private final ToolDefinition toolDefinition; - - TestToolCallback(String name) { - this.toolDefinition = DefaultToolDefinition.builder().name(name).inputSchema("{}").build(); - } - - @Override - public ToolDefinition getToolDefinition() { - return this.toolDefinition; - } - - @Override - public String call(String toolInput) { - return "Mission accomplished!"; - } - - } - -} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java new file mode 100644 index 00000000000..163cf2a7119 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java @@ -0,0 +1,320 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.AbstractMessage; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Ricken Bazolo + * @author Alexandros Pappas + * @author Thomas Vitale + * @author Nicolas Krier + * @since 0.8.1 + */ +class MistralAiChatCompletionRequestTests { + + private static final String BASE_URL = "https://faked.url"; + + private static final String API_KEY = "FAKED_API_KEY"; + + private static final String TEXT_CONTENT = "Hello world!"; + + private static final String IMAGE_URL = "https://example.com/image.png"; + + private static final Media IMAGE_MEDIA = new Media(Media.Format.IMAGE_PNG, URI.create(IMAGE_URL)); + + private final MistralAiChatModel chatModel = MistralAiChatModel.builder() + .mistralAiApi(new MistralAiApi(BASE_URL, API_KEY)) + .build(); + + @Test + void chatCompletionDefaultRequestTest() { + var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content")); + var request = this.chatModel.createRequest(prompt, false); + + assertThat(request.messages()).hasSize(1); + assertThat(request.topP()).isEqualTo(1); + assertThat(request.temperature()).isEqualTo(0.7); + assertThat(request.safePrompt()).isFalse(); + assertThat(request.maxTokens()).isNull(); + assertThat(request.stream()).isFalse(); + } + + @Test + void chatCompletionRequestWithOptionsTest() { + var options = MistralAiChatOptions.builder().temperature(0.5).topP(0.8).build(); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("test content", options)); + var request = this.chatModel.createRequest(prompt, true); + + assertThat(request.messages()).hasSize(1); + assertThat(request.topP()).isEqualTo(0.8); + assertThat(request.temperature()).isEqualTo(0.5); + assertThat(request.stream()).isTrue(); + } + + @Test + void whenToolRuntimeOptionsThenMergeWithDefaults() { + MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder() + .model("DEFAULT_MODEL") + .internalToolExecutionEnabled(true) + .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) + .toolNames("tool1", "tool2") + .toolContext(Map.of("key1", "value1", "key2", "valueA")) + .build(); + + MistralAiChatModel anotherChatModel = MistralAiChatModel.builder() + .mistralAiApi(new MistralAiApi(BASE_URL, API_KEY)) + .defaultOptions(defaultOptions) + .build(); + + MistralAiChatOptions runtimeOptions = MistralAiChatOptions.builder() + .internalToolExecutionEnabled(false) + .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) + .toolNames("tool3") + .toolContext(Map.of("key2", "valueB")) + .build(); + Prompt prompt = anotherChatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); + + assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() + .stream() + .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") + .containsEntry("key2", "valueB"); + } + + @Test + void createChatCompletionMessagesWithUserMessage() { + var userMessage = new UserMessage(TEXT_CONTENT); + userMessage.getMedia().add(IMAGE_MEDIA); + var prompt = createPrompt(userMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + verifyUserChatCompletionMessages(chatCompletionRequest.messages()); + } + + @Test + void createChatCompletionMessagesWithSimpleUserMessage() { + var simpleUserMessage = new SimpleMessage(MessageType.USER, TEXT_CONTENT); + var prompt = createPrompt(simpleUserMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + var chatCompletionMessages = chatCompletionRequest.messages(); + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER); + assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT); + } + + @Test + void createChatCompletionMessagesWithSystemMessage() { + var systemMessage = new SystemMessage(TEXT_CONTENT); + var prompt = createPrompt(systemMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + verifySystemChatCompletionMessages(chatCompletionRequest.messages()); + } + + @Test + void createChatCompletionMessagesWithSimpleSystemMessage() { + var simpleSystemMessage = new SimpleMessage(MessageType.SYSTEM, TEXT_CONTENT); + var prompt = createPrompt(simpleSystemMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + verifySystemChatCompletionMessages(chatCompletionRequest.messages()); + } + + @Test + void createChatCompletionMessagesWithAssistantMessage() { + var toolCall1 = createToolCall(1); + var toolCall2 = createToolCall(2); + var toolCall3 = createToolCall(3); + // @formatter:off + var assistantMessage = AssistantMessage.builder() + .content(TEXT_CONTENT) + .toolCalls(List.of(toolCall1, toolCall2, toolCall3)) + .build(); + // @formatter:on + var prompt = createPrompt(assistantMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + var chatCompletionMessages = chatCompletionRequest.messages(); + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.ASSISTANT); + assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT); + var toolCalls = chatCompletionMessage.toolCalls(); + assertThat(toolCalls).hasSize(3); + verifyToolCall(toolCalls.get(0), toolCall1); + verifyToolCall(toolCalls.get(1), toolCall2); + verifyToolCall(toolCalls.get(2), toolCall3); + } + + @Test + void createChatCompletionMessagesWithSimpleAssistantMessage() { + var simpleAssistantMessage = new SimpleMessage(MessageType.ASSISTANT, TEXT_CONTENT); + var prompt = createPrompt(simpleAssistantMessage); + assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unsupported assistant message class: " + SimpleMessage.class.getName()); + } + + @Test + void createChatCompletionMessagesWithToolResponseMessage() { + var toolResponse1 = createToolResponse(1); + var toolResponse2 = createToolResponse(2); + var toolResponse3 = createToolResponse(3); + var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse1, toolResponse2, toolResponse3)); + var prompt = createPrompt(toolResponseMessage); + var chatCompletionRequest = this.chatModel.createRequest(prompt, false); + var chatCompletionMessages = chatCompletionRequest.messages(); + assertThat(chatCompletionMessages).hasSize(3); + verifyToolChatCompletionMessage(chatCompletionMessages.get(0), toolResponse1); + verifyToolChatCompletionMessage(chatCompletionMessages.get(1), toolResponse2); + verifyToolChatCompletionMessage(chatCompletionMessages.get(2), toolResponse3); + } + + @Test + void createChatCompletionMessagesWithInvalidToolResponseMessage() { + var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null); + var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse)); + var prompt = createPrompt(toolResponseMessage); + assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("ToolResponseMessage.ToolResponse must have an id."); + } + + @Test + void createChatCompletionMessagesWithSimpleToolMessage() { + var simpleToolMessage = new SimpleMessage(MessageType.TOOL, TEXT_CONTENT); + var prompt = createPrompt(simpleToolMessage); + assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Unsupported tool message class: " + SimpleMessage.class.getName()); + } + + private Prompt createPrompt(Message message) { + var chatOptions = MistralAiChatOptions.builder().temperature(0.7d).build(); + var prompt = new Prompt(message, chatOptions); + + return this.chatModel.buildRequestPrompt(prompt); + } + + private static void verifyToolChatCompletionMessage(ChatCompletionMessage chatCompletionMessage, + ToolResponseMessage.ToolResponse toolResponse) { + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.TOOL); + assertThat(chatCompletionMessage.content()).isEqualTo(toolResponse.responseData()); + assertThat(chatCompletionMessage.name()).isEqualTo(toolResponse.name()); + assertThat(chatCompletionMessage.toolCalls()).isNull(); + assertThat(chatCompletionMessage.toolCallId()).isEqualTo(toolResponse.id()); + } + + private static ToolResponseMessage.ToolResponse createToolResponse(int number) { + return new ToolResponseMessage.ToolResponse("id" + number, "name" + number, "responseData" + number); + } + + private static void verifyToolCall(ChatCompletionMessage.ToolCall mistralToolCall, + AssistantMessage.ToolCall toolCall) { + assertThat(mistralToolCall.id()).isEqualTo(toolCall.id()); + assertThat(mistralToolCall.type()).isEqualTo(toolCall.type()); + var function = mistralToolCall.function(); + assertThat(function).isNotNull(); + assertThat(function.name()).isEqualTo(toolCall.name()); + assertThat(function.arguments()).isEqualTo(toolCall.arguments()); + } + + private static AssistantMessage.ToolCall createToolCall(int number) { + return new AssistantMessage.ToolCall("id" + number, "type" + number, "name" + number, "arguments " + number); + } + + private static void verifySystemChatCompletionMessages(List chatCompletionMessages) { + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.SYSTEM); + assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT); + } + + private static void verifyUserChatCompletionMessages(List chatCompletionMessages) { + assertThat(chatCompletionMessages).hasSize(1); + var chatCompletionMessage = chatCompletionMessages.get(0); + assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER); + var rawContent = chatCompletionMessage.rawContent(); + assertThat(rawContent).isNotNull(); + var maps = (List>) rawContent; + assertThat(maps).hasSize(2); + // @formatter:off + var textMap = maps.get(0); + assertThat(textMap).hasSize(2) + .containsEntry("type", "text") + .containsEntry("text", TEXT_CONTENT); + var imageUrlMap = maps.get(1); + assertThat(imageUrlMap).hasSize(2) + .containsEntry("type", "image_url") + .containsEntry("image_url", Map.of("url", IMAGE_URL)); + // @formatter:on + } + + static class SimpleMessage extends AbstractMessage { + + SimpleMessage(MessageType messageType, String textContent) { + super(messageType, textContent, Map.of()); + } + + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + TestToolCallback(String name) { + this.toolDefinition = DefaultToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return this.toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + +}