diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index 95a74f933c1..4ae4643ffdc 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -18,11 +18,13 @@ import java.util.ArrayList; import java.util.Base64; import java.util.List; -import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; @@ -82,11 +84,17 @@ public ChatResponse call(Prompt prompt) { AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); List generations = response.content().stream().map(content -> { - return new Generation(content.text(), Map.of()) - .withGenerationMetadata(ChatGenerationMetadata.from(response.stopReason(), null)); + return new Generation(new AssistantMessage(content.text()), + ChatGenerationMetadata.from(response.stopReason(), null)); }).toList(); - return new ChatResponse(generations); + ChatResponseMetadata metadata = ChatResponseMetadata.builder() + .withId(response.id()) + .withModel(response.model()) + .withUsage(extractUsage(response.usage())) + .build(); + + return new ChatResponse(generations, metadata); } @Override @@ -116,6 +124,21 @@ public Flux stream(Prompt prompt) { }); } + private Usage extractUsage(Anthropic3ChatBedrockApi.AnthropicUsage usage) { + return new Usage() { + + @Override + public Long getPromptTokens() { + return usage.inputTokens().longValue(); + } + + @Override + public Long getGenerationTokens() { + return usage.outputTokens().longValue(); + } + }; + } + /** * Accessible for testing. */