Skip to content

Commit 81dfd3b

Browse files
ilayaperumalgtzolov
authored andcommitted
Fix Mistral AI Chat model function call usage calculation
- Fix the chat model's call() to calculate the cumulative usage - Use an explicit internalCall to pass the previous chat response so that accumulation can be done - Fix the chat model's stream() to calculate the cumulative usage - Fix MistralAi API to include usgae in ChatCompletionChunk - Use internalStream() to accumulate the usage Add/update tests
1 parent 187360d commit 81dfd3b

File tree

5 files changed

+70
-9
lines changed

5 files changed

+70
-9
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import org.springframework.ai.chat.messages.UserMessage;
3737
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3838
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
39+
import org.springframework.ai.chat.metadata.Usage;
40+
import org.springframework.ai.chat.metadata.UsageUtils;
3941
import org.springframework.ai.chat.model.AbstractToolCallSupport;
4042
import org.springframework.ai.chat.model.ChatModel;
4143
import org.springframework.ai.chat.model.ChatResponse;
@@ -74,6 +76,7 @@
7476
* @author Grogdunn
7577
* @author Thomas Vitale
7678
* @author luocongqiu
79+
* @author Ilayaperumal Gopinathan
7780
* @since 1.0.0
7881
*/
7982
public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel {
@@ -155,8 +158,22 @@ public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
155158
.build();
156159
}
157160

161+
public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result, Usage usage) {
162+
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
163+
return ChatResponseMetadata.builder()
164+
.withId(result.id())
165+
.withModel(result.model())
166+
.withUsage(usage)
167+
.withKeyValue("created", result.created())
168+
.build();
169+
}
170+
158171
@Override
159172
public ChatResponse call(Prompt prompt) {
173+
return this.internalCall(prompt, null);
174+
}
175+
176+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
160177

161178
MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false);
162179

@@ -192,7 +209,10 @@ public ChatResponse call(Prompt prompt) {
192209
return buildGeneration(choice, metadata);
193210
}).toList();
194211

195-
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
212+
MistralAiUsage usage = MistralAiUsage.from(completionEntity.getBody().usage());
213+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse);
214+
ChatResponse chatResponse = new ChatResponse(generations,
215+
from(completionEntity.getBody(), cumulativeUsage));
196216

197217
observationContext.setResponse(chatResponse);
198218

@@ -205,14 +225,18 @@ && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALL
205225
var toolCallConversation = handleToolCalls(prompt, response);
206226
// Recursively call the call method with the tool call message
207227
// conversation that contains the call responses.
208-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
228+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
209229
}
210230

211231
return response;
212232
}
213233

214234
@Override
215235
public Flux<ChatResponse> stream(Prompt prompt) {
236+
return this.internalStream(prompt, null);
237+
}
238+
239+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
216240
return Flux.deferContextual(contextView -> {
217241
var request = createRequest(prompt, true);
218242

@@ -258,7 +282,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
258282
// @formatter:on
259283

260284
if (chatCompletion2.usage() != null) {
261-
return new ChatResponse(generations, from(chatCompletion2));
285+
MistralAiUsage usage = MistralAiUsage.from(chatCompletion2.usage());
286+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse);
287+
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
262288
}
263289
else {
264290
return new ChatResponse(generations);
@@ -276,7 +302,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
276302
var toolCallConversation = handleToolCalls(prompt, response);
277303
// Recursively call the stream method with the tool call message
278304
// conversation that contains the call responses.
279-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
305+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
280306
}
281307
else {
282308
return Flux.just(response);
@@ -313,7 +339,8 @@ private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
313339
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs()))
314340
.toList();
315341

316-
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
342+
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices,
343+
chunk.usage());
317344
}
318345

319346
/**

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
209209
return !isInsideTool.get();
210210
})
211211
.concatMapIterable(window -> {
212-
Mono<ChatCompletionChunk> mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null),
212+
Mono<ChatCompletionChunk> mono1 = window.reduce(
213+
new ChatCompletionChunk(null, null, null, null, null, null),
213214
(previous, current) -> this.chunkMerger.merge(previous, current));
214215
return List.of(mono1);
215216
})
@@ -934,6 +935,7 @@ public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("lo
934935
* @param model The model used for the chat completion.
935936
* @param choices A list of chat completion choices. Can be more than one if n is
936937
* greater than 1.
938+
* @param usage usage metrics for the chat completion.
937939
*/
938940
@JsonInclude(Include.NON_NULL)
939941
public record ChatCompletionChunk(
@@ -942,7 +944,8 @@ public record ChatCompletionChunk(
942944
@JsonProperty("object") String object,
943945
@JsonProperty("created") Long created,
944946
@JsonProperty("model") String model,
945-
@JsonProperty("choices") List<ChunkChoice> choices) {
947+
@JsonProperty("choices") List<ChunkChoice> choices,
948+
@JsonProperty("usage") Usage usage) {
946949
// @formatter:on
947950

948951
/**

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu
6363

6464
ChunkChoice choice = merge(previousChoice0, currentChoice0);
6565

66-
return new ChatCompletionChunk(id, object, created, model, List.of(choice));
66+
MistralAiApi.Usage usage = (current.usage() != null ? current.usage() : previous.usage());
67+
68+
return new ChatCompletionChunk(id, object, created, model, List.of(choice), usage);
6769
}
6870

6971
private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ void functionCallTest() {
205205
logger.info("Response: {}", response);
206206

207207
assertThat(response.getResult().getOutput().getText()).containsAnyOf("30.0", "30");
208+
assertThat(response.getMetadata()).isNotNull();
209+
assertThat(response.getMetadata().getUsage()).isNotNull();
210+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(800);
208211
}
209212

210213
@Test
@@ -238,6 +241,32 @@ void streamFunctionCallTest() {
238241
assertThat(content).containsAnyOf("10.0", "10");
239242
}
240243

244+
@Test
245+
void streamFunctionCallUsageTest() {
246+
247+
UserMessage userMessage = new UserMessage(
248+
"What's the weather like in San Francisco, Tokyo, and Paris? Response in Celsius");
249+
250+
List<Message> messages = new ArrayList<>(List.of(userMessage));
251+
252+
var promptOptions = MistralAiChatOptions.builder()
253+
.withModel(MistralAiApi.ChatModel.SMALL.getValue())
254+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
255+
.function("getCurrentWeather", new MockWeatherService())
256+
.description("Get the weather in location")
257+
.inputType(MockWeatherService.Request.class)
258+
.build()))
259+
.build();
260+
261+
Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(messages, promptOptions));
262+
ChatResponse chatResponse = response.last().block();
263+
264+
logger.info("Response: {}", chatResponse);
265+
assertThat(chatResponse.getMetadata()).isNotNull();
266+
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
267+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(800);
268+
}
269+
241270
record ActorsFilmsRecord(String actor, List<String> movies) {
242271

243272
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void mistralAiChatStreamTransientError() {
124124
var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
125125
ChatCompletionFinishReason.STOP, null);
126126
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L,
127-
"model", List.of(choice));
127+
"model", List.of(choice), null);
128128

129129
given(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
130130
.willThrow(new TransientAiException("Transient Error 1"))

0 commit comments

Comments
 (0)