diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index 57c832fd6de..a818c930145 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -17,9 +17,11 @@ package org.springframework.ai.bedrock.titan; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import reactor.core.publisher.Flux; +import org.springframework.ai.bedrock.BedrockUsage; import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; @@ -27,6 +29,8 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -42,6 +46,7 @@ * uses the Titan Chat API. * * @author Christian Tzolov + * @author Jihoon Kim * @since 0.8.0 */ public class BedrockTitanChatModel implements ChatModel, StreamingChatModel { @@ -63,13 +68,19 @@ public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOption @Override public ChatResponse call(Prompt prompt) { + AtomicLong generationTokenCount = new AtomicLong(0L); TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); - List generations = response.results() - .stream() - .map(result -> new Generation(new AssistantMessage(result.outputText()))) - .toList(); - - return new ChatResponse(generations); + List generations = response.results().stream().map(result -> { + generationTokenCount.addAndGet(result.tokenCount()); + return new Generation(new AssistantMessage(result.outputText())); + }).toList(); + + ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() + .withModel(prompt.getOptions().getModel()) + .withUsage(new DefaultUsage(Long.parseLong(String.valueOf(response.inputTextTokenCount())), + generationTokenCount.get())) + .build(); + return new ChatResponse(generations, chatResponseMetadata); } @Override @@ -85,8 +96,15 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount( String completionReason = chunk.completionReason().name(); chatGenerationMetadata = ChatGenerationMetadata.from(completionReason, extractUsage(chunk)); } + + ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() + .withModel(prompt.getOptions().getModel()) + .withUsage(BedrockUsage.from(chunk.amazonBedrockInvocationMetrics())) + .build(); + return new ChatResponse( - List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata))); + List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata)), + chatResponseMetadata); }); }