Skip to content

Commit 9a30eae

Browse files
committed
Improve MistralAiApi Jackson mapping for message content
- Define different type of chunk instead of media content - Introduce think chunk required by Magistral models and Mistral Small 4 - Introduce prompt_mode request parameter for Magistral models - Introduce reasoning_effort request parameter for Mistral Small 4 - Improve existing Javadoc - Polish integration tests and unit tests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com>
1 parent 019267f commit 9a30eae

18 files changed

+965
-180
lines changed

models/spring-ai-mistral-ai/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@
8080
</dependency>
8181

8282
<dependency>
83-
<groupId>io.micrometer</groupId>
84-
<artifactId>micrometer-observation-test</artifactId>
83+
<groupId>org.springframework.boot</groupId>
84+
<artifactId>spring-boot-starter-test</artifactId>
8585
<scope>test</scope>
8686
</dependency>
8787

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.ai.mistralai;
1818

1919
import java.util.ArrayList;
20-
import java.util.Base64;
2120
import java.util.List;
2221
import java.util.Map;
2322
import java.util.Objects;
@@ -350,7 +349,7 @@ private Flux<ChatResponse> internalStream(Prompt prompt, @Nullable ChatResponse
350349
.doOnError(observation::error)
351350
.doFinally(s -> observation.stop())
352351
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
353-
// @formatter:on;
352+
// @formatter:on
354353

355354
return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
356355
});
@@ -366,7 +365,7 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata)
366365
toolCall.function().name(), toolCall.function().arguments()))
367366
.toList();
368367

369-
var content = choice.message().content();
368+
var content = choice.message().extractContent();
370369
var assistantMessage = AssistantMessage.builder()
371370
.content(content)
372371
.properties(metadata)
@@ -507,14 +506,15 @@ private ChatCompletionMessage createSystemChatCompletionMessage(Message message)
507506
}
508507

509508
private ChatCompletionMessage createUserChatCompletionMessage(Message message) {
510-
Object content = message.getText();
509+
var content = message.getText();
511510
Assert.state(content != null, "content must not be null");
512511

513512
if (message instanceof UserMessage userMessage && !CollectionUtils.isEmpty(userMessage.getMedia())) {
514-
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
515-
List.of(new ChatCompletionMessage.MediaContent((String) content)));
516-
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
517-
content = contentList;
513+
var contentChunks = Stream.<ChatCompletionMessage.ContentChunk>concat(
514+
Stream.of(new ChatCompletionMessage.TextChunk(content)), this.mapToImageUrlChunks(userMessage))
515+
.toList();
516+
517+
return new ChatCompletionMessage(contentChunks, ChatCompletionMessage.Role.USER);
518518
}
519519

520520
return new ChatCompletionMessage(content, ChatCompletionMessage.Role.USER);
@@ -526,24 +526,24 @@ private ToolCall mapToolCall(AssistantMessage.ToolCall toolCall) {
526526
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
527527
}
528528

529-
private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
530-
return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl(
531-
this.fromMediaData(media.getMimeType(), media.getData())));
529+
private Stream<ChatCompletionMessage.ImageUrlChunk> mapToImageUrlChunks(UserMessage userMessage) {
530+
return userMessage.getMedia().stream().map(this::mapToImageUrlChunk);
531+
}
532+
533+
private ChatCompletionMessage.ImageUrlChunk mapToImageUrlChunk(Media media) {
534+
return new ChatCompletionMessage.ImageUrlChunk(this.fromMediaData(media.getMimeType(), media.getData()));
532535
}
533536

534-
private String fromMediaData(MimeType mimeType, Object mediaContentData) {
535-
if (mediaContentData instanceof byte[] bytes) {
536-
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
537-
// following the prefix pattern.
538-
return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
537+
private ChatCompletionMessage.ImageUrlChunk.ImageUrl fromMediaData(MimeType mimeType, Object mediaData) {
538+
if (mediaData instanceof byte[] bytes) {
539+
return ChatCompletionMessage.ImageUrlChunk.ImageUrl.fromImageData(mimeType, bytes);
539540
}
540-
else if (mediaContentData instanceof String text) {
541-
// Assume the text is a URLs or a base64 encoded image prefixed by the user.
542-
return text;
541+
else if (mediaData instanceof String text) {
542+
// Assume the text is a URL or a base64 encoded image prefixed by the user.
543+
return new ChatCompletionMessage.ImageUrlChunk.ImageUrl(text, null);
543544
}
544545
else {
545-
throw new IllegalArgumentException(
546-
"Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
546+
throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
547547
}
548548
}
549549

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright 2023-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mistralai.api;
18+
19+
import java.io.IOException;
20+
21+
import org.springframework.http.HttpHeaders;
22+
import org.springframework.http.HttpRequest;
23+
import org.springframework.http.client.ClientHttpRequestExecution;
24+
import org.springframework.http.client.ClientHttpRequestInterceptor;
25+
import org.springframework.http.client.ClientHttpResponse;
26+
27+
/**
28+
* {@link ClientHttpRequestInterceptor} to apply a content length header based on the body
29+
* length if this header was not present in the request headers.
30+
*
31+
* @author Nicolas Krier
32+
*/
33+
public class ContentLengthInterceptor implements ClientHttpRequestInterceptor {
34+
35+
// @formatter:off
36+
// Temporary solution to fix the following error:
37+
// org.springframework.ai.retry.NonTransientAiException: 411 - {
38+
// "message":"A valid Content-Length header is required",
39+
// "request_id":"5108031f4a1e0d3e6d66204d56b2ac60"
40+
// }
41+
// TODO: Discuss with Sébastien Deleuze the opportunity to add this class into Spring Framework if this solution is satisfying.
42+
// @formatter:on
43+
@Override
44+
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
45+
throws IOException {
46+
var headers = request.getHeaders();
47+
48+
if (!headers.containsHeader(HttpHeaders.CONTENT_LENGTH)) {
49+
headers.setContentLength(body.length);
50+
}
51+
52+
return execution.execute(request, body);
53+
}
54+
55+
}

0 commit comments

Comments
 (0)