diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 98376720f2d..d9441bacaec 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -380,7 +380,7 @@ else if (message instanceof AssistantMessage assistantMessage) { if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments()); - return new ToolCall(toolCall.id(), toolCall.type(), function); + return new ToolCall(toolCall.id(), toolCall.type(), function, null); }).toList(); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index 433ace72b7f..2b14a1db181 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -824,10 +824,11 @@ public enum Role { * @param type The type of tool call the output is required for. For now, this is * always function. * @param function The function definition. + * @param index The index of the tool call in the list of tool calls. */ @JsonInclude(Include.NON_NULL) public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type, - @JsonProperty("function") ChatCompletionFunction function) { + @JsonProperty("function") ChatCompletionFunction function, @JsonProperty("index") Integer index) { } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java index 608e8d6ffbb..079edfa988b 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java @@ -16,10 +16,7 @@ package org.springframework.ai.mistralai.api; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.UUID; +import java.util.*; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk.ChunkChoice; @@ -74,16 +71,16 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { Optional id = current.delta() .toolCalls() .stream() - .filter(tool -> tool.id() != null) - .map(tool -> tool.id()) + .map(ToolCall::id) + .filter(Objects::nonNull) .findFirst(); - if (!id.isPresent()) { + if (id.isEmpty()) { var newId = UUID.randomUUID().toString(); var toolCallsWithID = current.delta() .toolCalls() .stream() - .map(toolCall -> new ToolCall(newId, "function", toolCall.function())) + .map(toolCall -> new ToolCall(newId, "function", toolCall.function(), toolCall.index())) .toList(); var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT; @@ -151,7 +148,8 @@ private ToolCall merge(ToolCall previous, ToolCall current) { String id = (current.id() != null ? current.id() : previous.id()); String type = (current.type() != null ? current.type() : previous.type()); ChatCompletionFunction function = merge(previous.function(), current.function()); - return new ToolCall(id, type, function); + Integer index = (current.index() != null ? current.index() : previous.index()); + return new ToolCall(id, type, function, index); } private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {