From b65a3343a9d49a8adff03061a84a96693dc394ec Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Wed, 11 Dec 2024 12:33:36 +0000 Subject: [PATCH] Fix Mistral AI Chat model function call usage calculation - Fix the chat model's call() to calculate the cumulative usage - Use an explicit internalCall to pass the previous chat response so that accumulation can be done - Fix the chat model's stream() to calculate the cumulative usage - Fix MistralAi API to include usgae in ChatCompletionChunk - Use internalStream() to accumulate the usage Add/update tests --- .../ai/mistralai/MistralAiChatModel.java | 37 ++++++++++++++++--- .../ai/mistralai/api/MistralAiApi.java | 7 +++- .../MistralAiStreamFunctionCallingHelper.java | 4 +- .../ai/mistralai/MistralAiChatModelIT.java | 29 +++++++++++++++ .../ai/mistralai/MistralAiRetryTests.java | 2 +- 5 files changed, 70 insertions(+), 9 deletions(-) diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index c7b967a5673..be3461863c3 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -36,6 +36,8 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -75,6 +77,7 @@ * @author Grogdunn * @author Thomas Vitale * @author luocongqiu + * @author Ilayaperumal Gopinathan * @since 1.0.0 */ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel { @@ -156,8 +159,22 @@ public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) { .build(); } + public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result, Usage usage) { + Assert.notNull(result, "Mistral AI ChatCompletion must not be null"); + return ChatResponseMetadata.builder() + .withId(result.id()) + .withModel(result.model()) + .withUsage(usage) + .withKeyValue("created", result.created()) + .build(); + } + @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false); @@ -193,7 +210,10 @@ public ChatResponse call(Prompt prompt) { return buildGeneration(choice, metadata); }).toList(); - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody())); + MistralAiUsage usage = MistralAiUsage.from(completionEntity.getBody().usage()); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + from(completionEntity.getBody(), cumulativeUsage)); observationContext.setResponse(chatResponse); @@ -206,7 +226,7 @@ && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALL var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; @@ -214,6 +234,10 @@ && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALL @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { var request = createRequest(prompt, true); @@ -259,7 +283,9 @@ public Flux stream(Prompt prompt) { // @formatter:on if (chatCompletion2.usage() != null) { - return new ChatResponse(generations, from(chatCompletion2)); + MistralAiUsage usage = MistralAiUsage.from(chatCompletion2.usage()); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(usage, previousChatResponse); + return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); } else { return new ChatResponse(generations); @@ -277,7 +303,7 @@ public Flux stream(Prompt prompt) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); } else { return Flux.just(response); @@ -314,7 +340,8 @@ private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) { .map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs())) .toList(); - return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null); + return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, + chunk.usage()); } /** diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index bd7723676f4..4743b0e0387 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -209,7 +209,8 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return !isInsideTool.get(); }) .concatMapIterable(window -> { - Mono mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null), + Mono mono1 = window.reduce( + new ChatCompletionChunk(null, null, null, null, null, null), (previous, current) -> this.chunkMerger.merge(previous, current)); return List.of(mono1); }) @@ -934,6 +935,7 @@ public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("lo * @param model The model used for the chat completion. * @param choices A list of chat completion choices. Can be more than one if n is * greater than 1. + * @param usage usage metrics for the chat completion. */ @JsonInclude(Include.NON_NULL) public record ChatCompletionChunk( @@ -942,7 +944,8 @@ public record ChatCompletionChunk( @JsonProperty("object") String object, @JsonProperty("created") Long created, @JsonProperty("model") String model, - @JsonProperty("choices") List choices) { + @JsonProperty("choices") List choices, + @JsonProperty("usage") Usage usage) { // @formatter:on /** diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java index c00249eead8..608e8d6ffbb 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java @@ -63,7 +63,9 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu ChunkChoice choice = merge(previousChoice0, currentChoice0); - return new ChatCompletionChunk(id, object, created, model, List.of(choice)); + MistralAiApi.Usage usage = (current.usage() != null ? current.usage() : previous.usage()); + + return new ChatCompletionChunk(id, object, created, model, List.of(choice), usage); } private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index b8a4ebdafa7..1d3b3ecd56b 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -205,6 +205,9 @@ void functionCallTest() { logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getText()).containsAnyOf("30.0", "30"); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(800); } @Test @@ -238,6 +241,32 @@ void streamFunctionCallTest() { assertThat(content).containsAnyOf("10.0", "10"); } + @Test + void streamFunctionCallUsageTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Response in Celsius"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = MistralAiChatOptions.builder() + .withModel(MistralAiApi.ChatModel.SMALL.getValue()) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); + ChatResponse chatResponse = response.last().block(); + + logger.info("Response: {}", chatResponse); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(1050).isGreaterThan(800); + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 71982a390a6..1c92676ae3d 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -124,7 +124,7 @@ public void mistralAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP, null); ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L, - "model", List.of(choice)); + "model", List.of(choice), null); given(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .willThrow(new TransientAiException("Transient Error 1"))