|
16 | 16 |
|
17 | 17 | package org.springframework.ai.mistralai; |
18 | 18 |
|
19 | | -import java.util.List; |
20 | | - |
21 | 19 | import org.junit.jupiter.api.Test; |
22 | 20 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; |
23 | | - |
| 21 | +import org.junit.jupiter.params.ParameterizedTest; |
| 22 | +import org.junit.jupiter.params.provider.CsvSource; |
24 | 23 | import org.springframework.ai.embedding.EmbeddingRequest; |
| 24 | +import org.springframework.ai.mistralai.api.MistralAiApi; |
25 | 25 | import org.springframework.beans.factory.annotation.Autowired; |
26 | 26 | import org.springframework.boot.test.context.SpringBootTest; |
27 | 27 |
|
| 28 | +import java.util.List; |
| 29 | + |
28 | 30 | import static org.assertj.core.api.Assertions.assertThat; |
29 | 31 |
|
30 | 32 | @SpringBootTest(classes = MistralAiTestConfiguration.class) |
31 | 33 | @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") |
32 | 34 | class MistralAiEmbeddingIT { |
33 | 35 |
|
34 | | - @Autowired |
35 | | - private MistralAiEmbeddingModel mistralAiEmbeddingModel; |
36 | | - |
37 | 36 | @Test |
38 | | - void defaultEmbedding() { |
39 | | - assertThat(this.mistralAiEmbeddingModel).isNotNull(); |
40 | | - var embeddingResponse = this.mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); |
| 37 | + void defaultEmbedding(@Autowired MistralAiEmbeddingModel mistralAiEmbeddingModel) { |
| 38 | + assertThat(mistralAiEmbeddingModel).isNotNull(); |
| 39 | + var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World")); |
41 | 40 | assertThat(embeddingResponse.getResults()).hasSize(1); |
42 | 41 | assertThat(embeddingResponse.getResults().get(0)).isNotNull(); |
43 | 42 | assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); |
44 | 43 | assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); |
45 | 44 | assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4); |
46 | 45 | assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4); |
47 | | - assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); |
| 46 | + assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); |
48 | 47 | } |
49 | 48 |
|
50 | | - @Test |
51 | | - void embeddingTest() { |
52 | | - assertThat(this.mistralAiEmbeddingModel).isNotNull(); |
53 | | - var embeddingResponse = this.mistralAiEmbeddingModel.call(new EmbeddingRequest( |
54 | | - List.of("Hello World", "World is big"), |
55 | | - MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build())); |
| 49 | + @CsvSource({ "mistral-embed, 1024", "codestral-embed, 1536" }) |
| 50 | + @ParameterizedTest |
| 51 | + void embeddingTest(String model, int dimensions, @Autowired MistralAiApi mistralAiApi) { |
| 52 | + var mistralAiEmbeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).build(); |
| 53 | + var mistralAiEmbeddingModel = new MistralAiEmbeddingModel(mistralAiApi, mistralAiEmbeddingOptions); |
| 54 | + var embeddingOptions = MistralAiEmbeddingOptions.builder().withModel(model).withEncodingFormat("float").build(); |
| 55 | + var embeddingRequest = new EmbeddingRequest(List.of("Hello World", "World is big"), embeddingOptions); |
| 56 | + var embeddingResponse = mistralAiEmbeddingModel.call(embeddingRequest); |
56 | 57 | assertThat(embeddingResponse.getResults()).hasSize(2); |
57 | 58 | assertThat(embeddingResponse.getResults().get(0)).isNotNull(); |
58 | | - assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024); |
59 | | - assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("mistral-embed"); |
| 59 | + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(dimensions); |
| 60 | + assertThat(embeddingResponse.getResults().get(1)).isNotNull(); |
| 61 | + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(dimensions); |
| 62 | + assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(model); |
60 | 63 | assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(9); |
61 | 64 | assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(9); |
62 | | - assertThat(this.mistralAiEmbeddingModel.dimensions()).isEqualTo(1024); |
| 65 | + assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(dimensions); |
63 | 66 | } |
64 | 67 |
|
65 | 68 | } |
0 commit comments