Skip to content

Commit c7ab0a6

Browse files
committed
Mistral: Support Vision Modality
Introduce multimodality support for Mistral AI, which currently supports text and vision modalities. Added integration tests and documentation for the new capability. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 9e18652 commit c7ab0a6

File tree

5 files changed

+246
-17
lines changed

5 files changed

+246
-17
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.ArrayList;
20+
import java.util.Base64;
1921
import java.util.HashSet;
2022
import java.util.List;
2123
import java.util.Map;
@@ -27,6 +29,8 @@
2729
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
2830
import org.slf4j.Logger;
2931
import org.slf4j.LoggerFactory;
32+
import org.springframework.ai.model.Media;
33+
import org.springframework.util.MimeType;
3034
import reactor.core.publisher.Flux;
3135
import reactor.core.publisher.Mono;
3236

@@ -325,8 +329,19 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
325329

326330
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
327331
if (message instanceof UserMessage userMessage) {
328-
return List.of(new MistralAiApi.ChatCompletionMessage(userMessage.getText(),
329-
MistralAiApi.ChatCompletionMessage.Role.USER));
332+
Object content = message.getText();
333+
334+
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
335+
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
336+
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
337+
338+
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
339+
340+
content = contentList;
341+
}
342+
343+
return List
344+
.of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER));
330345
}
331346
else if (message instanceof SystemMessage systemMessage) {
332347
return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(),
@@ -396,6 +411,27 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
396411
return request;
397412
}
398413

414+
private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
415+
return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl(
416+
this.fromMediaData(media.getMimeType(), media.getData())));
417+
}
418+
419+
private String fromMediaData(MimeType mimeType, Object mediaContentData) {
420+
if (mediaContentData instanceof byte[] bytes) {
421+
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
422+
// following the prefix pattern.
423+
return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
424+
}
425+
else if (mediaContentData instanceof String text) {
426+
// Assume the text is a URLs or a base64 encoded image prefixed by the user.
427+
return text;
428+
}
429+
else {
430+
throw new IllegalArgumentException(
431+
"Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
432+
}
433+
}
434+
399435
private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
400436
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
401437
var function = new MistralAiApi.FunctionTool.Function(functionCallback.getDescription(),

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,8 @@ public record ResponseFormat(@JsonProperty("type") String type) {
739739
/**
740740
* Message comprising the conversation.
741741
*
742-
* @param content The contents of the message.
742+
* @param rawContent The contents of the message. Can be either a {@link MediaContent}
743+
* or a {@link String}. The response message content is always a {@link String}.
743744
* @param role The role of the messages author. Could be one of the {@link Role}
744745
* types.
745746
* @param name The name of the author of the message.
@@ -751,7 +752,7 @@ public record ResponseFormat(@JsonProperty("type") String type) {
751752
@JsonInclude(Include.NON_NULL)
752753
public record ChatCompletionMessage(
753754
// @formatter:off
754-
@JsonProperty("content") String content,
755+
@JsonProperty("content") Object rawContent,
755756
@JsonProperty("role") Role role,
756757
@JsonProperty("name") String name,
757758
@JsonProperty("tool_calls") List<ToolCall> toolCalls,
@@ -766,7 +767,7 @@ public record ChatCompletionMessage(
766767
* @param toolCalls The tool calls generated by the model, such as function calls.
767768
* Applicable only for {@link Role#ASSISTANT} role and null otherwise.
768769
*/
769-
public ChatCompletionMessage(String content, Role role, String name, List<ToolCall> toolCalls) {
770+
public ChatCompletionMessage(Object content, Role role, String name, List<ToolCall> toolCalls) {
770771
this(content, role, name, toolCalls, null);
771772
}
772773

@@ -776,10 +777,23 @@ public ChatCompletionMessage(String content, Role role, String name, List<ToolCa
776777
* @param content The contents of the message.
777778
* @param role The role of the author of this message.
778779
*/
779-
public ChatCompletionMessage(String content, Role role) {
780+
public ChatCompletionMessage(Object content, Role role) {
780781
this(content, role, null, null, null);
781782
}
782783

784+
/**
785+
* Get message content as String.
786+
*/
787+
public String content() {
788+
if (this.rawContent == null) {
789+
return null;
790+
}
791+
if (this.rawContent instanceof String text) {
792+
return text;
793+
}
794+
throw new IllegalStateException("The content is not a string!");
795+
}
796+
783797
/**
784798
* The role of the author of this message.
785799
*
@@ -829,6 +843,63 @@ public record ChatCompletionFunction(@JsonProperty("name") String name,
829843

830844
}
831845

846+
/**
847+
* An array of content parts with a defined type. Each MediaContent can be of
848+
* either "text" or "image_url" type. Only one option allowed.
849+
*
850+
* @param type Content type, each can be of type text or image_url.
851+
* @param text The text content of the message.
852+
* @param imageUrl The image content of the message.
853+
*/
854+
@JsonInclude(Include.NON_NULL)
855+
public record MediaContent(
856+
// @formatter:off
857+
@JsonProperty("type") String type,
858+
@JsonProperty("text") String text,
859+
@JsonProperty("image_url") ImageUrl imageUrl
860+
// @formatter:on
861+
) {
862+
863+
/**
864+
* Shortcut constructor for a text content.
865+
* @param text The text content of the message.
866+
*/
867+
public MediaContent(String text) {
868+
this("text", text, null);
869+
}
870+
871+
/**
872+
* Shortcut constructor for an image content.
873+
* @param imageUrl The image content of the message.
874+
*/
875+
public MediaContent(ImageUrl imageUrl) {
876+
this("image_url", null, imageUrl);
877+
}
878+
879+
/**
880+
* Shortcut constructor for an image content.
881+
*
882+
* @param url Either a URL of the image or the base64 encoded image data. The
883+
* base64 encoded image data must have a special prefix in the following
884+
* format: "data:{mimetype};base64,{base64-encoded-image-data}".
885+
* @param detail Specifies the detail level of the image.
886+
*/
887+
@JsonInclude(Include.NON_NULL)
888+
public record ImageUrl(
889+
// @formatter:off
890+
@JsonProperty("url") String url,
891+
@JsonProperty("detail") String detail
892+
// @formatter:on
893+
) {
894+
895+
public ImageUrl(String url) {
896+
this(url, null);
897+
}
898+
899+
}
900+
901+
}
902+
832903
}
833904

834905
/**

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,43 +16,51 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import java.util.ArrayList;
20-
import java.util.Arrays;
21-
import java.util.List;
22-
import java.util.Map;
23-
import java.util.stream.Collectors;
24-
2519
import org.junit.jupiter.api.Test;
2620
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.ValueSource;
2723
import org.slf4j.Logger;
2824
import org.slf4j.LoggerFactory;
29-
import reactor.core.publisher.Flux;
30-
3125
import org.springframework.ai.chat.messages.AssistantMessage;
3226
import org.springframework.ai.chat.messages.Message;
3327
import org.springframework.ai.chat.messages.UserMessage;
3428
import org.springframework.ai.chat.model.ChatModel;
3529
import org.springframework.ai.chat.model.ChatResponse;
3630
import org.springframework.ai.chat.model.Generation;
3731
import org.springframework.ai.chat.model.StreamingChatModel;
32+
import org.springframework.ai.chat.prompt.ChatOptions;
3833
import org.springframework.ai.chat.prompt.Prompt;
3934
import org.springframework.ai.chat.prompt.PromptTemplate;
4035
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
4136
import org.springframework.ai.converter.BeanOutputConverter;
4237
import org.springframework.ai.converter.ListOutputConverter;
4338
import org.springframework.ai.converter.MapOutputConverter;
4439
import org.springframework.ai.mistralai.api.MistralAiApi;
40+
import org.springframework.ai.model.Media;
4541
import org.springframework.ai.model.function.FunctionCallback;
4642
import org.springframework.beans.factory.annotation.Autowired;
4743
import org.springframework.beans.factory.annotation.Value;
4844
import org.springframework.boot.test.context.SpringBootTest;
4945
import org.springframework.core.convert.support.DefaultConversionService;
46+
import org.springframework.core.io.ClassPathResource;
5047
import org.springframework.core.io.Resource;
48+
import org.springframework.util.MimeTypeUtils;
49+
import reactor.core.publisher.Flux;
50+
51+
import java.io.IOException;
52+
import java.net.URL;
53+
import java.util.ArrayList;
54+
import java.util.Arrays;
55+
import java.util.List;
56+
import java.util.Map;
57+
import java.util.stream.Collectors;
5158

5259
import static org.assertj.core.api.Assertions.assertThat;
5360

5461
/**
5562
* @author Christian Tzolov
63+
* @author Thomas Vitale
5664
* @since 0.8.1
5765
*/
5866
@SpringBootTest(classes = MistralAiTestConfiguration.class)
@@ -238,6 +246,63 @@ void streamFunctionCallTest() {
238246
assertThat(content).containsAnyOf("10.0", "10");
239247
}
240248

249+
@ParameterizedTest(name = "{0} : {displayName} ")
250+
@ValueSource(strings = { "pixtral-large-latest" })
251+
void multiModalityEmbeddedImage(String modelName) {
252+
var imageData = new ClassPathResource("/test.png");
253+
254+
var userMessage = new UserMessage("Explain what do you see on this picture?",
255+
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
256+
257+
var response = this.chatModel
258+
.call(new Prompt(List.of(userMessage), ChatOptions.builder().model(modelName).build()));
259+
260+
logger.info(response.getResult().getOutput().getText());
261+
assertThat(response.getResult().getOutput().getText()).contains("bananas", "apple");
262+
assertThat(response.getResult().getOutput().getText()).containsAnyOf("bowl", "basket", "fruit stand");
263+
}
264+
265+
@ParameterizedTest(name = "{0} : {displayName} ")
266+
@ValueSource(strings = { "pixtral-large-latest" })
267+
void multiModalityImageUrl(String modelName) throws IOException {
268+
var userMessage = new UserMessage("Explain what do you see on this picture?",
269+
List.of(Media.builder()
270+
.mimeType(MimeTypeUtils.IMAGE_PNG)
271+
.data(new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"))
272+
.build()));
273+
274+
ChatResponse response = this.chatModel
275+
.call(new Prompt(List.of(userMessage), ChatOptions.builder().model(modelName).build()));
276+
277+
logger.info(response.getResult().getOutput().getText());
278+
assertThat(response.getResult().getOutput().getText()).contains("bananas", "apple");
279+
assertThat(response.getResult().getOutput().getText()).containsAnyOf("bowl", "basket", "fruit stand");
280+
}
281+
282+
@Test
283+
void streamingMultiModalityImageUrl() throws IOException {
284+
var userMessage = new UserMessage("Explain what do you see on this picture?",
285+
List.of(Media.builder()
286+
.mimeType(MimeTypeUtils.IMAGE_PNG)
287+
.data(new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"))
288+
.build()));
289+
290+
Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(List.of(userMessage),
291+
ChatOptions.builder().model(MistralAiApi.ChatModel.PIXTRAL_LARGE.getValue()).build()));
292+
293+
String content = response.collectList()
294+
.block()
295+
.stream()
296+
.map(ChatResponse::getResults)
297+
.flatMap(List::stream)
298+
.map(Generation::getOutput)
299+
.map(AssistantMessage::getText)
300+
.collect(Collectors.joining());
301+
logger.info("Response: {}", content);
302+
assertThat(content).contains("bananas", "apple");
303+
assertThat(content).containsAnyOf("bowl", "basket", "fruit stand");
304+
}
305+
241306
record ActorsFilmsRecord(String actor, List<String> movies) {
242307

243308
}

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,62 @@ You can register custom Java functions with the `MistralAiChatModel` and have th
140140
This is a powerful technique to connect the LLM capabilities with external tools and APIs.
141141
Read more about xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral AI Function Calling].
142142

143+
== Multimodal
144+
145+
Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats.
146+
Mistral AI supports text and vision modalities.
147+
148+
=== Vision
149+
150+
Mistral AI models that offer vision multimodal support include `pixtral-large-latest`.
151+
Refer to the link:https://docs.mistral.ai/capabilities/vision/[Vision] guide for more information.
152+
153+
The Mistral AI link:https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post[User Message API] can incorporate a list of base64-encoded images or image urls with the message.
154+
Spring AI’s link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java[Message] interface facilitates multimodal AI models by introducing the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java[Media] type.
155+
This type encompasses data and details regarding media attachments in messages, utilizing Spring’s `org.springframework.util.MimeType` and a `org.springframework.core.io.Resource` for the raw media data.
156+
157+
Below is a code example excerpted from `MistralAiChatModelIT.java`, illustrating the fusion of user text with an image.
158+
159+
[source,java]
160+
----
161+
var imageResource = new ClassPathResource("/multimodal.test.png");
162+
163+
var userMessage = new UserMessage("Explain what do you see on this picture?",
164+
new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource));
165+
166+
ChatResponse response = chatModel.call(new Prompt(this.userMessage,
167+
ChatOptions.builder().model(MistralAiApi.ChatModel.PIXTRAL_LARGE.getValue()).build()));
168+
----
169+
170+
or the image URL equivalent:
171+
172+
[source,java]
173+
----
174+
var userMessage = new UserMessage("Explain what do you see on this picture?",
175+
new Media(MimeTypeUtils.IMAGE_PNG,
176+
"https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"));
177+
178+
ChatResponse response = chatModel.call(new Prompt(this.userMessage,
179+
ChatOptions.builder().model(MistralAiApi.ChatModel.PIXTRAL_LARGE.getValue()).build()));
180+
----
181+
182+
TIP: You can pass multiple images as well.
183+
184+
The example shows a model taking as an input the `multimodal.test.png` image:
185+
186+
image::multimodal.test.png[Multimodal Test Image, 200, 200, align="left"]
187+
188+
along with the text message "Explain what do you see on this picture?", and generating a response like this:
189+
190+
----
191+
This is an image of a fruit bowl with a simple design. The bowl is made of metal with curved wire edges that
192+
create an open structure, allowing the fruit to be visible from all angles. Inside the bowl, there are two
193+
yellow bananas resting on top of what appears to be a red apple. The bananas are slightly overripe, as
194+
indicated by the brown spots on their peels. The bowl has a metal ring at the top, likely to serve as a handle
195+
for carrying. The bowl is placed on a flat surface with a neutral-colored background that provides a clear
196+
view of the fruit inside.
197+
----
198+
143199
== OpenAI API Compatibility
144200

145201
Mistral is OpenAI API-compatible and you can use the xref:api/chat/openai-chat.adoc[Spring AI OpenAI] client to talk to Mistrial.

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ and produce a response like:
6565

6666
Spring AI provides multimodal support for the following chat models:
6767

68-
* xref:api/chat/openai-chat.adoc#_multimodal[OpenAI (e.g. GPT-4 and GPT-4o models)]
69-
* xref:api/chat/ollama-chat.adoc#_multimodal[Ollama (e.g. LlaVa, Baklava, Llama3.2 models)]
70-
* xref:api/chat/vertexai-gemini-chat.adoc#_multimodal[Vertex AI Gemini (e.g. gemini-1.5-pro-001, gemini-1.5-flash-001 models)]
7168
* xref:api/chat/anthropic-chat.adoc#_multimodal[Anthropic Claude 3]
7269
* xref:api/chat/bedrock-converse.adoc#_multimodal[AWS Bedrock Converse]
7370
* xref:api/chat/azure-openai-chat.adoc#_multimodal[Azure Open AI (e.g. GPT-4o models)]
71+
* xref:api/chat/mistralai-chat.adoc#_multimodal[Mistral AI (e.g. Mistral Pixtral models)]
72+
* xref:api/chat/ollama-chat.adoc#_multimodal[Ollama (e.g. LlaVa, Baklava, Llama3.2 models)]
73+
* xref:api/chat/openai-chat.adoc#_multimodal[OpenAI (e.g. GPT-4 and GPT-4o models)]
74+
* xref:api/chat/vertexai-gemini-chat.adoc#_multimodal[Vertex AI Gemini (e.g. gemini-1.5-pro-001, gemini-1.5-flash-001 models)]

0 commit comments

Comments
 (0)