Skip to content

Commit ea48e17

Browse files
nicolaskriermarkpollack
authored andcommitted
Add Mistral AI codestral Embed model
- Add Mistral AI Codestral Embed model - Add unit tests for embedding dimension mappings and fallback behavior - Add JavaDoc comments for KNOWN_EMBEDDING_DIMENSIONS map - Add detailed JavaDoc for EMBED and CODESTRAL_EMBED enum values - Update AsciiDoc documentation with model comparison table - Add usage examples for both mistral-embed and codestral-embed models - Include test for sync validation between enum and dimensions map Signed-off-by: Nicolas Krier <[email protected]> Signed-off-by: Mark Pollack <[email protected]>
1 parent e91eda9 commit ea48e17

File tree

6 files changed

+263
-28
lines changed

6 files changed

+263
-28
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.mistralai;
1818

1919
import java.util.List;
20+
import java.util.Map;
2021

2122
import io.micrometer.observation.ObservationRegistry;
2223
import org.slf4j.Logger;
@@ -48,12 +49,22 @@
4849
* @author Ricken Bazolo
4950
* @author Thomas Vitale
5051
* @author Jason Smith
52+
* @author Nicolas Krier
5153
* @since 1.0.0
5254
*/
5355
public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5456

5557
private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);
5658

59+
/**
60+
* Known embedding dimensions for Mistral AI models. Maps model names to their
61+
* respective embedding vector dimensions. This allows the dimensions() method to
62+
* return the correct value without making an API call.
63+
*/
64+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(
65+
MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024, MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(),
66+
1536);
67+
5768
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
5869

5970
private final MistralAiEmbeddingOptions defaultOptions;
@@ -184,6 +195,11 @@ public float[] embed(Document document) {
184195
return this.embed(document.getFormattedContent(this.metadataMode));
185196
}
186197

198+
@Override
199+
public int dimensions() {
200+
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
201+
}
202+
187203
/**
188204
* Use the provided convention for reporting observation data
189205
* @param observationConvention The provided convention

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
* @author Christian Tzolov
6363
* @author Thomas Vitale
6464
* @author Jason Smith
65+
* @author Nicolas Krier
6566
* @since 1.0.0
6667
*/
6768
public class MistralAiApi {
@@ -330,7 +331,20 @@ public String getName() {
330331
public enum EmbeddingModel {
331332

332333
// @formatter:off
333-
EMBED("mistral-embed");
334+
/**
335+
* Mistral Embed model for general text embeddings.
336+
* Produces 1024-dimensional embeddings suitable for semantic search,
337+
* clustering, and other text similarity tasks.
338+
*/
339+
EMBED("mistral-embed"),
340+
341+
/**
342+
* Codestral Embed model optimized for code embeddings.
343+
* Produces 1536-dimensional embeddings specifically designed for
344+
* code similarity, code search, and retrieval-augmented generation (RAG)
345+
* with code repositories.
346+
*/
347+
CODESTRAL_EMBED("codestral-embed");
334348
// @formatter:on
335349

336350
private final String value;
Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,46 +20,79 @@
2020

2121
import org.junit.jupiter.api.Test;
2222
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23+
import org.junit.jupiter.params.ParameterizedTest;
24+
import org.junit.jupiter.params.provider.CsvSource;
2325

2426
import org.springframework.ai.embedding.EmbeddingRequest;
27+
import org.springframework.ai.mistralai.api.MistralAiApi;
2528
import org.springframework.beans.factory.annotation.Autowired;
2629
import org.springframework.boot.test.context.SpringBootTest;
2730

2831
import static org.assertj.core.api.Assertions.assertThat;
2932

33+
/**
34+
* @author Nicolas Krier
35+
*/
3036
@SpringBootTest(classes = MistralAiTestConfiguration.class)
3137
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
3238
class MistralAiEmbeddingIT {
3339

40+
private static final int MISTRAL_EMBED_DIMENSIONS = 1024;
41+
42+
@Autowired
43+
private MistralAiApi mistralAiApi;
44+
3445
@Autowired
3546
private MistralAiEmbeddingModel mistralAiEmbeddingModel;
3647

3748
@Test
3849
void defaultEmbedding() {
39-
assertThat(this.mistralAiEmbeddingModel).isNotNull();
4050
var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
4151
assertThat(embeddingResponse.getResults()).hasSize(1);
4252
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
43-
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
53+
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(MISTRAL_EMBED_DIMENSIONS);
4454
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
4555
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
4656
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
47-
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
57+
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS);
4858
}
4959

50-
@Test
51-
void embeddingTest() {
52-
assertThat(this.mistralAiEmbeddingModel).isNotNull();
53-
var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest(
54-
List.of("Hello World", "World is big"),
55-
MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build()));
60+
@ParameterizedTest
61+
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
62+
void defaultOptionsEmbedding(String model, int dimensions) {
63+
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
64+
var anotherMistralAiEmbeddingModel = MistralAiEmbeddingModel.builder()
65+
.mistralAiApi(this.mistralAiApi)
66+
.options(mistralAiEmbeddingOptions)
67+
.build();
68+
var embeddingResponse = anotherMistralAiEmbeddingModel.embedForResponse(List.of("Hello World", "World is big"));
5669
assertThat(embeddingResponse.getResults()).hasSize(2);
57-
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
58-
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
59-
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
70+
embeddingResponse.getResults().forEach(result -> {
71+
assertThat(result).isNotNull();
72+
assertThat(result.getOutput()).hasSize(dimensions);
73+
});
74+
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
6075
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9);
6176
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9);
62-
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
77+
assertThat(anotherMistralAiEmbeddingModel.dimensions()).isEqualTo(dimensions);
78+
}
79+
80+
@ParameterizedTest
81+
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
82+
void calledOptionsEmbedding(String model, int dimensions) {
83+
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
84+
var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big", "We are small"),
85+
mistralAiEmbeddingOptions);
86+
var embeddingResponse = this.mistralAiEmbeddingModel.call(embeddingRequest);
87+
assertThat(embeddingResponse.getResults()).hasSize(3);
88+
embeddingResponse.getResults().forEach(result -> {
89+
assertThat(result).isNotNull();
90+
assertThat(result.getOutput()).hasSize(dimensions);
91+
});
92+
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
93+
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(14);
94+
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(14);
95+
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(MISTRAL_EMBED_DIMENSIONS);
6396
}
6497

6598
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mistralai;
18+
19+
import java.util.List;
20+
21+
import org.junit.jupiter.api.Test;
22+
import org.mockito.Mockito;
23+
24+
import org.springframework.ai.document.MetadataMode;
25+
import org.springframework.ai.mistralai.api.MistralAiApi;
26+
import org.springframework.ai.retry.RetryUtils;
27+
import org.springframework.http.ResponseEntity;
28+
29+
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.mockito.ArgumentMatchers.any;
31+
import static org.mockito.Mockito.when;
32+
33+
/**
34+
* Unit tests for {@link MistralAiEmbeddingModel}.
35+
*
36+
* @author Nicolas Krier
37+
*/
38+
class MistralAiEmbeddingModelTests {
39+
40+
@Test
41+
void testDimensionsForMistralEmbedModel() {
42+
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1024);
43+
44+
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder()
45+
.withModel(MistralAiApi.EmbeddingModel.EMBED.getValue())
46+
.build();
47+
48+
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
49+
RetryUtils.DEFAULT_RETRY_TEMPLATE);
50+
51+
assertThat(model.dimensions()).isEqualTo(1024);
52+
}
53+
54+
@Test
55+
void testDimensionsForCodestralEmbedModel() {
56+
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1536);
57+
58+
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder()
59+
.withModel(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue())
60+
.build();
61+
62+
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
63+
RetryUtils.DEFAULT_RETRY_TEMPLATE);
64+
65+
assertThat(model.dimensions()).isEqualTo(1536);
66+
}
67+
68+
@Test
69+
void testDimensionsFallbackForUnknownModel() {
70+
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(512);
71+
72+
// Use a model name that doesn't exist in KNOWN_EMBEDDING_DIMENSIONS
73+
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build();
74+
75+
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
76+
RetryUtils.DEFAULT_RETRY_TEMPLATE);
77+
78+
// Should fall back to super.dimensions() which detects dimensions from the API
79+
// response
80+
assertThat(model.dimensions()).isEqualTo(512);
81+
}
82+
83+
@Test
84+
void testAllEmbeddingModelsHaveDimensionMapping() {
85+
// This test ensures that KNOWN_EMBEDDING_DIMENSIONS map stays in sync with the
86+
// EmbeddingModel enum
87+
// If a new model is added to the enum but not to the dimensions map, this test
88+
// will help catch it
89+
90+
for (MistralAiApi.EmbeddingModel embeddingModel : MistralAiApi.EmbeddingModel.values()) {
91+
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1024);
92+
MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder()
93+
.withModel(embeddingModel.getValue())
94+
.build();
95+
96+
MistralAiEmbeddingModel model = new MistralAiEmbeddingModel(mockApi, MetadataMode.EMBED, options,
97+
RetryUtils.DEFAULT_RETRY_TEMPLATE);
98+
99+
// Each model should have a valid dimension (not the fallback -1)
100+
assertThat(model.dimensions()).as("Model %s should have a dimension mapping", embeddingModel.getValue())
101+
.isGreaterThan(0);
102+
}
103+
}
104+
105+
@Test
106+
void testBuilderCreatesValidModel() {
107+
MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1536);
108+
109+
MistralAiEmbeddingModel model = MistralAiEmbeddingModel.builder()
110+
.mistralAiApi(mockApi)
111+
.options(MistralAiEmbeddingOptions.builder()
112+
.withModel(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue())
113+
.build())
114+
.build();
115+
116+
assertThat(model).isNotNull();
117+
assertThat(model.dimensions()).isEqualTo(1536);
118+
}
119+
120+
private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) {
121+
MistralAiApi mockApi = Mockito.mock(MistralAiApi.class);
122+
123+
// Create a mock embedding response with the specified dimensions
124+
float[] embedding = new float[dimensions];
125+
for (int i = 0; i < dimensions; i++) {
126+
embedding[i] = 0.1f;
127+
}
128+
129+
MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding");
130+
131+
MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10);
132+
133+
MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData),
134+
"model", usage);
135+
136+
when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));
137+
138+
return mockApi;
139+
}
140+
141+
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import org.springframework.ai.embedding.EmbeddingModel;
2019
import org.springframework.ai.mistralai.api.MistralAiApi;
2120
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
2221
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
@@ -26,32 +25,32 @@
2625

2726
/**
2827
* @author Jason Smith
28+
* @author Nicolas Krier
2929
*/
3030
@SpringBootConfiguration
3131
public class MistralAiTestConfiguration {
3232

33-
@Bean
34-
public MistralAiApi mistralAiApi() {
33+
private static String retrieveApiKey() {
3534
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
3635
if (!StringUtils.hasText(apiKey)) {
3736
throw new IllegalArgumentException(
3837
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
3938
}
40-
return MistralAiApi.builder().apiKey(apiKey).build();
39+
return apiKey;
40+
}
41+
42+
@Bean
43+
public MistralAiApi mistralAiApi() {
44+
return MistralAiApi.builder().apiKey(retrieveApiKey()).build();
4145
}
4246

4347
@Bean
4448
public MistralAiModerationApi mistralAiModerationApi() {
45-
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
46-
if (!StringUtils.hasText(apiKey)) {
47-
throw new IllegalArgumentException(
48-
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
49-
}
50-
return MistralAiModerationApi.builder().apiKey(apiKey).build();
49+
return MistralAiModerationApi.builder().apiKey(retrieveApiKey()).build();
5150
}
5251

5352
@Bean
54-
public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
53+
public MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
5554
return MistralAiEmbeddingModel.builder().mistralAiApi(api).build();
5655
}
5756

0 commit comments

Comments
 (0)