Skip to content

Commit aec9da0

Browse files
committed
Add Mistral AI Codestral Embed model
Signed-off-by: Nicolas Krier <[email protected]>
1 parent 72f7c63 commit aec9da0

File tree

3 files changed

+37
-22
lines changed

3 files changed

+37
-22
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.mistralai;
1818

1919
import java.util.List;
20+
import java.util.Map;
2021

2122
import io.micrometer.observation.ObservationRegistry;
2223
import org.slf4j.Logger;
@@ -41,6 +42,9 @@
4142
import org.springframework.retry.support.RetryTemplate;
4243
import org.springframework.util.Assert;
4344

45+
import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.CODESTRAL_EMBED;
46+
import static org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingModel.EMBED;
47+
4448
/**
4549
* Provides the Mistral AI Embedding Model.
4650
*
@@ -53,6 +57,9 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5357

5458
private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class);
5559

60+
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Map.of(EMBED.getValue(), 1024,
61+
CODESTRAL_EMBED.getValue(), 1536);
62+
5663
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
5764

5865
private final MistralAiEmbeddingOptions defaultOptions;
@@ -78,8 +85,7 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
7885
}
7986

8087
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
81-
this(mistralAiApi, metadataMode,
82-
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
88+
this(mistralAiApi, metadataMode, MistralAiEmbeddingOptions.builder().withModel(EMBED.getValue()).build(),
8389
RetryUtils.DEFAULT_RETRY_TEMPLATE);
8490
}
8591

@@ -179,6 +185,11 @@ public float[] embed(Document document) {
179185
return this.embed(document.getFormattedContent(this.metadataMode));
180186
}
181187

188+
@Override
189+
public int dimensions() {
190+
return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions());
191+
}
192+
182193
/**
183194
* Use the provided convention for reporting observation data
184195
* @param observationConvention The provided convention

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ public String getName() {
321321
public enum EmbeddingModel {
322322

323323
// @formatter:off
324-
EMBED("mistral-embed");
324+
EMBED("mistral-embed"),
325+
CODESTRAL_EMBED("codestral-embed");
325326
// @formatter:on
326327

327328
private final String value;

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,50 +16,53 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import java.util.List;
20-
2119
import org.junit.jupiter.api.Test;
2220
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23-
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.CsvSource;
2423
import org.springframework.ai.embedding.EmbeddingRequest;
24+
import org.springframework.ai.mistralai.api.MistralAiApi;
2525
import org.springframework.beans.factory.annotation.Autowired;
2626
import org.springframework.boot.test.context.SpringBootTest;
2727

28+
import java.util.List;
29+
2830
import static org.assertj.core.api.Assertions.assertThat;
2931

3032
@SpringBootTest(classes = MistralAiTestConfiguration.class)
3133
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
3234
class MistralAiEmbeddingIT {
3335

34-
@Autowired
35-
private MistralAiEmbeddingModel mistralAiEmbeddingModel;
36-
3736
@Test
38-
void defaultEmbedding() {
39-
assertThat(this.mistralAiEmbeddingModel).isNotNull();
40-
var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
37+
void defaultEmbedding(@Autowired MistralAiEmbeddingModel mistralAiEmbeddingModel) {
38+
assertThat(mistralAiEmbeddingModel).isNotNull();
39+
var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
4140
assertThat(embeddingResponse.getResults()).hasSize(1);
4241
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
4342
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
4443
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
4544
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
4645
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
47-
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
46+
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
4847
}
4948

50-
@Test
51-
void embeddingTest() {
52-
assertThat(this.mistralAiEmbeddingModel).isNotNull();
53-
var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest(
54-
List.of("Hello World", "World is big"),
55-
MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build()));
49+
@CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" })
50+
@ParameterizedTest
51+
void embeddingTest(String model, int dimensions, @Autowired MistralAiApi mistralAiApi) {
52+
var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build();
53+
var mistralAiEmbeddingModel = new MistralAiEmbeddingModel(mistralAiApi, mistralAiEmbeddingOptions);
54+
var embeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).withEncodingFormat("float").build();
55+
var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big"), embeddingOptions);
56+
var embeddingResponse = mistralAiEmbeddingModel.call(embeddingRequest);
5657
assertThat(embeddingResponse.getResults()).hasSize(2);
5758
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
58-
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
59-
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed");
59+
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(dimensions);
60+
assertThat(embeddingResponse.getResults().get(1)).isNotNull();
61+
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(dimensions);
62+
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model);
6063
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9);
6164
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9);
62-
assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
65+
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(dimensions);
6366
}
6467

6568
}

0 commit comments

Comments
 (0)