Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
Expand All @@ -32,7 +33,7 @@
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
Expand Down Expand Up @@ -84,6 +85,7 @@
* @author luocongqiu
* @author Ilayaperumal Gopinathan
* @author Alexandros Pappas
* @author Nicolas Krier
* @since 1.0.0
*/
public class MistralAiChatModel implements ChatModel {
Expand Down Expand Up @@ -425,52 +427,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
* Accessible for testing.
*/
MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
if (message instanceof UserMessage userMessage) {
Object content = message.getText();

if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
List.of(new ChatCompletionMessage.MediaContent(message.getText())));

contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());

content = contentList;
}

return List
.of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER));
}
else if (message instanceof SystemMessage systemMessage) {
return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(),
MistralAiApi.ChatCompletionMessage.Role.SYSTEM));
}
else if (message instanceof AssistantMessage assistantMessage) {
List<ToolCall> toolCalls = null;
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
}).toList();
}

return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(),
MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
}
else if (message instanceof ToolResponseMessage toolResponseMessage) {
toolResponseMessage.getResponses()
.forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));

return toolResponseMessage.getResponses()
.stream()
.map(toolResponse -> new MistralAiApi.ChatCompletionMessage(toolResponse.responseData(),
MistralAiApi.ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()))
.toList();
}
else {
throw new IllegalStateException("Unexpected message type: " + message);
}
}).flatMap(List::stream).toList();
// @formatter:off
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
.stream()
.flatMap(this::createChatCompletionMessages)
.toList();
// @formatter:on

var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream);

Expand All @@ -488,6 +450,78 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
return request;
}

private Stream<ChatCompletionMessage> createChatCompletionMessages(Message message) {
switch (message.getMessageType()) {
case USER:
return Stream.of(createUserChatCompletionMessage(message));
case SYSTEM:
return Stream.of(createSystemChatCompletionMessage(message));
case ASSISTANT:
return Stream.of(createAssistantChatCompletionMessage(message));
case TOOL:
return createToolChatCompletionMessages(message);
default:
throw new IllegalStateException("Unknown message type: " + message.getMessageType());
}
}

private Stream<ChatCompletionMessage> createToolChatCompletionMessages(Message message) {
if (message instanceof ToolResponseMessage toolResponseMessage) {
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();

for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage.ToolResponse must have an id.");
var chatCompletionMessage = new ChatCompletionMessage(toolResponse.responseData(),
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id());
chatCompletionMessages.add(chatCompletionMessage);
}

return chatCompletionMessages.stream();
}
else {
throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName());
}
}

private ChatCompletionMessage createAssistantChatCompletionMessage(Message message) {
if (message instanceof AssistantMessage assistantMessage) {
List<ToolCall> toolCalls = null;

if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
}

return new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, null,
toolCalls, null);
}
else {
throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName());
}
}

private ChatCompletionMessage createSystemChatCompletionMessage(Message message) {
return new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM);
}

private ChatCompletionMessage createUserChatCompletionMessage(Message message) {
Object content = message.getText();

if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
content = contentList;
}

return new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER);
}

private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) {
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());

return new ToolCall(toolCall.id(), toolCall.type(), function, null);
}

private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl(
this.fromMediaData(media.getMimeType(), media.getData())));
Expand Down

This file was deleted.

Loading