Skip to content

Commit a77d1e5

Browse files
committed
Address PR review feedbacks and improve code readability
Signed-off-by: Nicolas Krier <[email protected]>
1 parent e9d3c90 commit a77d1e5

File tree

2 files changed

+118
-89
lines changed

2 files changed

+118
-89
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
430430
// @formatter:off
431431
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
432432
.stream()
433-
.flatMap(this::createMessages)
433+
.flatMap(this::createChatCompletionMessages)
434434
.toList();
435435
// @formatter:on
436436

@@ -450,60 +450,70 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
450450
return request;
451451
}
452452

453-
/**
454-
* Accessible for testing.
455-
*/
456-
Stream<ChatCompletionMessage> createMessages(Message message) {
457-
return switch (message.getMessageType()) {
458-
case USER -> {
459-
Object content = message.getText();
453+
private Stream<ChatCompletionMessage> createChatCompletionMessages(Message message) {
454+
switch (message.getMessageType()) {
455+
case USER:
456+
return Stream.of(createUserChatCompletionMessage(message));
457+
case SYSTEM:
458+
return Stream.of(createSystemChatCompletionMessage(message));
459+
case ASSISTANT:
460+
return Stream.of(createAssistantChatCompletionMessage(message));
461+
case TOOL:
462+
return createToolChatCompletionMessages(message);
463+
default:
464+
throw new IllegalStateException("Unknown message type: " + message.getMessageType());
465+
}
466+
}
460467

461-
if (message instanceof MediaContent mediaContent && !CollectionUtils.isEmpty(mediaContent.getMedia())) {
462-
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
463-
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
468+
private Stream<ChatCompletionMessage> createToolChatCompletionMessages(Message message) {
469+
if (message instanceof ToolResponseMessage toolResponseMessage) {
470+
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();
464471

465-
contentList.addAll(mediaContent.getMedia().stream().map(this::mapToMediaContent).toList());
472+
for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
473+
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage.ToolResponse must have an id.");
474+
var chatCompletionMessage = new ChatCompletionMessage(toolResponse.responseData(),
475+
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id());
476+
chatCompletionMessages.add(chatCompletionMessage);
477+
}
466478

467-
content = contentList;
468-
}
479+
return chatCompletionMessages.stream();
480+
}
481+
else {
482+
throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName());
483+
}
484+
}
485+
486+
private ChatCompletionMessage createAssistantChatCompletionMessage(Message message) {
487+
if (message instanceof AssistantMessage assistantMessage) {
488+
List<ToolCall> toolCalls = null;
469489

470-
yield Stream.of(new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER));
490+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
491+
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
471492
}
472-
case SYSTEM -> Stream.of(new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM));
473-
case ASSISTANT -> {
474-
if (message instanceof AssistantMessage assistantMessage) {
475-
List<ToolCall> toolCalls = null;
476493

477-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
478-
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
479-
}
494+
return new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, null,
495+
toolCalls, null);
496+
}
497+
else {
498+
throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName());
499+
}
500+
}
480501

481-
yield Stream.of(new ChatCompletionMessage(assistantMessage.getText(),
482-
ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
483-
}
484-
else {
485-
throw new IllegalArgumentException(
486-
"Unexpected assistant message class: " + message.getClass().getName());
487-
}
488-
}
489-
case TOOL -> {
490-
if (message instanceof ToolResponseMessage toolResponseMessage) {
491-
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();
492-
493-
for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
494-
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage must have an id");
495-
chatCompletionMessages.add(new ChatCompletionMessage(toolResponse.responseData(),
496-
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()));
497-
}
502+
private ChatCompletionMessage createSystemChatCompletionMessage(Message message) {
503+
return new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM);
504+
}
498505

499-
yield chatCompletionMessages.stream();
500-
}
501-
else {
502-
throw new IllegalArgumentException(
503-
"Unexpected tool message class: " + message.getClass().getName());
504-
}
505-
}
506-
};
506+
private ChatCompletionMessage createUserChatCompletionMessage(Message message) {
507+
Object content = message.getText();
508+
509+
if (message instanceof MediaContent mediaContent && !CollectionUtils.isEmpty(mediaContent.getMedia())) {
510+
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
511+
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
512+
contentList.addAll(mediaContent.getMedia().stream().map(this::mapToMediaContent).toList());
513+
content = contentList;
514+
}
515+
516+
return new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER);
507517
}
508518

509519
private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) {

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelTests.java

Lines changed: 61 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import org.springframework.ai.chat.messages.AbstractMessage;
2626
import org.springframework.ai.chat.messages.AssistantMessage;
27+
import org.springframework.ai.chat.messages.Message;
2728
import org.springframework.ai.chat.messages.MessageType;
2829
import org.springframework.ai.chat.messages.SystemMessage;
2930
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -124,51 +125,59 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() {
124125
}
125126

126127
@Test
127-
void createMessagesWithUserMessage() {
128+
void createChatCompletionMessagesWithUserMessage() {
128129
var userMessage = new UserMessage(TEXT_CONTENT);
129130
userMessage.getMedia().add(IMAGE_MEDIA);
130-
var chatCompletionMessages = this.chatModel.createMessages(userMessage).toList();
131-
verifyUserChatCompletionMessages(chatCompletionMessages);
131+
var prompt = createPrompt(userMessage);
132+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
133+
verifyUserChatCompletionMessages(chatCompletionRequest.messages());
132134
}
133135

134136
@Test
135-
void createMessagesWithAnotherUserMessage() {
137+
void createChatCompletionMessagesWithAnotherUserMessage() {
136138
var anotherUserMessage = new AnotherUserMessage(TEXT_CONTENT, List.of(IMAGE_MEDIA));
137-
var chatCompletionMessages = this.chatModel.createMessages(anotherUserMessage).toList();
138-
verifyUserChatCompletionMessages(chatCompletionMessages);
139+
var prompt = createPrompt(anotherUserMessage);
140+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
141+
verifyUserChatCompletionMessages(chatCompletionRequest.messages());
139142
}
140143

141144
@Test
142-
void createMessagesWithSimpleUserMessage() {
145+
void createChatCompletionMessagesWithSimpleUserMessage() {
143146
var simpleUserMessage = new SimpleMessage(MessageType.USER, TEXT_CONTENT);
144-
var chatCompletionMessages = this.chatModel.createMessages(simpleUserMessage).toList();
147+
var prompt = createPrompt(simpleUserMessage);
148+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
149+
var chatCompletionMessages = chatCompletionRequest.messages();
145150
assertThat(chatCompletionMessages).hasSize(1);
146151
var chatCompletionMessage = chatCompletionMessages.get(0);
147152
assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER);
148153
assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT);
149154
}
150155

151156
@Test
152-
void createMessagesWithSystemMessage() {
157+
void createChatCompletionMessagesWithSystemMessage() {
153158
var systemMessage = new SystemMessage(TEXT_CONTENT);
154-
var chatCompletionMessages = this.chatModel.createMessages(systemMessage).toList();
155-
verifySystemChatCompletionMessages(chatCompletionMessages);
159+
var prompt = createPrompt(systemMessage);
160+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
161+
verifySystemChatCompletionMessages(chatCompletionRequest.messages());
156162
}
157163

158164
@Test
159-
void createMessagesWithSimpleSystemMessage() {
165+
void createChatCompletionMessagesWithSimpleSystemMessage() {
160166
var simpleSystemMessage = new SimpleMessage(MessageType.SYSTEM, TEXT_CONTENT);
161-
var chatCompletionMessages = this.chatModel.createMessages(simpleSystemMessage).toList();
162-
verifySystemChatCompletionMessages(chatCompletionMessages);
167+
var prompt = createPrompt(simpleSystemMessage);
168+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
169+
verifySystemChatCompletionMessages(chatCompletionRequest.messages());
163170
}
164171

165172
@Test
166-
void createMessagesWithAssistantMessage() {
173+
void createChatCompletionMessagesWithAssistantMessage() {
167174
var toolCall1 = createToolCall(1);
168175
var toolCall2 = createToolCall(2);
169176
var toolCall3 = createToolCall(3);
170177
var assistantMessage = new AssistantMessage(TEXT_CONTENT, Map.of(), List.of(toolCall1, toolCall2, toolCall3));
171-
var chatCompletionMessages = this.chatModel.createMessages(assistantMessage).toList();
178+
var prompt = createPrompt(assistantMessage);
179+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
180+
var chatCompletionMessages = chatCompletionRequest.messages();
172181
assertThat(chatCompletionMessages).hasSize(1);
173182
var chatCompletionMessage = chatCompletionMessages.get(0);
174183
assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.ASSISTANT);
@@ -181,41 +190,55 @@ void createMessagesWithAssistantMessage() {
181190
}
182191

183192
@Test
184-
void createMessagesWithSimpleAssistantMessage() {
193+
void createChatCompletionMessagesWithSimpleAssistantMessage() {
185194
var simpleAssistantMessage = new SimpleMessage(MessageType.ASSISTANT, TEXT_CONTENT);
186-
assertThatThrownBy(() -> this.chatModel.createMessages(simpleAssistantMessage))
195+
var prompt = createPrompt(simpleAssistantMessage);
196+
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
187197
.isInstanceOf(IllegalArgumentException.class)
188-
.hasMessage("Unexpected assistant message class: " + SimpleMessage.class.getName());
198+
.hasMessage("Unsupported assistant message class: " + SimpleMessage.class.getName());
189199
}
190200

191201
@Test
192-
void createMessagesWithToolResponseMessage() {
202+
void createChatCompletionMessagesWithToolResponseMessage() {
193203
var toolResponse1 = createToolResponse(1);
194204
var toolResponse2 = createToolResponse(2);
195205
var toolResponse3 = createToolResponse(3);
196206
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse1, toolResponse2, toolResponse3));
197-
var chatCompletionMessages = this.chatModel.createMessages(toolResponseMessage).toList();
207+
var prompt = createPrompt(toolResponseMessage);
208+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
209+
var chatCompletionMessages = chatCompletionRequest.messages();
198210
assertThat(chatCompletionMessages).hasSize(3);
199211
verifyToolChatCompletionMessage(chatCompletionMessages.get(0), toolResponse1);
200212
verifyToolChatCompletionMessage(chatCompletionMessages.get(1), toolResponse2);
201213
verifyToolChatCompletionMessage(chatCompletionMessages.get(2), toolResponse3);
202214
}
203215

204216
@Test
205-
void createMessagesWithInvalidToolResponseMessage() {
217+
void createChatCompletionMessagesWithInvalidToolResponseMessage() {
206218
var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null);
207219
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse));
208-
assertThatThrownBy(() -> this.chatModel.createMessages(toolResponseMessage))
220+
var prompt = createPrompt(toolResponseMessage);
221+
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
209222
.isInstanceOf(IllegalArgumentException.class)
210-
.hasMessage("ToolResponseMessage must have an id");
223+
.hasMessage("ToolResponseMessage.ToolResponse must have an id.");
211224
}
212225

213226
@Test
214-
void createMessagesWithSimpleToolMessage() {
227+
void createChatCompletionMessagesWithSimpleToolMessage() {
215228
var simpleToolMessage = new SimpleMessage(MessageType.TOOL, TEXT_CONTENT);
216-
assertThatThrownBy(() -> this.chatModel.createMessages(simpleToolMessage))
229+
var prompt = createPrompt(simpleToolMessage);
230+
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
217231
.isInstanceOf(IllegalArgumentException.class)
218-
.hasMessage("Unexpected tool message class: " + SimpleMessage.class.getName());
232+
.hasMessage("Unsupported tool message class: " + SimpleMessage.class.getName());
233+
}
234+
235+
private Prompt createPrompt(Message message) {
236+
var chatOptions = MistralAiChatOptions.builder()
237+
.temperature(0.7d)
238+
.build();
239+
var prompt = new Prompt(message, chatOptions);
240+
241+
return this.chatModel.buildRequestPrompt(prompt);
219242
}
220243

221244
private static void verifyToolChatCompletionMessage(ChatCompletionMessage chatCompletionMessage,
@@ -236,6 +259,7 @@ private static void verifyToolCall(ChatCompletionMessage.ToolCall mistralToolCal
236259
assertThat(mistralToolCall.id()).isEqualTo(toolCall.id());
237260
assertThat(mistralToolCall.type()).isEqualTo(toolCall.type());
238261
var function = mistralToolCall.function();
262+
assertThat(function).isNotNull();
239263
assertThat(function.name()).isEqualTo(toolCall.name());
240264
assertThat(function.arguments()).isEqualTo(toolCall.arguments());
241265
}
@@ -257,21 +281,16 @@ private static void verifyUserChatCompletionMessages(List<ChatCompletionMessage>
257281
assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER);
258282
var rawContent = chatCompletionMessage.rawContent();
259283
assertThat(rawContent).isNotNull();
260-
var mediaContents = (List<ChatCompletionMessage.MediaContent>) rawContent;
261-
assertThat(mediaContents).hasSize(2);
262-
var textMediaContent = mediaContents.get(0);
263-
assertThat(textMediaContent).isNotNull();
264-
assertThat(textMediaContent.type()).isEqualTo("text");
265-
assertThat(textMediaContent.text()).isEqualTo(TEXT_CONTENT);
266-
assertThat(textMediaContent.imageUrl()).isNull();
267-
var imageUrlMediaContent = mediaContents.get(1);
268-
assertThat(imageUrlMediaContent).isNotNull();
269-
assertThat(imageUrlMediaContent.type()).isEqualTo("image_url");
270-
assertThat(imageUrlMediaContent.text()).isNull();
271-
var imageUrl = imageUrlMediaContent.imageUrl();
272-
assertThat(imageUrl).isNotNull();
273-
assertThat(imageUrl.url()).isEqualTo(IMAGE_URL);
274-
assertThat(imageUrl.detail()).isNull();
284+
var maps = (List<Map<String, Object>>) rawContent;
285+
assertThat(maps).hasSize(2);
286+
var textMap = maps.get(0);
287+
assertThat(textMap).hasSize(2)
288+
.containsEntry("type", "text")
289+
.containsEntry("text", TEXT_CONTENT);
290+
var imageUrlMap = maps.get(1);
291+
assertThat(imageUrlMap).hasSize(2)
292+
.containsEntry("type", "image_url")
293+
.containsEntry("image_url", Map.of("url", IMAGE_URL));
275294
}
276295

277296
static class SimpleMessage extends AbstractMessage {

0 commit comments

Comments
 (0)