Skip to content

Commit f337ee0

Browse files
committed
Address PR review feedbacks and improve code readability
Signed-off-by: Nicolas Krier <[email protected]>
1 parent cf1ba6c commit f337ee0

File tree

2 files changed

+115
-111
lines changed

2 files changed

+115
-111
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.ai.chat.messages.AssistantMessage;
3636
import org.springframework.ai.chat.messages.Message;
3737
import org.springframework.ai.chat.messages.ToolResponseMessage;
38+
import org.springframework.ai.chat.messages.UserMessage;
3839
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3940
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
4041
import org.springframework.ai.chat.metadata.DefaultUsage;
@@ -50,7 +51,6 @@
5051
import org.springframework.ai.chat.prompt.ChatOptions;
5152
import org.springframework.ai.chat.prompt.Prompt;
5253
import org.springframework.ai.content.Media;
53-
import org.springframework.ai.content.MediaContent;
5454
import org.springframework.ai.mistralai.api.MistralAiApi;
5555
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
5656
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion.Choice;
@@ -430,7 +430,7 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
430430
// @formatter:off
431431
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
432432
.stream()
433-
.flatMap(this::createMessages)
433+
.flatMap(this::createChatCompletionMessages)
434434
.toList();
435435
// @formatter:on
436436

@@ -450,60 +450,70 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
450450
return request;
451451
}
452452

453-
/**
454-
* Accessible for testing.
455-
*/
456-
Stream<ChatCompletionMessage> createMessages(Message message) {
457-
return switch (message.getMessageType()) {
458-
case USER -> {
459-
Object content = message.getText();
453+
private Stream<ChatCompletionMessage> createChatCompletionMessages(Message message) {
454+
switch (message.getMessageType()) {
455+
case USER:
456+
return Stream.of(createUserChatCompletionMessage(message));
457+
case SYSTEM:
458+
return Stream.of(createSystemChatCompletionMessage(message));
459+
case ASSISTANT:
460+
return Stream.of(createAssistantChatCompletionMessage(message));
461+
case TOOL:
462+
return createToolChatCompletionMessages(message);
463+
default:
464+
throw new IllegalStateException("Unknown message type: " + message.getMessageType());
465+
}
466+
}
460467

461-
if (message instanceof MediaContent mediaContent && !CollectionUtils.isEmpty(mediaContent.getMedia())) {
462-
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
463-
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
468+
private Stream<ChatCompletionMessage> createToolChatCompletionMessages(Message message) {
469+
if (message instanceof ToolResponseMessage toolResponseMessage) {
470+
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();
464471

465-
contentList.addAll(mediaContent.getMedia().stream().map(this::mapToMediaContent).toList());
472+
for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
473+
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage.ToolResponse must have an id.");
474+
var chatCompletionMessage = new ChatCompletionMessage(toolResponse.responseData(),
475+
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id());
476+
chatCompletionMessages.add(chatCompletionMessage);
477+
}
466478

467-
content = contentList;
468-
}
479+
return chatCompletionMessages.stream();
480+
}
481+
else {
482+
throw new IllegalArgumentException("Unsupported tool message class: " + message.getClass().getName());
483+
}
484+
}
485+
486+
private ChatCompletionMessage createAssistantChatCompletionMessage(Message message) {
487+
if (message instanceof AssistantMessage assistantMessage) {
488+
List<ToolCall> toolCalls = null;
469489

470-
yield Stream.of(new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER));
490+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
491+
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
471492
}
472-
case SYSTEM -> Stream.of(new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM));
473-
case ASSISTANT -> {
474-
if (message instanceof AssistantMessage assistantMessage) {
475-
List<ToolCall> toolCalls = null;
476493

477-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
478-
toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList();
479-
}
494+
return new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, null,
495+
toolCalls, null);
496+
}
497+
else {
498+
throw new IllegalArgumentException("Unsupported assistant message class: " + message.getClass().getName());
499+
}
500+
}
480501

481-
yield Stream.of(new ChatCompletionMessage(assistantMessage.getText(),
482-
ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null));
483-
}
484-
else {
485-
throw new IllegalArgumentException(
486-
"Unexpected assistant message class: " + message.getClass().getName());
487-
}
488-
}
489-
case TOOL -> {
490-
if (message instanceof ToolResponseMessage toolResponseMessage) {
491-
var chatCompletionMessages = new ArrayList<ChatCompletionMessage>();
492-
493-
for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) {
494-
Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage must have an id");
495-
chatCompletionMessages.add(new ChatCompletionMessage(toolResponse.responseData(),
496-
ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id()));
497-
}
502+
private ChatCompletionMessage createSystemChatCompletionMessage(Message message) {
503+
return new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM);
504+
}
498505

499-
yield chatCompletionMessages.stream();
500-
}
501-
else {
502-
throw new IllegalArgumentException(
503-
"Unexpected tool message class: " + message.getClass().getName());
504-
}
505-
}
506-
};
506+
private ChatCompletionMessage createUserChatCompletionMessage(Message message) {
507+
Object content = message.getText();
508+
509+
if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) {
510+
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
511+
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
512+
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
513+
content = contentList;
514+
}
515+
516+
return new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER);
507517
}
508518

509519
private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) {

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java

Lines changed: 57 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424

2525
import org.springframework.ai.chat.messages.AbstractMessage;
2626
import org.springframework.ai.chat.messages.AssistantMessage;
27+
import org.springframework.ai.chat.messages.Message;
2728
import org.springframework.ai.chat.messages.MessageType;
2829
import org.springframework.ai.chat.messages.SystemMessage;
2930
import org.springframework.ai.chat.messages.ToolResponseMessage;
3031
import org.springframework.ai.chat.messages.UserMessage;
3132
import org.springframework.ai.chat.prompt.Prompt;
3233
import org.springframework.ai.content.Media;
33-
import org.springframework.ai.content.MediaContent;
3434
import org.springframework.ai.mistralai.api.MistralAiApi;
3535
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
3636
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -124,51 +124,51 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() {
124124
}
125125

126126
@Test
127-
void createMessagesWithUserMessage() {
127+
void createChatCompletionMessagesWithUserMessage() {
128128
var userMessage = new UserMessage(TEXT_CONTENT);
129129
userMessage.getMedia().add(IMAGE_MEDIA);
130-
var chatCompletionMessages = this.chatModel.createMessages(userMessage).toList();
131-
verifyUserChatCompletionMessages(chatCompletionMessages);
130+
var prompt = createPrompt(userMessage);
131+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
132+
verifyUserChatCompletionMessages(chatCompletionRequest.messages());
132133
}
133134

134135
@Test
135-
void createMessagesWithAnotherUserMessage() {
136-
var anotherUserMessage = new AnotherUserMessage(TEXT_CONTENT, List.of(IMAGE_MEDIA));
137-
var chatCompletionMessages = this.chatModel.createMessages(anotherUserMessage).toList();
138-
verifyUserChatCompletionMessages(chatCompletionMessages);
139-
}
140-
141-
@Test
142-
void createMessagesWithSimpleUserMessage() {
136+
void createChatCompletionMessagesWithSimpleUserMessage() {
143137
var simpleUserMessage = new SimpleMessage(MessageType.USER, TEXT_CONTENT);
144-
var chatCompletionMessages = this.chatModel.createMessages(simpleUserMessage).toList();
138+
var prompt = createPrompt(simpleUserMessage);
139+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
140+
var chatCompletionMessages = chatCompletionRequest.messages();
145141
assertThat(chatCompletionMessages).hasSize(1);
146142
var chatCompletionMessage = chatCompletionMessages.get(0);
147143
assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER);
148144
assertThat(chatCompletionMessage.content()).isEqualTo(TEXT_CONTENT);
149145
}
150146

151147
@Test
152-
void createMessagesWithSystemMessage() {
148+
void createChatCompletionMessagesWithSystemMessage() {
153149
var systemMessage = new SystemMessage(TEXT_CONTENT);
154-
var chatCompletionMessages = this.chatModel.createMessages(systemMessage).toList();
155-
verifySystemChatCompletionMessages(chatCompletionMessages);
150+
var prompt = createPrompt(systemMessage);
151+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
152+
verifySystemChatCompletionMessages(chatCompletionRequest.messages());
156153
}
157154

158155
@Test
159-
void createMessagesWithSimpleSystemMessage() {
156+
void createChatCompletionMessagesWithSimpleSystemMessage() {
160157
var simpleSystemMessage = new SimpleMessage(MessageType.SYSTEM, TEXT_CONTENT);
161-
var chatCompletionMessages = this.chatModel.createMessages(simpleSystemMessage).toList();
162-
verifySystemChatCompletionMessages(chatCompletionMessages);
158+
var prompt = createPrompt(simpleSystemMessage);
159+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
160+
verifySystemChatCompletionMessages(chatCompletionRequest.messages());
163161
}
164162

165163
@Test
166-
void createMessagesWithAssistantMessage() {
164+
void createChatCompletionMessagesWithAssistantMessage() {
167165
var toolCall1 = createToolCall(1);
168166
var toolCall2 = createToolCall(2);
169167
var toolCall3 = createToolCall(3);
170168
var assistantMessage = new AssistantMessage(TEXT_CONTENT, Map.of(), List.of(toolCall1, toolCall2, toolCall3));
171-
var chatCompletionMessages = this.chatModel.createMessages(assistantMessage).toList();
169+
var prompt = createPrompt(assistantMessage);
170+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
171+
var chatCompletionMessages = chatCompletionRequest.messages();
172172
assertThat(chatCompletionMessages).hasSize(1);
173173
var chatCompletionMessage = chatCompletionMessages.get(0);
174174
assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.ASSISTANT);
@@ -181,41 +181,53 @@ void createMessagesWithAssistantMessage() {
181181
}
182182

183183
@Test
184-
void createMessagesWithSimpleAssistantMessage() {
184+
void createChatCompletionMessagesWithSimpleAssistantMessage() {
185185
var simpleAssistantMessage = new SimpleMessage(MessageType.ASSISTANT, TEXT_CONTENT);
186-
assertThatThrownBy(() -> this.chatModel.createMessages(simpleAssistantMessage))
186+
var prompt = createPrompt(simpleAssistantMessage);
187+
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
187188
.isInstanceOf(IllegalArgumentException.class)
188-
.hasMessage("Unexpected assistant message class: " + SimpleMessage.class.getName());
189+
.hasMessage("Unsupported assistant message class: " + SimpleMessage.class.getName());
189190
}
190191

191192
@Test
192-
void createMessagesWithToolResponseMessage() {
193+
void createChatCompletionMessagesWithToolResponseMessage() {
193194
var toolResponse1 = createToolResponse(1);
194195
var toolResponse2 = createToolResponse(2);
195196
var toolResponse3 = createToolResponse(3);
196197
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse1, toolResponse2, toolResponse3));
197-
var chatCompletionMessages = this.chatModel.createMessages(toolResponseMessage).toList();
198+
var prompt = createPrompt(toolResponseMessage);
199+
var chatCompletionRequest = this.chatModel.createRequest(prompt, false);
200+
var chatCompletionMessages = chatCompletionRequest.messages();
198201
assertThat(chatCompletionMessages).hasSize(3);
199202
verifyToolChatCompletionMessage(chatCompletionMessages.get(0), toolResponse1);
200203
verifyToolChatCompletionMessage(chatCompletionMessages.get(1), toolResponse2);
201204
verifyToolChatCompletionMessage(chatCompletionMessages.get(2), toolResponse3);
202205
}
203206

204207
@Test
205-
void createMessagesWithInvalidToolResponseMessage() {
208+
void createChatCompletionMessagesWithInvalidToolResponseMessage() {
206209
var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null);
207210
var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse));
208-
assertThatThrownBy(() -> this.chatModel.createMessages(toolResponseMessage))
211+
var prompt = createPrompt(toolResponseMessage);
212+
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
209213
.isInstanceOf(IllegalArgumentException.class)
210-
.hasMessage("ToolResponseMessage must have an id");
214+
.hasMessage("ToolResponseMessage.ToolResponse must have an id.");
211215
}
212216

213217
@Test
214-
void createMessagesWithSimpleToolMessage() {
218+
void createChatCompletionMessagesWithSimpleToolMessage() {
215219
var simpleToolMessage = new SimpleMessage(MessageType.TOOL, TEXT_CONTENT);
216-
assertThatThrownBy(() -> this.chatModel.createMessages(simpleToolMessage))
220+
var prompt = createPrompt(simpleToolMessage);
221+
assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false))
217222
.isInstanceOf(IllegalArgumentException.class)
218-
.hasMessage("Unexpected tool message class: " + SimpleMessage.class.getName());
223+
.hasMessage("Unsupported tool message class: " + SimpleMessage.class.getName());
224+
}
225+
226+
private Prompt createPrompt(Message message) {
227+
var chatOptions = MistralAiChatOptions.builder().temperature(0.7d).build();
228+
var prompt = new Prompt(message, chatOptions);
229+
230+
return this.chatModel.buildRequestPrompt(prompt);
219231
}
220232

221233
private static void verifyToolChatCompletionMessage(ChatCompletionMessage chatCompletionMessage,
@@ -236,6 +248,7 @@ private static void verifyToolCall(ChatCompletionMessage.ToolCall mistralToolCal
236248
assertThat(mistralToolCall.id()).isEqualTo(toolCall.id());
237249
assertThat(mistralToolCall.type()).isEqualTo(toolCall.type());
238250
var function = mistralToolCall.function();
251+
assertThat(function).isNotNull();
239252
assertThat(function.name()).isEqualTo(toolCall.name());
240253
assertThat(function.arguments()).isEqualTo(toolCall.arguments());
241254
}
@@ -257,21 +270,18 @@ private static void verifyUserChatCompletionMessages(List<ChatCompletionMessage>
257270
assertThat(chatCompletionMessage.role()).isEqualTo(ChatCompletionMessage.Role.USER);
258271
var rawContent = chatCompletionMessage.rawContent();
259272
assertThat(rawContent).isNotNull();
260-
var mediaContents = (List<ChatCompletionMessage.MediaContent>) rawContent;
261-
assertThat(mediaContents).hasSize(2);
262-
var textMediaContent = mediaContents.get(0);
263-
assertThat(textMediaContent).isNotNull();
264-
assertThat(textMediaContent.type()).isEqualTo("text");
265-
assertThat(textMediaContent.text()).isEqualTo(TEXT_CONTENT);
266-
assertThat(textMediaContent.imageUrl()).isNull();
267-
var imageUrlMediaContent = mediaContents.get(1);
268-
assertThat(imageUrlMediaContent).isNotNull();
269-
assertThat(imageUrlMediaContent.type()).isEqualTo("image_url");
270-
assertThat(imageUrlMediaContent.text()).isNull();
271-
var imageUrl = imageUrlMediaContent.imageUrl();
272-
assertThat(imageUrl).isNotNull();
273-
assertThat(imageUrl.url()).isEqualTo(IMAGE_URL);
274-
assertThat(imageUrl.detail()).isNull();
273+
var maps = (List<Map<String, Object>>) rawContent;
274+
assertThat(maps).hasSize(2);
275+
// @formatter:off
276+
var textMap = maps.get(0);
277+
assertThat(textMap).hasSize(2)
278+
.containsEntry("type", "text")
279+
.containsEntry("text", TEXT_CONTENT);
280+
var imageUrlMap = maps.get(1);
281+
assertThat(imageUrlMap).hasSize(2)
282+
.containsEntry("type", "image_url")
283+
.containsEntry("image_url", Map.of("url", IMAGE_URL));
284+
// @formatter:on
275285
}
276286

277287
static class SimpleMessage extends AbstractMessage {
@@ -282,22 +292,6 @@ static class SimpleMessage extends AbstractMessage {
282292

283293
}
284294

285-
static class AnotherUserMessage extends AbstractMessage implements MediaContent {
286-
287-
private final List<Media> media;
288-
289-
AnotherUserMessage(String textContent, List<Media> media) {
290-
super(MessageType.USER, textContent, Map.of());
291-
this.media = List.copyOf(media);
292-
}
293-
294-
@Override
295-
public List<Media> getMedia() {
296-
return this.media;
297-
}
298-
299-
}
300-
301295
static class TestToolCallback implements ToolCallback {
302296

303297
private final ToolDefinition toolDefinition;

0 commit comments

Comments
 (0)