Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
Expand Down Expand Up @@ -57,9 +56,12 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE;

/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax}
* backed by {@link MiniMaxApi}.
Expand Down Expand Up @@ -169,12 +171,21 @@ public ChatResponse call(Prompt prompt) {

List<Generation> generations = choices.stream().map(choice -> {
// @formatter:off
// if the choice is a web search tool call, return last message of choice.messages
ChatCompletionMessage message = null;
if(choice.message() != null) {
message = choice.message();
} else if(!CollectionUtils.isEmpty(choice.messages())){
// the MiniMax web search messages result is ['user message','assistant tool call', 'tool call', 'assistant message']
// so the last message is the assistant message
message = choice.messages().get(choice.messages().size() - 1);
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"role", message != null && message.role() != null ? message.role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
return buildGeneration(message, choice.finishReason(), metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
Expand Down Expand Up @@ -224,7 +235,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
"role", roleMap.getOrDefault(id, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
}).filter(Objects::nonNull).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
Expand All @@ -250,12 +261,28 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
return Flux.just(response);
});
}

/**
* The MimiMax web search function tool type is 'web_search', so we need to filter out
* the tool calls whose type is not 'function'
* @param generation the generation to check
* @param toolCallFinishReasons the tool call finish reasons
* @return true if the generation is a tool call
*/
@Override
protected boolean isToolCall(Generation generation, Set<String> toolCallFinishReasons) {
if (!super.isToolCall(generation, toolCallFinishReasons)) {
return false;
}
return generation.getOutput()
.getToolCalls()
.stream()
.anyMatch(toolCall -> TOOL_CALL_FUNCTION_TYPE.equals(toolCall.type()));
}

private ChatResponseMetadata from(ChatCompletion result, RateLimit rateLimit) {
Assert.notNull(result, "MiniMax ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
Expand All @@ -277,21 +304,28 @@ private ChatResponseMetadata from(ChatCompletion result) {
.build();
}

private static Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
: choice.message()
.toolCalls()
private Generation buildGeneration(ChatCompletionMessage message, ChatCompletionFinishReason completionFinishReason,
Map<String, Object> metadata) {
if (message == null || message.role() == Role.TOOL) {
return null;
}
List<AssistantMessage.ToolCall> toolCalls = message.toolCalls() == null ? List.of()
: message.toolCalls()
.stream()
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(),
toolCall.function().name(), toolCall.function().arguments()))
.toList();

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls);
String finishReason = (completionFinishReason != null ? completionFinishReason.name() : "");
var generationMetadata = ChatGenerationMetadata.from(finishReason, null);
return new Generation(assistantMessage, generationMetadata);
}

private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
return buildGeneration(choice.message(), choice.finishReason(), metadata);
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ public final class MiniMaxApiConstants {

public static final String DEFAULT_BASE_URL = "https://api.minimax.chat";

public static final String TOOL_CALL_FUNCTION_TYPE = "function";

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.api.MiniMaxApi;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat;

/**
* @author Geng Rong
*/
@EnabledIfEnvironmentVariable(named = "MINIMAX_API_KEY", matches = ".+")
public class MiniMaxChatOptionsTests {

private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatOptionsTests.class);

private final MiniMaxChatModel chatModel = new MiniMaxChatModel(new MiniMaxApi(System.getenv("MINIMAX_API_KEY")));

@Test
Expand All @@ -46,4 +56,72 @@ void testMarkSensitiveInfo() {
assertThat(unmaskResponseContent).contains("133-12345678");
}

/**
* There is a certain probability of failure, because it needs to be searched through
* the network, which may cause the test to fail due to different search results. And
* the search results are related to time. For example, after the start of the Paris
* Paralympic Games, searching for the number of gold medals in the Paris Olympics may
* be affected by the search results of the number of gold medals in the Paris
* Paralympic Games with higher priority by the search engine. Even if the input is an
* English question, there may be get Chinese content, because the main training
* content of MiniMax and search engine are Chinese
*/
@Test
void testWebSearch() {
UserMessage userMessage = new UserMessage(
"How many gold medals has the United States won in total at the 2024 Olympics?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

List<MiniMaxApi.FunctionTool> functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool());

MiniMaxChatOptions options = MiniMaxChatOptions.builder()
.withModel(ABAB_6_5_S_Chat.value)
.withTools(functionTool)
.build();

ChatResponse response = chatModel.call(new Prompt(messages, options));
String responseContent = response.getResult().getOutput().getContent();

assertThat(responseContent).contains("40");
}

/**
* There is a certain probability of failure, because it needs to be searched through
* the network, which may cause the test to fail due to different search results. And
* the search results are related to time. For example, after the start of the Paris
* Paralympic Games, searching for the number of gold medals in the Paris Olympics may
* be affected by the search results of the number of gold medals in the Paris
* Paralympic Games with higher priority by the search engine. Even if the input is an
* English question, there may be get Chinese content, because the main training
* content of MiniMax and search engine of MiniMax are Chinese
*/
@Test
void testWebSearchStream() {
UserMessage userMessage = new UserMessage(
"How many gold medals has the United States won in total at the 2024 Olympics?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

List<MiniMaxApi.FunctionTool> functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool());

MiniMaxChatOptions options = MiniMaxChatOptions.builder()
.withModel(ABAB_6_5_S_Chat.value)
.withTools(functionTool)
.build();

Flux<ChatResponse> response = chatModel.stream(new Prompt(messages, options));
String content = Objects.requireNonNull(response.collectList().block())
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.filter(Objects::nonNull)
.collect(Collectors.joining());
logger.info("Response: {}", content);

assertThat(content).contains("40");
}

}