diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 295d20c3876..1b9a0e161f9 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -15,15 +15,21 @@ */ package org.springframework.ai.mistralai; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -39,24 +45,16 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; import org.springframework.ai.mistralai.metadata.MistralAiUsage; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; + import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; - /** * @author Ricken Bazolo * @author Christian Tzolov @@ -67,7 +65,7 @@ */ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel { - private final Logger log = LoggerFactory.getLogger(getClass()); + private final Logger logger = LoggerFactory.getLogger(getClass()); /** * The default options used for the chat completion requests. @@ -108,158 +106,133 @@ public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions option @Override public ChatResponse call(Prompt prompt) { - var request = createRequest(prompt, false); - - return retryTemplate.execute(ctx -> { - ResponseEntity completionEntity = this.mistralAiApi.chatCompletionEntity(request); + var request = createRequest(prompt, false); - if (this.isToolFunctionCall(completionEntity.getBody())) { - List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), - completionEntity.getBody()); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions())); - } + ResponseEntity completionEntity = retryTemplate + .execute(ctx -> this.mistralAiApi.chatCompletionEntity(request)); - var chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - log.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + ChatCompletion chatCompletion = completionEntity.getBody(); - List generations = chatCompletion.choices() - .stream() - .map(choice -> new Generation(choice.message().content(), toMap(chatCompletion.id(), choice)) - .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null))) - .toList(); + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - return new ChatResponse(generations, from(chatCompletion)); - }); - } + List generations = chatCompletion.choices().stream().map(choice -> { + // @formatter:off + Map metadata = Map.of( + "id", chatCompletion.id() != null ? chatCompletion.id() : "", + "index", choice.index(), + "role", choice.message().role() != null ? choice.message().role().name() : "", + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); - public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { - Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); - MistralAiUsage usage = MistralAiUsage.from(result.usage()); - return ChatResponseMetadata.builder() - .withId(result.id()) - .withModel(result.model()) - .withUsage(usage) - .withKeyValue("created", result.created()) - .build(); - } + // // Non function calling. + // RateLimit rateLimit = + // OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); - private Map toMap(String id, ChatCompletion.Choice choice) { - Map map = new HashMap<>(); + ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); - var message = choice.message(); - if (message.role() != null) { - map.put("role", message.role().name()); - } - if (choice.finishReason() != null) { - map.put("finishReason", choice.finishReason().name()); + if (isToolCall(chatResponse, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, chatResponse); + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.call(new Prompt(toolCallConversation, prompt.getOptions())); } - map.put("id", id); - return map; + + return chatResponse; } @Override public Flux stream(Prompt prompt) { var request = createRequest(prompt, true); - return retryTemplate.execute(ctx -> { - - Flux completionChunks = this.mistralAiApi.chatCompletionStream(request); + Flux completionChunks = retryTemplate + .execute(ctx -> this.mistralAiApi.chatCompletionStream(request)); - // For chunked responses, only the first chunk contains the choice role. - // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); - return completionChunks.map(this::toChatCompletion).switchMap(chatCompletion -> { - if (this.isToolFunctionCall(chatCompletion)) { - var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), - chatCompletion); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); - } - - return Mono.just(chatCompletion).map(chatCompletion2 -> { + // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse + // the function call handling logic. + Flux chatResponse = completionChunks.map(this::toChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + try { @SuppressWarnings("null") String id = chatCompletion2.id(); + // @formatter:off List generations = chatCompletion2.choices().stream().map(choice -> { if (choice.message().role() != null) { roleMap.putIfAbsent(id, choice.message().role().name()); } - String finish = (choice.finishReason() != null ? choice.finishReason().name() : ""); - var generation = new Generation(choice.message().content(), - Map.of("id", id, "role", roleMap.get(id), "finishReason", finish)); - if (choice.finishReason() != null) { - generation = generation.withGenerationMetadata( - ChatGenerationMetadata.from(choice.finishReason().name(), null)); - } - return generation; - }).toList(); - - return new ChatResponse(generations); - }); - }); - }); - } - - private List handleToolCallRequests(List previousMessages, ChatCompletion chatCompletion) { - - ChatCompletionMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion); - - List assistantToolCalls = nativeAssistantMessage.toolCalls() - .stream() - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", toolCall.function().name(), - toolCall.function().arguments())) - .toList(); - - AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.content(), Map.of(), - assistantToolCalls); - - ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage); - - // History - List messages = new ArrayList<>(previousMessages); - messages.add(assistantMessage); - messages.add(toolResponseMessage); - - return messages; - } - - private ChatCompletionMessage extractAssistantMessage(ChatCompletion chatCompletion) { - ChatCompletionMessage msg = chatCompletion.choices().iterator().next().message(); - if (msg.role() == null) { - // add missing role - msg = new ChatCompletionMessage(msg.content(), ChatCompletionMessage.Role.ASSISTANT, msg.name(), - msg.toolCalls()); - } - return msg; - } - - protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage) { - - List toolResponses = new ArrayList<>(); + Map metadata = Map.of( + "id", chatCompletion2.id(), + "role", roleMap.getOrDefault(id, ""), + "index", choice.index(), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + return buildGeneration(choice, metadata); + }).toList(); + // @formatter:on + + if (chatCompletion2.usage() != null) { + return new ChatResponse(generations, from(chatCompletion2)); + } + else { + return new ChatResponse(generations); + } + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + })); - var functionName = toolCall.name(); - String functionArguments = toolCall.arguments(); + return chatResponse.flatMap(response -> { - if (!this.functionCallbackRegister.containsKey(functionName)) { - throw new IllegalStateException("No function callback found for function name: " + functionName); + if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + MistralAiApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + } + else { + return Flux.just(response); } + }); + } - String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); + private Generation buildGeneration(Choice choice, Map metadata) { + List toolCalls = choice.message().toolCalls() == null ? List.of() + : choice.message() + .toolCalls() + .stream() + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", + toolCall.function().name(), toolCall.function().arguments())) + .toList(); - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, functionResponse)); - } + var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); + var generationMetadata = ChatGenerationMetadata.from(finishReason, null); + return new Generation(assistantMessage, generationMetadata); + } - return new ToolResponseMessage(toolResponses, Map.of()); + public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { + Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); + MistralAiUsage usage = MistralAiUsage.from(result.usage()); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withModel(result.model()) + .withUsage(usage) + .withKeyValue("created", result.created()) + .build(); } private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { @@ -358,21 +331,6 @@ private List getFunctionTools(Set functionNam }).toList(); } - protected boolean isToolFunctionCall(ChatCompletion chatCompletion) { - - var body = chatCompletion; - if (body == null) { - return false; - } - - var choices = body.choices(); - if (CollectionUtils.isEmpty(choices)) { - return false; - } - - return !CollectionUtils.isEmpty(choices.get(0).message().toolCalls()); - } - @Override public ChatOptions getDefaultOptions() { return MistralAiChatOptions.fromOptions(this.defaultOptions); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 65646ea2918..3255e34ebaf 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -166,6 +166,7 @@ public ChatResponse call(Prompt prompt) { Map metadata = Map.of( "id", chatCompletion.id() != null ? chatCompletion.id() : "", "role", choice.message().role() != null ? choice.message().role().name() : "", + "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); // @formatter:on return buildGeneration(choice, metadata); @@ -215,6 +216,7 @@ public Flux stream(Prompt prompt) { Map metadata = Map.of( "id", chatCompletion2.id(), "role", roleMap.getOrDefault(id, ""), + "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); return buildGeneration(choice, metadata); }).toList(); @@ -236,7 +238,8 @@ public Flux stream(Prompt prompt) { return chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) { + if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses.