Skip to content

Commit b104bee

Browse files
committed
Add Mistral AI Medium and Saba models
- Add Mistral Medium Latest model to ChatModel - Add Mistral Saba Latest model to ChatModel - Rename Ministraux models from ChatModel - Improve javadoc and update documentation links - Add small code enhancements Signed-off-by: Nicolas Krier <[email protected]>
1 parent fa5fb53 commit b104bee

File tree

4 files changed

+50
-30
lines changed

4 files changed

+50
-30
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
153153

154154
// The input must not an empty string, and any array must be 1024 dimensions or
155155
// less.
156-
if (embeddingRequest.input() instanceof List list) {
156+
if (embeddingRequest.input() instanceof List<?> list) {
157157
Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty.");
158158
Assert.isTrue(list.size() <= 1024, "The list must be 1024 dimensions or less");
159159
Assert.isTrue(
@@ -269,28 +269,28 @@ public enum ChatCompletionFinishReason {
269269

270270
/**
271271
* List of well-known Mistral chat models.
272-
* https://docs.mistral.ai/platform/endpoints/#mistral-ai-generative-models
273272
*
274-
* <p>
275-
* Mistral AI provides two types of models: open-weights models (Mistral 7B, Mixtral
276-
* 8x7B, Mixtral 8x22B) and optimized commercial models (Mistral Small, Mistral
277-
* Medium, Mistral Large, and Mistral Embeddings).
273+
* @see <a href=
274+
* "https://docs.mistral.ai/getting-started/models/models_overview/">Mistral AI Models
275+
* Overview</a>
278276
*/
279277
public enum ChatModel implements ChatModelDescription {
280278

281279
// @formatter:off
282280
// Premier Models
283281
CODESTRAL("codestral-latest"),
284282
LARGE("mistral-large-latest"),
283+
MEDIUM("mistral-medium-latest"),
284+
MINISTRAL_3B("ministral-3b-latest"),
285+
MINISTRAL_8B("ministral-8b-latest"),
285286
PIXTRAL_LARGE("pixtral-large-latest"),
286-
MINISTRAL_3B_LATEST("ministral-3b-latest"),
287-
MINISTRAL_8B_LATEST("ministral-8b-latest"),
287+
SABA("mistral-saba-latest"),
288288
// Free Models
289-
SMALL("mistral-small-latest"),
290289
PIXTRAL("pixtral-12b-2409"),
290+
SMALL("mistral-small-latest"),
291291
// Free Models - Research
292-
OPEN_MISTRAL_NEMO("open-mistral-nemo"),
293-
OPEN_CODESTRAL_MAMBA("open-codestral-mamba");
292+
OPEN_CODESTRAL_MAMBA("open-codestral-mamba"),
293+
OPEN_MISTRAL_NEMO("open-mistral-nemo");
294294
// @formatter:on
295295

296296
private final String value;
@@ -312,11 +312,15 @@ public String getName() {
312312

313313
/**
314314
* List of well-known Mistral embedding models.
315-
* https://docs.mistral.ai/platform/endpoints/#mistral-ai-embedding-model
315+
*
316+
* @see <a href=
317+
* "https://docs.mistral.ai/getting-started/models/models_overview/">Mistral AI Models
318+
* Overview</a>
316319
*/
317320
public enum EmbeddingModel {
318321

319322
// @formatter:off
323+
// Premier Models
320324
EMBED("mistral-embed");
321325
// @formatter:on
322326

@@ -820,8 +824,7 @@ public String content() {
820824
}
821825

822826
/**
823-
* The role of the author of this message.
824-
*
827+
* The role of the author of this message. <br/>
825828
* NOTE: Mistral expects the system message to be before the user message or will
826829
* fail with 400 error.
827830
*/

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiModerationApi.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
* MistralAI Moderation API.
3434
*
3535
* @author Ricken Bazolo
36-
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/</a>
36+
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/">Moderation</a>
3737
*/
3838
public class MistralAiModerationApi {
3939

@@ -71,9 +71,17 @@ public ResponseEntity<MistralAiModerationResponse> moderate(MistralAiModerationR
7171
.toEntity(MistralAiModerationResponse.class);
7272
}
7373

74+
/**
75+
* List of well-known Mistral moderation models.
76+
*
77+
* @see <a href=
78+
* "https://docs.mistral.ai/getting-started/models/models_overview/">Mistral AI Models
79+
* Overview</a>
80+
*/
7481
public enum Model {
7582

7683
// @formatter:off
84+
// Premier Models
7785
MISTRAL_MODERATION("mistral-moderation-latest");
7886
// @formatter:on
7987

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,10 @@
1616

1717
package org.springframework.ai.mistralai;
1818

19-
import java.util.Arrays;
20-
import java.util.List;
21-
import java.util.Map;
22-
import java.util.stream.Collectors;
23-
2419
import org.junit.jupiter.api.Test;
2520
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2621
import org.slf4j.Logger;
2722
import org.slf4j.LoggerFactory;
28-
import reactor.core.publisher.Flux;
29-
3023
import org.springframework.ai.chat.client.ChatClient;
3124
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
3225
import org.springframework.ai.chat.messages.UserMessage;
@@ -43,6 +36,13 @@
4336
import org.springframework.core.ParameterizedTypeReference;
4437
import org.springframework.core.convert.support.DefaultConversionService;
4538
import org.springframework.core.io.Resource;
39+
import reactor.core.publisher.Flux;
40+
41+
import java.util.Arrays;
42+
import java.util.Collections;
43+
import java.util.List;
44+
import java.util.Map;
45+
import java.util.stream.Collectors;
4646

4747
import static org.assertj.core.api.Assertions.assertThat;
4848

@@ -53,7 +53,7 @@ class MistralAiChatClientIT {
5353
private static final Logger logger = LoggerFactory.getLogger(MistralAiChatClientIT.class);
5454

5555
@Autowired
56-
MistralAiChatModel chatModel;
56+
private MistralAiChatModel chatModel;
5757

5858
@Value("classpath:/prompts/system-message.st")
5959
private Resource systemTextResource;
@@ -290,9 +290,7 @@ void streamFunctionCallTest() {
290290

291291
@Test
292292
void validateCallResponseMetadata() {
293-
// String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getName();
294-
String model = MistralAiApi.ChatModel.PIXTRAL.getName();
295-
// String model = MistralAiApi.ChatModel.PIXTRAL_LARGE.getName();
293+
String model = selectChatModelName();
296294
// @formatter:off
297295
ChatResponse response = ChatClient.create(this.chatModel).prompt()
298296
.options(MistralAiChatOptions.builder().model(model).build())
@@ -301,6 +299,7 @@ void validateCallResponseMetadata() {
301299
.chatResponse();
302300
// @formatter:on
303301

302+
assertThat(response).isNotNull();
304303
logger.info(response.toString());
305304
assertThat(response.getMetadata().getId()).isNotEmpty();
306305
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
@@ -309,6 +308,15 @@ void validateCallResponseMetadata() {
309308
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
310309
}
311310

311+
private static String selectChatModelName() {
312+
var chatModels = Arrays.asList(MistralAiApi.ChatModel.values());
313+
Collections.shuffle(chatModels);
314+
var chatModelName = chatModels.get(0).getName();
315+
logger.info("Selected chat model name: {}", chatModelName);
316+
317+
return chatModelName;
318+
}
319+
312320
record ActorsFilms(String actor, List<String> movies) {
313321

314322
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
* @since 0.8.1
4040
*/
4141
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
42-
public class MistralAiApiIT {
42+
class MistralAiApiIT {
4343

44-
MistralAiApi mistralAiApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY"));
44+
private final MistralAiApi mistralAiApi = new MistralAiApi(System.getenv("MISTRAL_AI_API_KEY"));
4545

4646
@Test
4747
void chatCompletionEntity() {
@@ -61,7 +61,7 @@ void chatCompletionEntityWithSystemMessage() {
6161
You are an AI assistant that helps people find information.
6262
Your name is Bob.
6363
You should reply to the user's request with your name and also in the style of a pirate.
64-
""", Role.SYSTEM);
64+
""", Role.SYSTEM);
6565

6666
ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
6767
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
@@ -83,9 +83,10 @@ void chatCompletionStream() {
8383
@Test
8484
void embeddings() {
8585
ResponseEntity<EmbeddingList<Embedding>> response = this.mistralAiApi
86-
.embeddings(new MistralAiApi.EmbeddingRequest<String>("Hello world"));
86+
.embeddings(new MistralAiApi.EmbeddingRequest<>("Hello world"));
8787

8888
assertThat(response).isNotNull();
89+
assertThat(response.getBody()).isNotNull();
8990
assertThat(response.getBody().data()).hasSize(1);
9091
assertThat(response.getBody().data().get(0).embedding()).hasSize(1024);
9192
}

0 commit comments

Comments
 (0)