Skip to content

Commit ac32938

Browse files
committed
Anthropic UsageAccessor
- This PR introduces Usage accessor to retrieve the usage metadata fields from the ChatCompletion Usage response metadata - Having UsageAccessor would help enable the clients getting access to the entire metadata instead of pre-defined set of metadata - Add utility method to UsageUtils to parseLong value from the metadata value object - Add tests
1 parent 329e6c0 commit ac32938

File tree

8 files changed

+229
-79
lines changed

8 files changed

+229
-79
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4141
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4242
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
43-
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
43+
import org.springframework.ai.anthropic.metadata.AnthropicUsageAccessor;
4444
import org.springframework.ai.chat.messages.AssistantMessage;
4545
import org.springframework.ai.chat.messages.MessageType;
4646
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -235,9 +235,9 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
235235
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
236236

237237
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
238-
AnthropicApi.Usage usage = completionResponse.usage();
238+
Map<String, Object> usage = completionResponse.usage();
239239

240-
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
240+
Usage currentChatResponseUsage = usage != null ? new AnthropicUsageAccessor(completionResponse.usage())
241241
: new EmptyUsage();
242242
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
243243

@@ -281,8 +281,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
281281

282282
// @formatter:off
283283
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
284-
AnthropicApi.Usage usage = chatCompletionResponse.usage();
285-
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
284+
Map<String, Object> usage = chatCompletionResponse.usage();
285+
Usage currentChatResponseUsage = usage != null ? new AnthropicUsageAccessor(chatCompletionResponse.usage()) : new EmptyUsage();
286286
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
287287
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
288288

@@ -352,7 +352,7 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
352352
}
353353

354354
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
355-
return from(result, AnthropicUsage.from(result.usage()));
355+
return from(result, new AnthropicUsageAccessor(result.usage()));
356356
}
357357

358358
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ public record ChatCompletionResponse(
870870
@JsonProperty("model") String model,
871871
@JsonProperty("stop_reason") String stopReason,
872872
@JsonProperty("stop_sequence") String stopSequence,
873-
@JsonProperty("usage") Usage usage) {
873+
@JsonProperty("usage") Map<String, Object> usage) {
874874
// @formatter:on
875875
}
876876

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.ai.anthropic.api;
1818

1919
import java.util.ArrayList;
20+
import java.util.HashMap;
2021
import java.util.List;
22+
import java.util.Map;
2123
import java.util.concurrent.atomic.AtomicReference;
2224

2325
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
@@ -35,7 +37,7 @@
3537
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
3638
import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent;
3739
import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent;
38-
import org.springframework.ai.anthropic.api.AnthropicApi.Usage;
40+
import org.springframework.ai.anthropic.metadata.AnthropicUsageAccessor;
3941
import org.springframework.util.Assert;
4042
import org.springframework.util.CollectionUtils;
4143
import org.springframework.util.StringUtils;
@@ -173,9 +175,11 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
173175
}
174176

175177
if (messageDeltaEvent.usage() != null) {
176-
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
177-
messageDeltaEvent.usage().outputTokens());
178-
contentBlockReference.get().withUsage(totalUsage);
178+
Map<String, Object> metadata = (contentBlockReference.get().usage != null)
179+
? contentBlockReference.get().usage : new HashMap<>();
180+
metadata.put(AnthropicUsageAccessor.OUTPUT_TOKENS,
181+
String.valueOf(messageDeltaEvent.usage().outputTokens()));
182+
contentBlockReference.get().withUsage(metadata);
179183
}
180184
}
181185
else if (event.type().equals(EventType.MESSAGE_STOP)) {
@@ -204,7 +208,7 @@ public static class ChatCompletionResponseBuilder {
204208

205209
private String stopSequence;
206210

207-
private Usage usage;
211+
private Map<String, Object> usage;
208212

209213
public ChatCompletionResponseBuilder() {
210214
}
@@ -244,7 +248,7 @@ public ChatCompletionResponseBuilder withStopSequence(String stopSequence) {
244248
return this;
245249
}
246250

247-
public ChatCompletionResponseBuilder withUsage(Usage usage) {
251+
public ChatCompletionResponseBuilder withUsage(Map<String, Object> usage) {
248252
this.usage = usage;
249253
return this;
250254
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicUsage.java

Lines changed: 0 additions & 66 deletions
This file was deleted.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.anthropic.metadata;
18+
19+
import java.util.Map;
20+
21+
import org.springframework.ai.chat.metadata.Usage;
22+
import org.springframework.ai.chat.metadata.UsageUtils;
23+
import org.springframework.util.Assert;
24+
25+
/**
26+
* Anthropic Usage accessor class which provides access to the usage metadata.
27+
*
28+
* @author Ilayaperumal Gopinathan
29+
*/
30+
public record AnthropicUsageAccessor(Map<String, Object> usage) implements Usage {
31+
32+
public static final String INPUT_TOKENS = "input_tokens";
33+
34+
public static final String OUTPUT_TOKENS = "output_tokens";
35+
36+
public static final String CACHE_CREATION_INPUT_TOKENS = "cache_creation_input_tokens";
37+
38+
public static final String CACHE_READ_INPUT_TOKENS = "cache_read_input_tokens";
39+
40+
public AnthropicUsageAccessor {
41+
Assert.notNull(usage, "usage must not be null");
42+
}
43+
44+
@Override
45+
public Long getPromptTokens() {
46+
return UsageUtils.parseLong(this.usage.get(INPUT_TOKENS));
47+
}
48+
49+
@Override
50+
public Long getGenerationTokens() {
51+
return UsageUtils.parseLong(this.usage.get(OUTPUT_TOKENS));
52+
}
53+
54+
public Long getCacheCreationInputTokens() {
55+
return UsageUtils.parseLong(this.usage.get(CACHE_CREATION_INPUT_TOKENS));
56+
}
57+
58+
public Long getCacheReadInputTokens() {
59+
return UsageUtils.parseLong(this.usage.get(CACHE_READ_INPUT_TOKENS));
60+
}
61+
62+
public Map<String, Object> getUsage() {
63+
return this.usage;
64+
}
65+
66+
@Override
67+
public String toString() {
68+
return this.usage.toString();
69+
}
70+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
@NonNullApi
18+
@NonNullFields
19+
package org.springframework.ai.anthropic.metadata;
20+
21+
import org.springframework.lang.NonNullApi;
22+
import org.springframework.lang.NonNullFields;
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package org.springframework.ai.anthropic.metadata;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.junit.jupiter.api.DisplayName;
5+
import static org.assertj.core.api.Assertions.assertThat;
6+
import static org.junit.jupiter.api.Assertions.assertThrows;
7+
8+
import java.util.HashMap;
9+
import java.util.Map;
10+
11+
class AnthropicUsageAccessorTests {
12+
13+
@Test
14+
@DisplayName("Should throw exception when usage map is null")
15+
void constructorShouldThrowExceptionWhenUsageMapIsNull() {
16+
assertThrows(IllegalArgumentException.class, () -> new AnthropicUsageAccessor(null));
17+
}
18+
19+
@Test
20+
@DisplayName("Should return correct token counts for all fields")
21+
void shouldReturnCorrectTokenCounts() {
22+
Map<String, Object> usageMap = new HashMap<>();
23+
usageMap.put("input_tokens", 100L);
24+
usageMap.put("output_tokens", 50L);
25+
usageMap.put("cache_creation_input_tokens", 25L);
26+
usageMap.put("cache_read_input_tokens", 75L);
27+
28+
AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap);
29+
30+
assertThat(accessor.getPromptTokens()).isEqualTo(100L);
31+
assertThat(accessor.getGenerationTokens()).isEqualTo(50L);
32+
assertThat(accessor.getCacheCreationInputTokens()).isEqualTo(25L);
33+
assertThat(accessor.getCacheReadInputTokens()).isEqualTo(75L);
34+
}
35+
36+
@Test
37+
@DisplayName("Should handle missing values in usage map")
38+
void shouldHandleMissingValues() {
39+
Map<String, Object> usageMap = new HashMap<>();
40+
usageMap.put("input_tokens", 100L);
41+
42+
AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap);
43+
44+
assertThat(accessor.getPromptTokens()).isEqualTo(100L);
45+
assertThat(accessor.getGenerationTokens()).isNull();
46+
assertThat(accessor.getCacheCreationInputTokens()).isNull();
47+
assertThat(accessor.getCacheReadInputTokens()).isNull();
48+
}
49+
50+
@Test
51+
@DisplayName("Should handle empty usage map")
52+
void shouldHandleEmptyUsageMap() {
53+
Map<String, Object> usageMap = new HashMap<>();
54+
55+
AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap);
56+
57+
assertThat(accessor.getPromptTokens()).isNull();
58+
assertThat(accessor.getGenerationTokens()).isNull();
59+
assertThat(accessor.getCacheCreationInputTokens()).isNull();
60+
assertThat(accessor.getCacheReadInputTokens()).isNull();
61+
}
62+
63+
@Test
64+
@DisplayName("Should handle maximum token values")
65+
void shouldHandleMaximumTokenValues() {
66+
Map<String, Object> usageMap = new HashMap<>();
67+
usageMap.put("input_tokens", Long.MAX_VALUE);
68+
usageMap.put("output_tokens", Long.MAX_VALUE);
69+
usageMap.put("cache_creation_input_tokens", Long.MAX_VALUE);
70+
usageMap.put("cache_read_input_tokens", Long.MAX_VALUE);
71+
72+
AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap);
73+
74+
assertThat(accessor.getPromptTokens()).isEqualTo(Long.MAX_VALUE);
75+
assertThat(accessor.getGenerationTokens()).isEqualTo(Long.MAX_VALUE);
76+
assertThat(accessor.getCacheCreationInputTokens()).isEqualTo(Long.MAX_VALUE);
77+
assertThat(accessor.getCacheReadInputTokens()).isEqualTo(Long.MAX_VALUE);
78+
}
79+
80+
@Test
81+
@DisplayName("Should return usage metadata")
82+
void shouldReturnUsageMetadata() {
83+
Map<String, Object> usageMap = new HashMap<>();
84+
usageMap.put("input_tokens", 100L);
85+
usageMap.put("output_tokens", 50L);
86+
usageMap.put("cache_creation_input_tokens", 25L);
87+
usageMap.put("cache_read_input_tokens", 75L);
88+
89+
AnthropicUsageAccessor accessor = new AnthropicUsageAccessor(usageMap);
90+
91+
assertThat(accessor.getUsage()).isEqualTo(usageMap);
92+
}
93+
94+
}

spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,30 @@ else if (usage != null && usage.getTotalTokens() == 0L) {
7979
return false;
8080
}
8181

82+
/**
83+
* Parse the usage metadata value object into Long.
84+
* @param value the value object.
85+
* @return the Long value.
86+
*/
87+
public static Long parseLong(Object value) {
88+
if (value == null) {
89+
return null;
90+
}
91+
if (value instanceof Long) {
92+
return (Long) value;
93+
}
94+
else if (value instanceof String) {
95+
return Long.parseLong((String) value);
96+
}
97+
else if (value instanceof Number) {
98+
return ((Number) value).longValue();
99+
}
100+
else if (value instanceof Integer) {
101+
return ((Integer) value).longValue();
102+
}
103+
else {
104+
throw new IllegalArgumentException("Unsupported value type: " + value.getClass());
105+
}
106+
}
107+
82108
}

0 commit comments

Comments
 (0)