Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -211,6 +214,10 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand All @@ -227,8 +234,14 @@ public ChatResponse call(Prompt prompt) {
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));

ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
AnthropicApi.Usage usage = completionResponse.usage();

Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
: new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);

ChatResponse chatResponse = toChatResponse(completionEntity.getBody(), accumulatedUsage);
observationContext.setResponse(chatResponse);

return chatResponse;
Expand All @@ -237,14 +250,18 @@ public ChatResponse call(Prompt prompt) {
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
&& this.isToolCall(response, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, response);
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}

return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand All @@ -264,11 +281,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {

// @formatter:off
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
AnthropicApi.Usage usage = chatCompletionResponse.usage();
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
}

return Mono.just(chatResponse);
Expand All @@ -282,7 +302,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage usage) {

if (chatCompletion == null) {
logger.warn("Null chat completion returned");
Expand Down Expand Up @@ -328,12 +348,15 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
allGenerations.add(toolCallGeneration);
}

return new ChatResponse(allGenerations, this.from(chatCompletion));
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
return from(result, AnthropicUsage.from(result.usage()));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand Down Expand Up @@ -288,7 +289,12 @@ void functionCallTest() {
logger.info("Response: {}", response);

Generation generation = response.getResult();
assertThat(generation).isNotNull();
assertThat(generation.getOutput()).isNotNull();
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
assertThat(response.getMetadata()).isNotNull();
assertThat(response.getMetadata().getUsage()).isNotNull();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
}

@Test
Expand Down Expand Up @@ -324,6 +330,35 @@ void streamFunctionCallTest() {
assertThat(content).contains("30", "10", "15");
}

@Test
void streamFunctionCallUsageTest() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = AnthropicChatOptions.builder()
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();

Flux<ChatResponse> responseFlux = this.chatModel.stream(new Prompt(messages, promptOptions));

ChatResponse chatResponse = responseFlux.last().block();

logger.info("Response: {}", chatResponse);
Usage usage = chatResponse.getMetadata().getUsage();

assertThat(usage).isNotNull();
assertThat(usage.getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
}

@Test
void validateCallResponseMetadata() {
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName();
Expand Down
Loading