From 81e5e54f80ca186b44478043636dcf5e911cd8d3 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Thu, 12 Dec 2024 14:52:44 +0000 Subject: [PATCH] Fix Anthropic chat model functioncalling token usage - Accumulate the token usage when functioncalling is used - Fix both call() as well as stream() operations - Add/update tests --- .../ai/anthropic/AnthropicChatModel.java | 37 +++++++++++++++---- .../ai/anthropic/AnthropicChatModelIT.java | 35 ++++++++++++++++++ 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index c68fb6d29bc..c343fd5e977 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -47,6 +47,9 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -211,6 +214,10 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { ChatCompletionRequest request = createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -227,8 +234,14 @@ public ChatResponse call(Prompt prompt) { ResponseEntity completionEntity = this.retryTemplate .execute(ctx -> this.anthropicApi.chatCompletionEntity(request)); - ChatResponse chatResponse = toChatResponse(completionEntity.getBody()); + AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody(); + AnthropicApi.Usage usage = completionResponse.usage(); + Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage()) + : new EmptyUsage(); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); + + ChatResponse chatResponse = toChatResponse(completionEntity.getBody(), accumulatedUsage); observationContext.setResponse(chatResponse); return chatResponse; @@ -237,7 +250,7 @@ public ChatResponse call(Prompt prompt) { if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null && this.isToolCall(response, Set.of("tool_use"))) { var toolCallConversation = handleToolCalls(prompt, response); - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; @@ -245,6 +258,10 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -264,11 +281,14 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux chatResponseFlux = response.switchMap(chatCompletionResponse -> { - ChatResponse chatResponse = toChatResponse(chatCompletionResponse); + AnthropicApi.Usage usage = chatCompletionResponse.usage(); + Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage(); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); + ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); } return Mono.just(chatResponse); @@ -282,7 +302,7 @@ public Flux stream(Prompt prompt) { }); } - private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { + private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage usage) { if (chatCompletion == null) { logger.warn("Null chat completion returned"); @@ -328,12 +348,15 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { allGenerations.add(toolCallGeneration); } - return new ChatResponse(allGenerations, this.from(chatCompletion)); + return new ChatResponse(allGenerations, this.from(chatCompletion, usage)); } private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { + return from(result, AnthropicUsage.from(result.usage())); + } + + private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) { Assert.notNull(result, "Anthropic ChatCompletionResult must not be null"); - AnthropicUsage usage = AnthropicUsage.from(result.usage()); return ChatResponseMetadata.builder() .withId(result.id()) .withModel(result.model()) diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index aa62746bc7d..824e1c42c1a 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -37,6 +37,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -288,7 +289,12 @@ void functionCallTest() { logger.info("Response: {}", response); Generation generation = response.getResult(); + assertThat(generation).isNotNull(); + assertThat(generation.getOutput()).isNotNull(); assertThat(generation.getOutput().getText()).contains("30", "10", "15"); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(1800); } @Test @@ -324,6 +330,35 @@ void streamFunctionCallTest() { assertThat(content).contains("30", "10", "15"); } + @Test + void streamFunctionCallUsageTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AnthropicChatOptions.builder() + .withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux responseFlux = this.chatModel.stream(new Prompt(messages, promptOptions)); + + ChatResponse chatResponse = responseFlux.last().block(); + + logger.info("Response: {}", chatResponse); + Usage usage = chatResponse.getMetadata().getUsage(); + + assertThat(usage).isNotNull(); + assertThat(usage.getTotalTokens()).isLessThan(4000).isGreaterThan(1800); + } + @Test void validateCallResponseMetadata() { String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName();