Skip to content

Commit cadea8b

Browse files
ilayaperumalgtzolov
authored andcommitted
Fix Anthropic chat model functioncalling token usage
- Accumulate the token usage when functioncalling is used - Fix both call() as well as stream() operations - Add/update tests
1 parent 81dfd3b commit cadea8b

File tree

4 files changed

+68
-10
lines changed

4 files changed

+68
-10
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
import org.springframework.ai.chat.messages.UserMessage;
4848
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4949
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
50+
import org.springframework.ai.chat.metadata.EmptyUsage;
51+
import org.springframework.ai.chat.metadata.Usage;
52+
import org.springframework.ai.chat.metadata.UsageUtils;
5053
import org.springframework.ai.chat.model.AbstractToolCallSupport;
5154
import org.springframework.ai.chat.model.ChatModel;
5255
import org.springframework.ai.chat.model.ChatResponse;
@@ -210,6 +213,10 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
210213

211214
@Override
212215
public ChatResponse call(Prompt prompt) {
216+
return this.internalCall(prompt, null);
217+
}
218+
219+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
213220
ChatCompletionRequest request = createRequest(prompt, false);
214221

215222
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -226,8 +233,14 @@ public ChatResponse call(Prompt prompt) {
226233
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
227234
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
228235

229-
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
236+
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
237+
AnthropicApi.Usage usage = completionResponse.usage();
230238

239+
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
240+
: new EmptyUsage();
241+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
242+
243+
ChatResponse chatResponse = toChatResponse(completionEntity.getBody(), accumulatedUsage);
231244
observationContext.setResponse(chatResponse);
232245

233246
return chatResponse;
@@ -236,14 +249,18 @@ public ChatResponse call(Prompt prompt) {
236249
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
237250
&& this.isToolCall(response, Set.of("tool_use"))) {
238251
var toolCallConversation = handleToolCalls(prompt, response);
239-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
252+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
240253
}
241254

242255
return response;
243256
}
244257

245258
@Override
246259
public Flux<ChatResponse> stream(Prompt prompt) {
260+
return this.internalStream(prompt, null);
261+
}
262+
263+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
247264
return Flux.deferContextual(contextView -> {
248265
ChatCompletionRequest request = createRequest(prompt, true);
249266

@@ -263,11 +280,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
263280

264281
// @formatter:off
265282
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
266-
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
283+
AnthropicApi.Usage usage = chatCompletionResponse.usage();
284+
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
285+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
286+
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
267287

268288
if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
269289
var toolCallConversation = handleToolCalls(prompt, chatResponse);
270-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
290+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
271291
}
272292

273293
return Mono.just(chatResponse);
@@ -281,7 +301,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
281301
});
282302
}
283303

284-
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
304+
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage usage) {
285305

286306
if (chatCompletion == null) {
287307
logger.warn("Null chat completion returned");
@@ -327,12 +347,15 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
327347
allGenerations.add(toolCallGeneration);
328348
}
329349

330-
return new ChatResponse(allGenerations, this.from(chatCompletion));
350+
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
331351
}
332352

333353
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
354+
return from(result, AnthropicUsage.from(result.usage()));
355+
}
356+
357+
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
334358
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
335-
AnthropicUsage usage = AnthropicUsage.from(result.usage());
336359
return ChatResponseMetadata.builder()
337360
.withId(result.id())
338361
.withModel(result.model())

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.chat.messages.AssistantMessage;
3838
import org.springframework.ai.chat.messages.Message;
3939
import org.springframework.ai.chat.messages.UserMessage;
40+
import org.springframework.ai.chat.metadata.Usage;
4041
import org.springframework.ai.chat.model.ChatModel;
4142
import org.springframework.ai.chat.model.ChatResponse;
4243
import org.springframework.ai.chat.model.Generation;
@@ -288,7 +289,12 @@ void functionCallTest() {
288289
logger.info("Response: {}", response);
289290

290291
Generation generation = response.getResult();
292+
assertThat(generation).isNotNull();
293+
assertThat(generation.getOutput()).isNotNull();
291294
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
295+
assertThat(response.getMetadata()).isNotNull();
296+
assertThat(response.getMetadata().getUsage()).isNotNull();
297+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
292298
}
293299

294300
@Test
@@ -324,6 +330,35 @@ void streamFunctionCallTest() {
324330
assertThat(content).contains("30", "10", "15");
325331
}
326332

333+
@Test
334+
void streamFunctionCallUsageTest() {
335+
336+
UserMessage userMessage = new UserMessage(
337+
"What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius.");
338+
339+
List<Message> messages = new ArrayList<>(List.of(userMessage));
340+
341+
var promptOptions = AnthropicChatOptions.builder()
342+
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
343+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
344+
.function("getCurrentWeather", new MockWeatherService())
345+
.description(
346+
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
347+
.inputType(MockWeatherService.Request.class)
348+
.build()))
349+
.build();
350+
351+
Flux<ChatResponse> responseFlux = this.chatModel.stream(new Prompt(messages, promptOptions));
352+
353+
ChatResponse chatResponse = responseFlux.last().block();
354+
355+
logger.info("Response: {}", chatResponse);
356+
Usage usage = chatResponse.getMetadata().getUsage();
357+
358+
assertThat(usage).isNotNull();
359+
assertThat(usage.getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
360+
}
361+
327362
@Test
328363
void validateCallResponseMetadata() {
329364
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName();

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void functionCallTest() {
5757

5858
this.contextRunner
5959
.withPropertyValues(
60-
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
60+
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU.getValue())
6161
.run(context -> {
6262

6363
AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class);
@@ -87,7 +87,7 @@ void functionCallWithPortableFunctionCallingOptions() {
8787

8888
this.contextRunner
8989
.withPropertyValues(
90-
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
90+
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU.getValue())
9191
.run(context -> {
9292

9393
AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class);

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class FunctionCallWithPromptFunctionIT {
4949
void functionCallTest() {
5050
this.contextRunner
5151
.withPropertyValues(
52-
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
52+
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU.getValue())
5353
.run(context -> {
5454

5555
AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class);

0 commit comments

Comments
 (0)