Skip to content

Commit c61fe60

Browse files
committed
added index of tool call in the list of tool calls
Signed-off-by: Ricken Bazolo <[email protected]>
1 parent b77e084 commit c61fe60

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ else if (message instanceof AssistantMessage assistantMessage) {
380380
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
381381
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
382382
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
383-
return new ToolCall(toolCall.id(), toolCall.type(), function);
383+
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
384384
}).toList();
385385
}
386386

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,10 +824,11 @@ public enum Role {
824824
* @param type The type of tool call the output is required for. For now, this is
825825
* always function.
826826
* @param function The function definition.
827+
* @param index The index of the tool call in the list of tool calls.
827828
*/
828829
@JsonInclude(Include.NON_NULL)
829830
public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type,
830-
@JsonProperty("function") ChatCompletionFunction function) {
831+
@JsonProperty("function") ChatCompletionFunction function, @JsonProperty("index") Integer index) {
831832

832833
}
833834

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717
package org.springframework.ai.mistralai.api;
1818

19-
import java.util.ArrayList;
20-
import java.util.List;
21-
import java.util.Optional;
22-
import java.util.UUID;
19+
import java.util.*;
2320

2421
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk;
2522
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk.ChunkChoice;
@@ -74,16 +71,16 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
7471
Optional<String> id = current.delta()
7572
.toolCalls()
7673
.stream()
77-
.filter(tool -> tool.id() != null)
78-
.map(tool -> tool.id())
74+
.map(ToolCall::id)
75+
.filter(Objects::nonNull)
7976
.findFirst();
80-
if (!id.isPresent()) {
77+
if (id.isEmpty()) {
8178
var newId = UUID.randomUUID().toString();
8279

8380
var toolCallsWithID = current.delta()
8481
.toolCalls()
8582
.stream()
86-
.map(toolCall -> new ToolCall(newId, "function", toolCall.function()))
83+
.map(toolCall -> new ToolCall(newId, "function", toolCall.function(), toolCall.index()))
8784
.toList();
8885

8986
var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT;
@@ -151,7 +148,8 @@ private ToolCall merge(ToolCall previous, ToolCall current) {
151148
String id = (current.id() != null ? current.id() : previous.id());
152149
String type = (current.type() != null ? current.type() : previous.type());
153150
ChatCompletionFunction function = merge(previous.function(), current.function());
154-
return new ToolCall(id, type, function);
151+
Integer index = (current.index() != null ? current.index() : previous.index());
152+
return new ToolCall(id, type, function, index);
155153
}
156154

157155
private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {

0 commit comments

Comments
 (0)