|
21 | 21 | import java.util.List; |
22 | 22 | import java.util.Map; |
23 | 23 | import java.util.concurrent.ConcurrentHashMap; |
| 24 | +import java.util.stream.Stream; |
24 | 25 |
|
25 | 26 | import io.micrometer.observation.Observation; |
26 | 27 | import io.micrometer.observation.ObservationRegistry; |
|
32 | 33 | import reactor.core.scheduler.Schedulers; |
33 | 34 |
|
34 | 35 | import org.springframework.ai.chat.messages.AssistantMessage; |
35 | | -import org.springframework.ai.chat.messages.SystemMessage; |
| 36 | +import org.springframework.ai.chat.messages.Message; |
36 | 37 | import org.springframework.ai.chat.messages.ToolResponseMessage; |
37 | | -import org.springframework.ai.chat.messages.UserMessage; |
38 | 38 | import org.springframework.ai.chat.metadata.ChatGenerationMetadata; |
39 | 39 | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
40 | 40 | import org.springframework.ai.chat.metadata.DefaultUsage; |
|
50 | 50 | import org.springframework.ai.chat.prompt.ChatOptions; |
51 | 51 | import org.springframework.ai.chat.prompt.Prompt; |
52 | 52 | import org.springframework.ai.content.Media; |
| 53 | +import org.springframework.ai.content.MediaContent; |
53 | 54 | import org.springframework.ai.mistralai.api.MistralAiApi; |
54 | 55 | import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion; |
55 | 56 | import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion.Choice; |
|
84 | 85 | * @author luocongqiu |
85 | 86 | * @author Ilayaperumal Gopinathan |
86 | 87 | * @author Alexandros Pappas |
| 88 | + * @author Nicolas Krier |
87 | 89 | * @since 1.0.0 |
88 | 90 | */ |
89 | 91 | public class MistralAiChatModel implements ChatModel { |
@@ -425,67 +427,89 @@ Prompt buildRequestPrompt(Prompt prompt) { |
425 | 427 | * Accessible for testing. |
426 | 428 | */ |
427 | 429 | MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { |
428 | | - List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> { |
429 | | - if (message instanceof UserMessage userMessage) { |
| 430 | + // @formatter:off |
| 431 | + List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions() |
| 432 | + .stream() |
| 433 | + .flatMap(this::createMessages) |
| 434 | + .toList(); |
| 435 | + // @formatter:on |
| 436 | + |
| 437 | + var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream); |
| 438 | + |
| 439 | + MistralAiChatOptions requestOptions = (MistralAiChatOptions) prompt.getOptions(); |
| 440 | + request = ModelOptionsUtils.merge(requestOptions, request, MistralAiApi.ChatCompletionRequest.class); |
| 441 | + |
| 442 | + // Add the tool definitions to the request's tools parameter. |
| 443 | + List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); |
| 444 | + if (!CollectionUtils.isEmpty(toolDefinitions)) { |
| 445 | + request = ModelOptionsUtils.merge( |
| 446 | + MistralAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, |
| 447 | + ChatCompletionRequest.class); |
| 448 | + } |
| 449 | + |
| 450 | + return request; |
| 451 | + } |
| 452 | + |
| 453 | + /** |
| 454 | + * Accessible for testing. |
| 455 | + */ |
| 456 | + Stream<ChatCompletionMessage> createMessages(Message message) { |
| 457 | + return switch (message.getMessageType()) { |
| 458 | + case USER -> { |
430 | 459 | Object content = message.getText(); |
431 | 460 |
|
432 | | - if (!CollectionUtils.isEmpty(userMessage.getMedia())) { |
| 461 | + if (message instanceof MediaContent mediaContent && !CollectionUtils.isEmpty(mediaContent.getMedia())) { |
433 | 462 | List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>( |
434 | 463 | List.of(new ChatCompletionMessage.MediaContent(message.getText()))); |
435 | 464 |
|
436 | | - contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); |
| 465 | + contentList.addAll(mediaContent.getMedia().stream().map(this::mapToMediaContent).toList()); |
437 | 466 |
|
438 | 467 | content = contentList; |
439 | 468 | } |
440 | 469 |
|
441 | | - return List |
442 | | - .of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER)); |
443 | | - } |
444 | | - else if (message instanceof SystemMessage systemMessage) { |
445 | | - return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(), |
446 | | - MistralAiApi.ChatCompletionMessage.Role.SYSTEM)); |
| 470 | + yield Stream.of(new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER)); |
447 | 471 | } |
448 | | - else if (message instanceof AssistantMessage assistantMessage) { |
449 | | - List<ToolCall> toolCalls = null; |
450 | | - if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { |
451 | | - toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { |
452 | | - var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); |
453 | | - return new ToolCall(toolCall.id(), toolCall.type(), function, null); |
454 | | - }).toList(); |
455 | | - } |
| 472 | + case SYSTEM -> Stream.of(new ChatCompletionMessage(message.getText(), ChatCompletionMessage.Role.SYSTEM)); |
| 473 | + case ASSISTANT -> { |
| 474 | + if (message instanceof AssistantMessage assistantMessage) { |
| 475 | + List<ToolCall> toolCalls = null; |
456 | 476 |
|
457 | | - return List.of(new MistralAiApi.ChatCompletionMessage(assistantMessage.getText(), |
458 | | - MistralAiApi.ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null)); |
459 | | - } |
460 | | - else if (message instanceof ToolResponseMessage toolResponseMessage) { |
461 | | - toolResponseMessage.getResponses() |
462 | | - .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); |
| 477 | + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { |
| 478 | + toolCalls = assistantMessage.getToolCalls().stream().map(this::mapToolCall).toList(); |
| 479 | + } |
463 | 480 |
|
464 | | - return toolResponseMessage.getResponses() |
465 | | - .stream() |
466 | | - .map(toolResponse -> new MistralAiApi.ChatCompletionMessage(toolResponse.responseData(), |
467 | | - MistralAiApi.ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id())) |
468 | | - .toList(); |
469 | | - } |
470 | | - else { |
471 | | - throw new IllegalStateException("Unexpected message type: " + message); |
| 481 | + yield Stream.of(new ChatCompletionMessage(assistantMessage.getText(), |
| 482 | + ChatCompletionMessage.Role.ASSISTANT, null, toolCalls, null)); |
| 483 | + } |
| 484 | + else { |
| 485 | + throw new IllegalArgumentException( |
| 486 | + "Unexpected assistant message class: " + message.getClass().getName()); |
| 487 | + } |
472 | 488 | } |
473 | | - }).flatMap(List::stream).toList(); |
474 | | - |
475 | | - var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream); |
| 489 | + case TOOL -> { |
| 490 | + if (message instanceof ToolResponseMessage toolResponseMessage) { |
| 491 | + var chatCompletionMessages = new ArrayList<ChatCompletionMessage>(); |
| 492 | + |
| 493 | + for (ToolResponseMessage.ToolResponse toolResponse : toolResponseMessage.getResponses()) { |
| 494 | + Assert.isTrue(toolResponse.id() != null, "ToolResponseMessage must have an id"); |
| 495 | + chatCompletionMessages.add(new ChatCompletionMessage(toolResponse.responseData(), |
| 496 | + ChatCompletionMessage.Role.TOOL, toolResponse.name(), null, toolResponse.id())); |
| 497 | + } |
476 | 498 |
|
477 | | - MistralAiChatOptions requestOptions = (MistralAiChatOptions) prompt.getOptions(); |
478 | | - request = ModelOptionsUtils.merge(requestOptions, request, MistralAiApi.ChatCompletionRequest.class); |
| 499 | + yield chatCompletionMessages.stream(); |
| 500 | + } |
| 501 | + else { |
| 502 | + throw new IllegalArgumentException( |
| 503 | + "Unexpected tool message class: " + message.getClass().getName()); |
| 504 | + } |
| 505 | + } |
| 506 | + }; |
| 507 | + } |
479 | 508 |
|
480 | | - // Add the tool definitions to the request's tools parameter. |
481 | | - List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); |
482 | | - if (!CollectionUtils.isEmpty(toolDefinitions)) { |
483 | | - request = ModelOptionsUtils.merge( |
484 | | - MistralAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, |
485 | | - ChatCompletionRequest.class); |
486 | | - } |
| 509 | + private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) { |
| 510 | + var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); |
487 | 511 |
|
488 | | - return request; |
| 512 | + return new ToolCall(toolCall.id(), toolCall.type(), function, null); |
489 | 513 | } |
490 | 514 |
|
491 | 515 | private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) { |
|
0 commit comments