Skip to content

Commit 81e5e54

Browse files
committed
Fix Anthropic chat model functioncalling token usage
- Accumulate the token usage when functioncalling is used - Fix both call() as well as stream() operations - Add/update tests
1 parent 85580f8 commit 81e5e54

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
import org.springframework.ai.chat.messages.UserMessage;
4848
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4949
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
50+
import org.springframework.ai.chat.metadata.EmptyUsage;
51+
import org.springframework.ai.chat.metadata.Usage;
52+
import org.springframework.ai.chat.metadata.UsageUtils;
5053
import org.springframework.ai.chat.model.AbstractToolCallSupport;
5154
import org.springframework.ai.chat.model.ChatModel;
5255
import org.springframework.ai.chat.model.ChatResponse;
@@ -211,6 +214,10 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
211214

212215
@Override
213216
public ChatResponse call(Prompt prompt) {
217+
return this.internalCall(prompt, null);
218+
}
219+
220+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
214221
ChatCompletionRequest request = createRequest(prompt, false);
215222

216223
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -227,8 +234,14 @@ public ChatResponse call(Prompt prompt) {
227234
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
228235
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
229236

230-
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
237+
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
238+
AnthropicApi.Usage usage = completionResponse.usage();
231239

240+
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
241+
: new EmptyUsage();
242+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
243+
244+
ChatResponse chatResponse = toChatResponse(completionEntity.getBody(), accumulatedUsage);
232245
observationContext.setResponse(chatResponse);
233246

234247
return chatResponse;
@@ -237,14 +250,18 @@ public ChatResponse call(Prompt prompt) {
237250
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
238251
&& this.isToolCall(response, Set.of("tool_use"))) {
239252
var toolCallConversation = handleToolCalls(prompt, response);
240-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
253+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
241254
}
242255

243256
return response;
244257
}
245258

246259
@Override
247260
public Flux<ChatResponse> stream(Prompt prompt) {
261+
return this.internalStream(prompt, null);
262+
}
263+
264+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
248265
return Flux.deferContextual(contextView -> {
249266
ChatCompletionRequest request = createRequest(prompt, true);
250267

@@ -264,11 +281,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
264281

265282
// @formatter:off
266283
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
267-
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
284+
AnthropicApi.Usage usage = chatCompletionResponse.usage();
285+
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
286+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
287+
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
268288

269289
if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
270290
var toolCallConversation = handleToolCalls(prompt, chatResponse);
271-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
291+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
272292
}
273293

274294
return Mono.just(chatResponse);
@@ -282,7 +302,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
282302
});
283303
}
284304

285-
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
305+
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage usage) {
286306

287307
if (chatCompletion == null) {
288308
logger.warn("Null chat completion returned");
@@ -328,12 +348,15 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
328348
allGenerations.add(toolCallGeneration);
329349
}
330350

331-
return new ChatResponse(allGenerations, this.from(chatCompletion));
351+
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
332352
}
333353

334354
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
355+
return from(result, AnthropicUsage.from(result.usage()));
356+
}
357+
358+
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
335359
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
336-
AnthropicUsage usage = AnthropicUsage.from(result.usage());
337360
return ChatResponseMetadata.builder()
338361
.withId(result.id())
339362
.withModel(result.model())

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.chat.messages.AssistantMessage;
3838
import org.springframework.ai.chat.messages.Message;
3939
import org.springframework.ai.chat.messages.UserMessage;
40+
import org.springframework.ai.chat.metadata.Usage;
4041
import org.springframework.ai.chat.model.ChatModel;
4142
import org.springframework.ai.chat.model.ChatResponse;
4243
import org.springframework.ai.chat.model.Generation;
@@ -288,7 +289,12 @@ void functionCallTest() {
288289
logger.info("Response: {}", response);
289290

290291
Generation generation = response.getResult();
292+
assertThat(generation).isNotNull();
293+
assertThat(generation.getOutput()).isNotNull();
291294
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
295+
assertThat(response.getMetadata()).isNotNull();
296+
assertThat(response.getMetadata().getUsage()).isNotNull();
297+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
292298
}
293299

294300
@Test
@@ -324,6 +330,35 @@ void streamFunctionCallTest() {
324330
assertThat(content).contains("30", "10", "15");
325331
}
326332

333+
@Test
334+
void streamFunctionCallUsageTest() {
335+
336+
UserMessage userMessage = new UserMessage(
337+
"What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius.");
338+
339+
List<Message> messages = new ArrayList<>(List.of(userMessage));
340+
341+
var promptOptions = AnthropicChatOptions.builder()
342+
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
343+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
344+
.function("getCurrentWeather", new MockWeatherService())
345+
.description(
346+
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
347+
.inputType(MockWeatherService.Request.class)
348+
.build()))
349+
.build();
350+
351+
Flux<ChatResponse> responseFlux = this.chatModel.stream(new Prompt(messages, promptOptions));
352+
353+
ChatResponse chatResponse = responseFlux.last().block();
354+
355+
logger.info("Response: {}", chatResponse);
356+
Usage usage = chatResponse.getMetadata().getUsage();
357+
358+
assertThat(usage).isNotNull();
359+
assertThat(usage.getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
360+
}
361+
327362
@Test
328363
void validateCallResponseMetadata() {
329364
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName();

0 commit comments

Comments
 (0)