diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 288e5d7c1fe..da23e253fbf 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -17,7 +17,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -57,9 +56,12 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import static org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE; + /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax} * backed by {@link MiniMaxApi}. @@ -169,12 +171,21 @@ public ChatResponse call(Prompt prompt) { List generations = choices.stream().map(choice -> { // @formatter:off + // if the choice is a web search tool call, return last message of choice.messages + ChatCompletionMessage message = null; + if(choice.message() != null) { + message = choice.message(); + } else if(!CollectionUtils.isEmpty(choice.messages())){ + // the MiniMax web search messages result is ['user message','assistant tool call', 'tool call', 'assistant message'] + // so the last message is the assistant message + message = choice.messages().get(choice.messages().size() - 1); + } Map metadata = Map.of( "id", chatCompletion.id(), - "role", choice.message().role() != null ? choice.message().role().name() : "", + "role", message != null && message.role() != null ? message.role().name() : "", "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); // @formatter:on - return buildGeneration(choice, metadata); + return buildGeneration(message, choice.finishReason(), metadata); }).toList(); ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); @@ -224,7 +235,7 @@ public Flux stream(Prompt prompt) { "role", roleMap.getOrDefault(id, ""), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); return buildGeneration(choice, metadata); - }).toList(); + }).filter(Objects::nonNull).toList(); // @formatter:on if (chatCompletion2.usage() != null) { @@ -250,12 +261,28 @@ public Flux stream(Prompt prompt) { // conversation that contains the call responses. return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); } - else { - return Flux.just(response); - } + return Flux.just(response); }); } + /** + * The MimiMax web search function tool type is 'web_search', so we need to filter out + * the tool calls whose type is not 'function' + * @param generation the generation to check + * @param toolCallFinishReasons the tool call finish reasons + * @return true if the generation is a tool call + */ + @Override + protected boolean isToolCall(Generation generation, Set toolCallFinishReasons) { + if (!super.isToolCall(generation, toolCallFinishReasons)) { + return false; + } + return generation.getOutput() + .getToolCalls() + .stream() + .anyMatch(toolCall -> TOOL_CALL_FUNCTION_TYPE.equals(toolCall.type())); + } + private ChatResponseMetadata from(ChatCompletion result, RateLimit rateLimit) { Assert.notNull(result, "MiniMax ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() @@ -277,21 +304,28 @@ private ChatResponseMetadata from(ChatCompletion result) { .build(); } - private static Generation buildGeneration(Choice choice, Map metadata) { - List toolCalls = choice.message().toolCalls() == null ? List.of() - : choice.message() - .toolCalls() + private Generation buildGeneration(ChatCompletionMessage message, ChatCompletionFinishReason completionFinishReason, + Map metadata) { + if (message == null || message.role() == Role.TOOL) { + return null; + } + List toolCalls = message.toolCalls() == null ? List.of() + : message.toolCalls() .stream() - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(), toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); - String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls); + String finishReason = (completionFinishReason != null ? completionFinishReason.name() : ""); var generationMetadata = ChatGenerationMetadata.from(finishReason, null); return new Generation(assistantMessage, generationMetadata); } + private Generation buildGeneration(Choice choice, Map metadata) { + return buildGeneration(choice.message(), choice.finishReason(), metadata); + } + /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java index bc1b8d13985..1b2e797d83e 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java @@ -10,4 +10,6 @@ public final class MiniMaxApiConstants { public static final String DEFAULT_BASE_URL = "https://api.minimax.chat"; + public static final String TOOL_CALL_FUNCTION_TYPE = "function"; + } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java index 35326fa6efa..23daaf5cbc6 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java @@ -2,18 +2,26 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.minimax.api.MiniMaxApi; +import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat; /** * @author Geng Rong @@ -21,6 +29,8 @@ @EnabledIfEnvironmentVariable(named = "MINIMAX_API_KEY", matches = ".+") public class MiniMaxChatOptionsTests { + private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatOptionsTests.class); + private final MiniMaxChatModel chatModel = new MiniMaxChatModel(new MiniMaxApi(System.getenv("MINIMAX_API_KEY"))); @Test @@ -46,4 +56,72 @@ void testMarkSensitiveInfo() { assertThat(unmaskResponseContent).contains("133-12345678"); } + /** + * There is a certain probability of failure, because it needs to be searched through + * the network, which may cause the test to fail due to different search results. And + * the search results are related to time. For example, after the start of the Paris + * Paralympic Games, searching for the number of gold medals in the Paris Olympics may + * be affected by the search results of the number of gold medals in the Paris + * Paralympic Games with higher priority by the search engine. Even if the input is an + * English question, there may be get Chinese content, because the main training + * content of MiniMax and search engine are Chinese + */ + @Test + void testWebSearch() { + UserMessage userMessage = new UserMessage( + "How many gold medals has the United States won in total at the 2024 Olympics?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + List functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool()); + + MiniMaxChatOptions options = MiniMaxChatOptions.builder() + .withModel(ABAB_6_5_S_Chat.value) + .withTools(functionTool) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, options)); + String responseContent = response.getResult().getOutput().getContent(); + + assertThat(responseContent).contains("40"); + } + + /** + * There is a certain probability of failure, because it needs to be searched through + * the network, which may cause the test to fail due to different search results. And + * the search results are related to time. For example, after the start of the Paris + * Paralympic Games, searching for the number of gold medals in the Paris Olympics may + * be affected by the search results of the number of gold medals in the Paris + * Paralympic Games with higher priority by the search engine. Even if the input is an + * English question, there may be get Chinese content, because the main training + * content of MiniMax and search engine of MiniMax are Chinese + */ + @Test + void testWebSearchStream() { + UserMessage userMessage = new UserMessage( + "How many gold medals has the United States won in total at the 2024 Olympics?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + List functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool()); + + MiniMaxChatOptions options = MiniMaxChatOptions.builder() + .withModel(ABAB_6_5_S_Chat.value) + .withTools(functionTool) + .build(); + + Flux response = chatModel.stream(new Prompt(messages, options)); + String content = Objects.requireNonNull(response.collectList().block()) + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("40"); + } + }