Skip to content

Commit b822107

Browse files
didalgolabtzolov
authored andcommitted
Fix Anthropic token usage handling in streaming
1 parent 997e01c commit b822107

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
*
6060
* @author Christian Tzolov
6161
* @author luocongqiu
62+
* @author Mariusz Bernacki
6263
* @since 1.0.0
6364
*/
6465
public class AnthropicChatModel extends
@@ -192,6 +193,11 @@ else if (chunk.type().equals("message_delta")) {
192193
ChatCompletion delta = ModelOptionsUtils.mapToClass(chunk.delta(), ChatCompletion.class);
193194

194195
chatCompletionReference.get().withType(chunk.type());
196+
if (chunk.usage() != null) {
197+
var totalUsage = new Usage(chatCompletionReference.get().usage.inputTokens(),
198+
chunk.usage().outputTokens());
199+
chatCompletionReference.get().withUsage(totalUsage);
200+
}
195201
if (delta.id() != null) {
196202
chatCompletionReference.get().withId(delta.id());
197203
}
@@ -201,9 +207,6 @@ else if (chunk.type().equals("message_delta")) {
201207
if (delta.model() != null) {
202208
chatCompletionReference.get().withModel(delta.model());
203209
}
204-
if (delta.usage() != null) {
205-
chatCompletionReference.get().withUsage(delta.usage());
206-
}
207210
if (delta.content() != null) {
208211
chatCompletionReference.get().withContent(delta.content());
209212
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
/**
4242
* @author Christian Tzolov
43+
* @author Mariusz Bernacki
4344
* @since 1.0.0
4445
*/
4546
public class AnthropicApi {
@@ -521,6 +522,17 @@ public record Usage( // @formatter:off
521522
// @formatter:off
522523
}
523524

525+
/**
526+
* Usage statistics with output only tokens for streamed completions.
527+
*
528+
* @param outputTokens The number of output tokens which were used in a completion.
529+
*/
530+
@JsonInclude(Include.NON_NULL)
531+
public record OutputUsage( // @formatter:off
532+
@JsonProperty("output_tokens") Integer outputTokens) {
533+
// @formatter:off
534+
}
535+
524536
/**
525537
* The role of the author of this message.
526538
*/
@@ -557,7 +569,8 @@ public record StreamResponse( // @formatter:off
557569
@JsonProperty("index") Integer index,
558570
@JsonProperty("message") ChatCompletion message,
559571
@JsonProperty("content_block") MediaContent contentBlock,
560-
@JsonProperty("delta") Map<String, Object> delta) {
572+
@JsonProperty("delta") Map<String, Object> delta,
573+
@JsonProperty("usage") OutputUsage usage) {
561574
// @formatter:on
562575
}
563576

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,24 @@ void roleTest(String modelName) {
9494
logger.info(response.toString());
9595
}
9696

97+
@Test
98+
void streamingWithTokenUsage() {
99+
var promptOptions = AnthropicChatOptions.builder().withTemperature(0f).build();
100+
101+
var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
102+
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
103+
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();
104+
105+
assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
106+
assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0);
107+
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
108+
109+
assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
110+
assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens());
111+
assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens());
112+
113+
}
114+
97115
@Test
98116
void listOutputConverter() {
99117
DefaultConversionService conversionService = new DefaultConversionService();
@@ -195,7 +213,7 @@ void multiModalityTest() throws IOException {
195213
var response = chatModel.call(new Prompt(List.of(userMessage)));
196214

197215
logger.info(response.getResult().getOutput().getContent());
198-
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
216+
assertThat(response.getResult().getOutput().getContent()).contains("banan", "apple", "basket");
199217
}
200218

201219
@Test

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,10 @@ void defaultFunctionCallTest() {
230230
String response = ChatClient.builder(chatModel)
231231
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
232232
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
233-
.build()
234-
.prompt().call().content();
233+
.build()
234+
.prompt()
235+
.call()
236+
.content();
235237
// @formatter:on
236238

237239
logger.info("Response: {}", response);
@@ -304,7 +306,7 @@ void streamingMultiModality() throws IOException {
304306

305307
// @formatter:off
306308
Flux<String> response = ChatClient.create(chatModel).prompt()
307-
.options(AnthropicChatOptions.builder().withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS)
309+
.options(AnthropicChatOptions.builder().withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET)
308310
.build())
309311
.user(u -> u.text("Explain what do you see on this picture?")
310312
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))

0 commit comments

Comments
 (0)