Skip to content

Commit f2820cd

Browse files
ThomasVitaletzolov
authored andcommitted
Mistral: Support Vision Modality
Introduce multimodality support for Mistral AI, which currently supports text and vision modalities. Added integration tests and documentation for the new capability. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 2e5ee43 commit f2820cd

File tree

5 files changed

+246
-18
lines changed

5 files changed

+246
-18
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19+
import java.util.ArrayList;
20+
import java.util.Base64;
1921
import java.util.HashSet;
2022
import java.util.List;
2123
import java.util.Map;
@@ -27,6 +29,8 @@
2729
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
2830
import org.slf4j.Logger;
2931
import org.slf4j.LoggerFactory;
32+
import org.springframework.ai.model.Media;
33+
import org.springframework.util.MimeType;
3034
import reactor.core.publisher.Flux;
3135
import reactor.core.publisher.Mono;
3236

@@ -353,8 +357,19 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
353357

354358
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
355359
if (message instanceof UserMessage userMessage) {
356-
return List.of(new MistralAiApi.ChatCompletionMessage(userMessage.getText(),
357-
MistralAiApi.ChatCompletionMessage.Role.USER));
360+
Object content = message.getText();
361+
362+
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
363+
List<ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(
364+
List.of(new ChatCompletionMessage.MediaContent(message.getText())));
365+
366+
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
367+
368+
content = contentList;
369+
}
370+
371+
return List
372+
.of(new MistralAiApi.ChatCompletionMessage(content, MistralAiApi.ChatCompletionMessage.Role.USER));
358373
}
359374
else if (message instanceof SystemMessage systemMessage) {
360375
return List.of(new MistralAiApi.ChatCompletionMessage(systemMessage.getText(),
@@ -424,6 +439,27 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
424439
return request;
425440
}
426441

442+
private ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
443+
return new ChatCompletionMessage.MediaContent(new ChatCompletionMessage.MediaContent.ImageUrl(
444+
this.fromMediaData(media.getMimeType(), media.getData())));
445+
}
446+
447+
private String fromMediaData(MimeType mimeType, Object mediaContentData) {
448+
if (mediaContentData instanceof byte[] bytes) {
449+
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
450+
// following the prefix pattern.
451+
return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
452+
}
453+
else if (mediaContentData instanceof String text) {
454+
// Assume the text is a URLs or a base64 encoded image prefixed by the user.
455+
return text;
456+
}
457+
else {
458+
throw new IllegalArgumentException(
459+
"Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
460+
}
461+
}
462+
427463
private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
428464
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
429465
var function = new MistralAiApi.FunctionTool.Function(functionCallback.getDescription(),

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,8 @@ public record ResponseFormat(@JsonProperty("type") String type) {
740740
/**
741741
* Message comprising the conversation.
742742
*
743-
* @param content The contents of the message.
743+
* @param rawContent The contents of the message. Can be either a {@link MediaContent}
744+
* or a {@link String}. The response message content is always a {@link String}.
744745
* @param role The role of the messages author. Could be one of the {@link Role}
745746
* types.
746747
* @param name The name of the author of the message.
@@ -752,7 +753,7 @@ public record ResponseFormat(@JsonProperty("type") String type) {
752753
@JsonInclude(Include.NON_NULL)
753754
public record ChatCompletionMessage(
754755
// @formatter:off
755-
@JsonProperty("content") String content,
756+
@JsonProperty("content") Object rawContent,
756757
@JsonProperty("role") Role role,
757758
@JsonProperty("name") String name,
758759
@JsonProperty("tool_calls") List<ToolCall> toolCalls,
@@ -767,7 +768,7 @@ public record ChatCompletionMessage(
767768
* @param toolCalls The tool calls generated by the model, such as function calls.
768769
* Applicable only for {@link Role#ASSISTANT} role and null otherwise.
769770
*/
770-
public ChatCompletionMessage(String content, Role role, String name, List<ToolCall> toolCalls) {
771+
public ChatCompletionMessage(Object content, Role role, String name, List<ToolCall> toolCalls) {
771772
this(content, role, name, toolCalls, null);
772773
}
773774

@@ -777,10 +778,23 @@ public ChatCompletionMessage(String content, Role role, String name, List<ToolCa
777778
* @param content The contents of the message.
778779
* @param role The role of the author of this message.
779780
*/
780-
public ChatCompletionMessage(String content, Role role) {
781+
public ChatCompletionMessage(Object content, Role role) {
781782
this(content, role, null, null, null);
782783
}
783784

785+
/**
786+
* Get message content as String.
787+
*/
788+
public String content() {
789+
if (this.rawContent == null) {
790+
return null;
791+
}
792+
if (this.rawContent instanceof String text) {
793+
return text;
794+
}
795+
throw new IllegalStateException("The content is not a string!");
796+
}
797+
784798
/**
785799
* The role of the author of this message.
786800
*
@@ -830,6 +844,63 @@ public record ChatCompletionFunction(@JsonProperty("name") String name,
830844

831845
}
832846

847+
/**
848+
* An array of content parts with a defined type. Each MediaContent can be of
849+
* either "text" or "image_url" type. Only one option allowed.
850+
*
851+
* @param type Content type, each can be of type text or image_url.
852+
* @param text The text content of the message.
853+
* @param imageUrl The image content of the message.
854+
*/
855+
@JsonInclude(Include.NON_NULL)
856+
public record MediaContent(
857+
// @formatter:off
858+
@JsonProperty("type") String type,
859+
@JsonProperty("text") String text,
860+
@JsonProperty("image_url") ImageUrl imageUrl
861+
// @formatter:on
862+
) {
863+
864+
/**
865+
* Shortcut constructor for a text content.
866+
* @param text The text content of the message.
867+
*/
868+
public MediaContent(String text) {
869+
this("text", text, null);
870+
}
871+
872+
/**
873+
* Shortcut constructor for an image content.
874+
* @param imageUrl The image content of the message.
875+
*/
876+
public MediaContent(ImageUrl imageUrl) {
877+
this("image_url", null, imageUrl);
878+
}
879+
880+
/**
881+
* Shortcut constructor for an image content.
882+
*
883+
* @param url Either a URL of the image or the base64 encoded image data. The
884+
* base64 encoded image data must have a special prefix in the following
885+
* format: "data:{mimetype};base64,{base64-encoded-image-data}".
886+
* @param detail Specifies the detail level of the image.
887+
*/
888+
@JsonInclude(Include.NON_NULL)
889+
public record ImageUrl(
890+
// @formatter:off
891+
@JsonProperty("url") String url,
892+
@JsonProperty("detail") String detail
893+
// @formatter:on
894+
) {
895+
896+
public ImageUrl(String url) {
897+
this(url, null);
898+
}
899+
900+
}
901+
902+
}
903+
833904
}
834905

835906
/**

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,44 +16,52 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import java.util.ArrayList;
20-
import java.util.Arrays;
21-
import java.util.List;
22-
import java.util.Map;
23-
import java.util.stream.Collectors;
24-
2519
import org.junit.jupiter.api.Test;
2620
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.ValueSource;
2723
import org.slf4j.Logger;
2824
import org.slf4j.LoggerFactory;
29-
import reactor.core.publisher.Flux;
30-
3125
import org.springframework.ai.chat.messages.AssistantMessage;
3226
import org.springframework.ai.chat.messages.Message;
3327
import org.springframework.ai.chat.messages.UserMessage;
3428
import org.springframework.ai.chat.model.ChatModel;
3529
import org.springframework.ai.chat.model.ChatResponse;
3630
import org.springframework.ai.chat.model.Generation;
3731
import org.springframework.ai.chat.model.StreamingChatModel;
32+
import org.springframework.ai.chat.prompt.ChatOptions;
3833
import org.springframework.ai.chat.prompt.Prompt;
3934
import org.springframework.ai.chat.prompt.PromptTemplate;
4035
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
4136
import org.springframework.ai.converter.BeanOutputConverter;
4237
import org.springframework.ai.converter.ListOutputConverter;
4338
import org.springframework.ai.converter.MapOutputConverter;
4439
import org.springframework.ai.mistralai.api.MistralAiApi;
40+
import org.springframework.ai.model.Media;
4541
import org.springframework.ai.model.function.FunctionCallback;
4642
import org.springframework.beans.factory.annotation.Autowired;
4743
import org.springframework.beans.factory.annotation.Value;
4844
import org.springframework.boot.test.context.SpringBootTest;
4945
import org.springframework.core.convert.support.DefaultConversionService;
46+
import org.springframework.core.io.ClassPathResource;
5047
import org.springframework.core.io.Resource;
48+
import org.springframework.util.MimeTypeUtils;
49+
import reactor.core.publisher.Flux;
50+
51+
import java.io.IOException;
52+
import java.net.URL;
53+
import java.util.ArrayList;
54+
import java.util.Arrays;
55+
import java.util.List;
56+
import java.util.Map;
57+
import java.util.stream.Collectors;
5158

5259
import static org.assertj.core.api.Assertions.assertThat;
5360

5461
/**
5562
* @author Christian Tzolov
5663
* @author Alexandros Pappas
64+
* @author Thomas Vitale
5765
* @since 0.8.1
5866
*/
5967
@SpringBootTest(classes = MistralAiTestConfiguration.class)
@@ -242,9 +250,65 @@ void streamFunctionCallTest() {
242250
assertThat(content).containsAnyOf("10.0", "10");
243251
}
244252

253+
@ParameterizedTest(name = "{0} : {displayName} ")
254+
@ValueSource(strings = { "pixtral-large-latest" })
255+
void multiModalityEmbeddedImage(String modelName) {
256+
var imageData = new ClassPathResource("/test.png");
257+
258+
var userMessage = new UserMessage("Explain what do you see on this picture?",
259+
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
260+
261+
var response = this.chatModel
262+
.call(new Prompt(List.of(userMessage), ChatOptions.builder().model(modelName).build()));
263+
264+
logger.info(response.getResult().getOutput().getText());
265+
assertThat(response.getResult().getOutput().getText()).contains("bananas", "apple");
266+
assertThat(response.getResult().getOutput().getText()).containsAnyOf("bowl", "basket", "fruit stand");
267+
}
268+
269+
@ParameterizedTest(name = "{0} : {displayName} ")
270+
@ValueSource(strings = { "pixtral-large-latest" })
271+
void multiModalityImageUrl(String modelName) throws IOException {
272+
var userMessage = new UserMessage("Explain what do you see on this picture?",
273+
List.of(Media.builder()
274+
.mimeType(MimeTypeUtils.IMAGE_PNG)
275+
.data(new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"))
276+
.build()));
277+
278+
ChatResponse response = this.chatModel
279+
.call(new Prompt(List.of(userMessage), ChatOptions.builder().model(modelName).build()));
280+
281+
logger.info(response.getResult().getOutput().getText());
282+
assertThat(response.getResult().getOutput().getText()).contains("bananas", "apple");
283+
assertThat(response.getResult().getOutput().getText()).containsAnyOf("bowl", "basket", "fruit stand");
284+
}
285+
245286
@Test
246-
void streamFunctionCallUsageTest() {
287+
void streamingMultiModalityImageUrl() throws IOException {
288+
var userMessage = new UserMessage("Explain what do you see on this picture?",
289+
List.of(Media.builder()
290+
.mimeType(MimeTypeUtils.IMAGE_PNG)
291+
.data(new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"))
292+
.build()));
293+
294+
Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(List.of(userMessage),
295+
ChatOptions.builder().model(MistralAiApi.ChatModel.PIXTRAL_LARGE.getValue()).build()));
296+
297+
String content = response.collectList()
298+
.block()
299+
.stream()
300+
.map(ChatResponse::getResults)
301+
.flatMap(List::stream)
302+
.map(Generation::getOutput)
303+
.map(AssistantMessage::getText)
304+
.collect(Collectors.joining());
305+
logger.info("Response: {}", content);
306+
assertThat(content).contains("bananas", "apple");
307+
assertThat(content).containsAnyOf("bowl", "basket", "fruit stand");
308+
}
247309

310+
@Test
311+
void streamFunctionCallUsageTest() {
248312
UserMessage userMessage = new UserMessage(
249313
"What's the weather like in San Francisco, Tokyo, and Paris? Response in Celsius");
250314

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/mistralai-chat.adoc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,62 @@ You can register custom Java functions with the `MistralAiChatModel` and have th
140140
This is a powerful technique to connect the LLM capabilities with external tools and APIs.
141141
Read more about xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral AI Function Calling].
142142

143+
== Multimodal
144+
145+
Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats.
146+
Mistral AI supports text and vision modalities.
147+
148+
=== Vision
149+
150+
Mistral AI models that offer vision multimodal support include `pixtral-large-latest`.
151+
Refer to the link:https://docs.mistral.ai/capabilities/vision/[Vision] guide for more information.
152+
153+
The Mistral AI link:https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post[User Message API] can incorporate a list of base64-encoded images or image urls with the message.
154+
Spring AI’s link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java[Message] interface facilitates multimodal AI models by introducing the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/Media.java[Media] type.
155+
This type encompasses data and details regarding media attachments in messages, utilizing Spring’s `org.springframework.util.MimeType` and a `org.springframework.core.io.Resource` for the raw media data.
156+
157+
Below is a code example excerpted from `MistralAiChatModelIT.java`, illustrating the fusion of user text with an image.
158+
159+
[source,java]
160+
----
161+
var imageResource = new ClassPathResource("/multimodal.test.png");
162+
163+
var userMessage = new UserMessage("Explain what do you see on this picture?",
164+
new Media(MimeTypeUtils.IMAGE_PNG, this.imageResource));
165+
166+
ChatResponse response = chatModel.call(new Prompt(this.userMessage,
167+
ChatOptions.builder().model(MistralAiApi.ChatModel.PIXTRAL_LARGE.getValue()).build()));
168+
----
169+
170+
or the image URL equivalent:
171+
172+
[source,java]
173+
----
174+
var userMessage = new UserMessage("Explain what do you see on this picture?",
175+
new Media(MimeTypeUtils.IMAGE_PNG,
176+
"https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"));
177+
178+
ChatResponse response = chatModel.call(new Prompt(this.userMessage,
179+
ChatOptions.builder().model(MistralAiApi.ChatModel.PIXTRAL_LARGE.getValue()).build()));
180+
----
181+
182+
TIP: You can pass multiple images as well.
183+
184+
The example shows a model taking as an input the `multimodal.test.png` image:
185+
186+
image::multimodal.test.png[Multimodal Test Image, 200, 200, align="left"]
187+
188+
along with the text message "Explain what do you see on this picture?", and generating a response like this:
189+
190+
----
191+
This is an image of a fruit bowl with a simple design. The bowl is made of metal with curved wire edges that
192+
create an open structure, allowing the fruit to be visible from all angles. Inside the bowl, there are two
193+
yellow bananas resting on top of what appears to be a red apple. The bananas are slightly overripe, as
194+
indicated by the brown spots on their peels. The bowl has a metal ring at the top, likely to serve as a handle
195+
for carrying. The bowl is placed on a flat surface with a neutral-colored background that provides a clear
196+
view of the fruit inside.
197+
----
198+
143199
== OpenAI API Compatibility
144200

145201
Mistral is OpenAI API-compatible and you can use the xref:api/chat/openai-chat.adoc[Spring AI OpenAI] client to talk to Mistrial.

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/multimodality.adoc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ and produce a response like:
6565

6666
Spring AI provides multimodal support for the following chat models:
6767

68-
* xref:api/chat/openai-chat.adoc#_multimodal[OpenAI (e.g. GPT-4 and GPT-4o models)]
69-
* xref:api/chat/ollama-chat.adoc#_multimodal[Ollama (e.g. LlaVa, Baklava, Llama3.2 models)]
70-
* xref:api/chat/vertexai-gemini-chat.adoc#_multimodal[Vertex AI Gemini (e.g. gemini-1.5-pro-001, gemini-1.5-flash-001 models)]
7168
* xref:api/chat/anthropic-chat.adoc#_multimodal[Anthropic Claude 3]
7269
* xref:api/chat/bedrock-converse.adoc#_multimodal[AWS Bedrock Converse]
7370
* xref:api/chat/azure-openai-chat.adoc#_multimodal[Azure Open AI (e.g. GPT-4o models)]
71+
* xref:api/chat/mistralai-chat.adoc#_multimodal[Mistral AI (e.g. Mistral Pixtral models)]
72+
* xref:api/chat/ollama-chat.adoc#_multimodal[Ollama (e.g. LlaVa, Baklava, Llama3.2 models)]
73+
* xref:api/chat/openai-chat.adoc#_multimodal[OpenAI (e.g. GPT-4 and GPT-4o models)]
74+
* xref:api/chat/vertexai-gemini-chat.adoc#_multimodal[Vertex AI Gemini (e.g. gemini-1.5-pro-001, gemini-1.5-flash-001 models)]

0 commit comments

Comments
 (0)