Skip to content

Commit e1a951c

Browse files
committed
Refactor request creation of Mistral AI chat model
Signed-off-by: Nicolas Krier <[email protected]>
1 parent e56eb12 commit e1a951c

File tree

2 files changed

+277
-57
lines changed

2 files changed

+277
-57
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.stream.Stream;
2425

2526
import io.micrometer.observation.Observation;
2627
import io.micrometer.observation.ObservationRegistry;
@@ -32,9 +33,8 @@
3233
import reactor.core.scheduler.Schedulers;
3334

3435
import org.springframework.ai.chat.messages.AssistantMessage;
35-
import org.springframework.ai.chat.messages.SystemMessage;
36+
import org.springframework.ai.chat.messages.Message;
3637
import org.springframework.ai.chat.messages.ToolResponseMessage;
37-
import org.springframework.ai.chat.messages.UserMessage;
3838
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3939
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
4040
import org.springframework.ai.chat.metadata.DefaultUsage;
@@ -50,6 +50,7 @@
5050
import org.springframework.ai.chat.prompt.ChatOptions;
5151
import org.springframework.ai.chat.prompt.Prompt;
5252
import org.springframework.ai.content.Media;
53+
import org.springframework.ai.content.MediaContent;
5354
import org.springframework.ai.mistralai.api.MistralAiApi;
5455
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
5556
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion.Choice;
@@ -84,6 +85,7 @@
8485
* @author luocongqiu
8586
* @author Ilayaperumal Gopinathan
8687
* @author Alexandros Pappas
88+
* @author Nicolas Krier
8789
* @since 1.0.0
8890
*/
8991
public class MistralAiChatModel implements ChatModel {
@@ -425,67 +427,89 @@ Prompt buildRequestPrompt(Prompt prompt) {
425427
* Accessible for testing.
426428
*/
427429
MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
428-
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
429-
if (message instanceof UserMessage userMessage) {
430+
// @formatter:off
431+
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
432+
.stream()
433+
.flatMap(this::createMessages)
434+
.toList();
435+
// @formatter:on
436+
437+
var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
438+
439+
MistralAiChatOptions requestOptions = (MistralAiChatOptions) prompt.getOptions();
440+
request = ModelOptionsUtils.merge(requestOptions, request, MistralAiApi.ChatCompletionRequest.class);
441+
442+
// Add the tool definitions to the request's tools parameter.
443+
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
444+
if (!CollectionUtils.isEmpty(toolDefinitions)) {
445+
request = ModelOptionsUtils.merge(
446+
MistralAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,
447+
ChatCompletionRequest.class);
448+
}
449+
450+
return request;
451+
}
452+
453+
/**
454+
* Accessible for testing.
455+
*/
456+
Stream<ChatCompletionMessage> createMessages(Message message) {
457+
return switch (message.getMessageType()) {
458+
case USER -> {
430459
Object content = message.getText();
431460

432-
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
461+
if (message instanceof MediaContent mediaContent && !CollectionUtils.isEmpty(mediaContent.getMedia())) {
433462
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
434463
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
435464

436-
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
465+
contentList.addAll(mediaContent.getMedia().stream().map(this::mapToMediaContent).toList());
437466

438467
content = contentList;
439468
}
440469

441-
return List
442-
.of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER));
443-
}
444-
else if (message instanceof SystemMessage systemMessage) {
445-
return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(),
446-
MistralAiApi.ChatCompletionMessage.Role.SYSTEM));
470+
yield Stream.of(new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER));
447471
}
448-
else if (message instanceof AssistantMessage assistantMessage) {
449-
List<ToolCall> toolCalls = null;
450-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
451-
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
452-
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
453-
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
454-
}).toList();
455-
}
472+
case SYSTEM -> Stream.of(new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM));
473+
case ASSISTANT -> {
474+
if (message instanceof AssistantMessage assistantMessage) {
475+
List<ToolCall> toolCalls = null;
456476

457-
return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(),
458-
MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
459-
}
460-
else if (message instanceof ToolResponseMessage toolResponseMessage) {
461-
toolResponseMessage.getResponses()
462-
.forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
477+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
478+
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
479+
}
463480

464-
return toolResponseMessage.getResponses()
465-
.stream()
466-
.map(toolResponse -> new MistralAiApi.ChatCompletionMessage(toolResponse.responseData(),
467-
MistralAiApi.ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()))
468-
.toList();
469-
}
470-
else {
471-
throw new IllegalStateException("Unexpected message type: " + message);
481+
yield Stream.of(new ChatCompletionMessage(assistantMessage.getText(),
482+
ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
483+
}
484+
else {
485+
throw new IllegalArgumentException(
486+
"Unexpected assistant message class: " + message.getClass().getName());
487+
}
472488
}
473-
}).flatMap(List::stream).toList();
474-
475-
var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
489+
case TOOL -> {
490+
if (message instanceof ToolResponseMessage toolResponseMessage) {
491+
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();
492+
493+
for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
494+
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage must have an id");
495+
chatCompletionMessages.add(new ChatCompletionMessage(toolResponse.responseData(),
496+
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()));
497+
}
476498

477-
MistralAiChatOptions requestOptions = (MistralAiChatOptions) prompt.getOptions();
478-
request = ModelOptionsUtils.merge(requestOptions, request, MistralAiApi.ChatCompletionRequest.class);
499+
yield chatCompletionMessages.stream();
500+
}
501+
else {
502+
throw new IllegalArgumentException(
503+
"Unexpected tool message class: " + message.getClass().getName());
504+
}
505+
}
506+
};
507+
}
479508

480-
// Add the tool definitions to the request's tools parameter.
481-
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
482-
if (!CollectionUtils.isEmpty(toolDefinitions)) {
483-
request = ModelOptionsUtils.merge(
484-
MistralAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,
485-
ChatCompletionRequest.class);
486-
}
509+
private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) {
510+
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
487511

488-
return request;
512+
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
489513
}
490514

491515
private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {

0 commit comments

Comments
 (0)