From ac32938dfcbb86dbf19039c1d4a0c86a85d2ff19 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Tue, 14 Jan 2025 17:54:02 +0000 Subject: [PATCH] Anthropic UsageAccessor - This PR introduces Usage accessor to retrieve the usage metadata fields from the ChatCompletion Usage response metadata - Having UsageAccessor would help enable the clients getting access to the entire metadata instead of pre-defined set of metadata - Add utility method to UsageUtils to parseLong value from the metadata value object - Add tests --- .../ai/anthropic/AnthropicChatModel.java | 12 +-- .../ai/anthropic/api/AnthropicApi.java | 2 +- .../ai/anthropic/api/StreamHelper.java | 16 ++-- .../ai/anthropic/metadata/AnthropicUsage.java | 66 ------------- .../metadata/AnthropicUsageAccessor.java | 70 ++++++++++++++ .../ai/anthropic/metadata/package-info.java | 22 +++++ .../metadata/AnthropicUsageAccessorTests.java | 94 +++++++++++++++++++ .../ai/chat/metadata/UsageUtils.java | 26 +++++ 8 files changed, 229 insertions(+), 79 deletions(-) delete mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java create mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessor.java create mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/package-info.java create mode 100644 models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessorTests.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index e9b5528eb33..69231b03985 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -40,7 +40,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.anthropic.metadata.AnthropicUsage; +import org.springframework.ai.anthropic.metadata.AnthropicUsageAccessor; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -235,9 +235,9 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons .execute(ctx -> this.anthropicApi.chatCompletionEntity(request)); AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody(); - AnthropicApi.Usage usage = completionResponse.usage(); + Map usage = completionResponse.usage(); - Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage()) + Usage currentChatResponseUsage = usage != null ? new AnthropicUsageAccessor(completionResponse.usage()) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); @@ -281,8 +281,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = response.switchMap(chatCompletionResponse -> { - AnthropicApi.Usage usage = chatCompletionResponse.usage(); - Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage(); + Map usage = chatCompletionResponse.usage(); + Usage currentChatResponseUsage = usage != null ? new AnthropicUsageAccessor(chatCompletionResponse.usage()) : new EmptyUsage(); Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); @@ -352,7 +352,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage } private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { - return from(result, AnthropicUsage.from(result.usage())); + return from(result, new AnthropicUsageAccessor(result.usage())); } private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 8503201d21a..aa4afffe56b 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -870,7 +870,7 @@ public record ChatCompletionResponse( @JsonProperty("model") String model, @JsonProperty("stop_reason") String stopReason, @JsonProperty("stop_sequence") String stopSequence, - @JsonProperty("usage") Usage usage) { + @JsonProperty("usage") Map usage) { // @formatter:on } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index ae62eb0748c..f74c002d256 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -17,7 +17,9 @@ package org.springframework.ai.anthropic.api; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; @@ -35,7 +37,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent; import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent; -import org.springframework.ai.anthropic.api.AnthropicApi.Usage; +import org.springframework.ai.anthropic.metadata.AnthropicUsageAccessor; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -173,9 +175,11 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { } if (messageDeltaEvent.usage() != null) { - var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); - contentBlockReference.get().withUsage(totalUsage); + Map metadata = (contentBlockReference.get().usage != null) + ? contentBlockReference.get().usage : new HashMap<>(); + metadata.put(AnthropicUsageAccessor.OUTPUT_TOKENS, + String.valueOf(messageDeltaEvent.usage().outputTokens())); + contentBlockReference.get().withUsage(metadata); } } else if (event.type().equals(EventType.MESSAGE_STOP)) { @@ -204,7 +208,7 @@ public static class ChatCompletionResponseBuilder { private String stopSequence; - private Usage usage; + private Map usage; public ChatCompletionResponseBuilder() { } @@ -244,7 +248,7 @@ public ChatCompletionResponseBuilder withStopSequence(String stopSequence) { return this; } - public ChatCompletionResponseBuilder withUsage(Usage usage) { + public ChatCompletionResponseBuilder withUsage(Map usage) { this.usage = usage; return this; } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java deleted file mode 100644 index fbafc2297d3..00000000000 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.anthropic.metadata; - -import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.chat.metadata.Usage; -import org.springframework.util.Assert; - -/** - * {@link Usage} implementation for {@literal AnthropicApi}. - * - * @author Christian Tzolov - * @since 1.0.0 - */ -public class AnthropicUsage implements Usage { - - private final AnthropicApi.Usage usage; - - protected AnthropicUsage(AnthropicApi.Usage usage) { - Assert.notNull(usage, "AnthropicApi Usage must not be null"); - this.usage = usage; - } - - public static AnthropicUsage from(AnthropicApi.Usage usage) { - return new AnthropicUsage(usage); - } - - protected AnthropicApi.Usage getUsage() { - return this.usage; - } - - @Override - public Long getPromptTokens() { - return getUsage().inputTokens().longValue(); - } - - @Override - public Long getGenerationTokens() { - return getUsage().outputTokens().longValue(); - } - - @Override - public Long getTotalTokens() { - return this.getPromptTokens() + this.getGenerationTokens(); - } - - @Override - public String toString() { - return getUsage().toString(); - } - -} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessor.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessor.java new file mode 100644 index 00000000000..fe83c89fb67 --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessor.java @@ -0,0 +1,70 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.metadata; + +import java.util.Map; + +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; +import org.springframework.util.Assert; + +/** + * Anthropic Usage accessor class which provides access to the usage metadata. + * + * @author Ilayaperumal Gopinathan + */ +public record AnthropicUsageAccessor(Map usage) implements Usage { + + public static final String INPUT_TOKENS = "input_tokens"; + + public static final String OUTPUT_TOKENS = "output_tokens"; + + public static final String CACHE_CREATION_INPUT_TOKENS = "cache_creation_input_tokens"; + + public static final String CACHE_READ_INPUT_TOKENS = "cache_read_input_tokens"; + + public AnthropicUsageAccessor { + Assert.notNull(usage, "usage must not be null"); + } + + @Override + public Long getPromptTokens() { + return UsageUtils.parseLong(this.usage.get(INPUT_TOKENS)); + } + + @Override + public Long getGenerationTokens() { + return UsageUtils.parseLong(this.usage.get(OUTPUT_TOKENS)); + } + + public Long getCacheCreationInputTokens() { + return UsageUtils.parseLong(this.usage.get(CACHE_CREATION_INPUT_TOKENS)); + } + + public Long getCacheReadInputTokens() { + return UsageUtils.parseLong(this.usage.get(CACHE_READ_INPUT_TOKENS)); + } + + public Map getUsage() { + return this.usage; + } + + @Override + public String toString() { + return this.usage.toString(); + } +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/package-info.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/package-info.java new file mode 100644 index 00000000000..31872c43078 --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.anthropic.metadata; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessorTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessorTests.java new file mode 100644 index 00000000000..99898feb45b --- /dev/null +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/metadata/AnthropicUsageAccessorTests.java @@ -0,0 +1,94 @@ +package org.springframework.ai.anthropic.metadata; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.HashMap; +import java.util.Map; + +class AnthropicUsageAccessorTests { + + @Test + @DisplayName("Should throw exception when usage map is null") + void constructorShouldThrowExceptionWhenUsageMapIsNull() { + assertThrows(IllegalArgumentException.class, () -> new AnthropicUsageAccessor(null)); + } + + @Test + @DisplayName("Should return correct token counts for all fields") + void shouldReturnCorrectTokenCounts() { + Map usageMap = new HashMap<>(); + usageMap.put("input_tokens", 100L); + usageMap.put("output_tokens", 50L); + usageMap.put("cache_creation_input_tokens", 25L); + usageMap.put("cache_read_input_tokens", 75L); + + AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap); + + assertThat(accessor.getPromptTokens()).isEqualTo(100L); + assertThat(accessor.getGenerationTokens()).isEqualTo(50L); + assertThat(accessor.getCacheCreationInputTokens()).isEqualTo(25L); + assertThat(accessor.getCacheReadInputTokens()).isEqualTo(75L); + } + + @Test + @DisplayName("Should handle missing values in usage map") + void shouldHandleMissingValues() { + Map usageMap = new HashMap<>(); + usageMap.put("input_tokens", 100L); + + AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap); + + assertThat(accessor.getPromptTokens()).isEqualTo(100L); + assertThat(accessor.getGenerationTokens()).isNull(); + assertThat(accessor.getCacheCreationInputTokens()).isNull(); + assertThat(accessor.getCacheReadInputTokens()).isNull(); + } + + @Test + @DisplayName("Should handle empty usage map") + void shouldHandleEmptyUsageMap() { + Map usageMap = new HashMap<>(); + + AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap); + + assertThat(accessor.getPromptTokens()).isNull(); + assertThat(accessor.getGenerationTokens()).isNull(); + assertThat(accessor.getCacheCreationInputTokens()).isNull(); + assertThat(accessor.getCacheReadInputTokens()).isNull(); + } + + @Test + @DisplayName("Should handle maximum token values") + void shouldHandleMaximumTokenValues() { + Map usageMap = new HashMap<>(); + usageMap.put("input_tokens", Long.MAX_VALUE); + usageMap.put("output_tokens", Long.MAX_VALUE); + usageMap.put("cache_creation_input_tokens", Long.MAX_VALUE); + usageMap.put("cache_read_input_tokens", Long.MAX_VALUE); + + AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap); + + assertThat(accessor.getPromptTokens()).isEqualTo(Long.MAX_VALUE); + assertThat(accessor.getGenerationTokens()).isEqualTo(Long.MAX_VALUE); + assertThat(accessor.getCacheCreationInputTokens()).isEqualTo(Long.MAX_VALUE); + assertThat(accessor.getCacheReadInputTokens()).isEqualTo(Long.MAX_VALUE); + } + + @Test + @DisplayName("Should return usage metadata") + void shouldReturnUsageMetadata() { + Map usageMap = new HashMap<>(); + usageMap.put("input_tokens", 100L); + usageMap.put("output_tokens", 50L); + usageMap.put("cache_creation_input_tokens", 25L); + usageMap.put("cache_read_input_tokens", 75L); + + AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap); + + assertThat(accessor.getUsage()).isEqualTo(usageMap); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java index a430d4cb5a3..5e1c89991e8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java @@ -79,4 +79,30 @@ else if (usage != null && usage.getTotalTokens() == 0L) { return false; } + /** + * Parse the usage metadata value object into Long. + * @param value the value object. + * @return the Long value. + */ + public static Long parseLong(Object value) { + if (value == null) { + return null; + } + if (value instanceof Long) { + return (Long) value; + } + else if (value instanceof String) { + return Long.parseLong((String) value); + } + else if (value instanceof Number) { + return ((Number) value).longValue(); + } + else if (value instanceof Integer) { + return ((Integer) value).longValue(); + } + else { + throw new IllegalArgumentException("Unsupported value type: " + value.getClass()); + } + } + }