From af870b39268fa83ffe27a1dd6836df28a10ca457 Mon Sep 17 00:00:00 2001 From: jitokim Date: Fri, 22 Nov 2024 03:15:07 +0900 Subject: [PATCH] Include metadata when creating ChatResponse in Bedrock Titan ChatModel Signed-off-by: jitokim --- .../bedrock/titan/BedrockTitanChatModel.java | 32 ++++++++--- .../titan/BedrockTitanChatModelTests.java | 53 +++++++++++++++++++ 2 files changed, 78 insertions(+), 7 deletions(-) create mode 100644 models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelTests.java diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index 57c832fd6de..34f0d7554a5 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -17,9 +17,11 @@ package org.springframework.ai.bedrock.titan; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import reactor.core.publisher.Flux; +import org.springframework.ai.bedrock.BedrockUsage; import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; @@ -27,6 +29,8 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -42,6 +46,7 @@ * uses the Titan Chat API. * * @author Christian Tzolov + * @author Jihoon Kim * @since 0.8.0 */ public class BedrockTitanChatModel implements ChatModel, StreamingChatModel { @@ -63,13 +68,19 @@ public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOption @Override public ChatResponse call(Prompt prompt) { + AtomicLong generationTokenCount = new AtomicLong(0L); TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); - List generations = response.results() - .stream() - .map(result -> new Generation(new AssistantMessage(result.outputText()))) - .toList(); - - return new ChatResponse(generations); + List generations = response.results().stream().map(result -> { + generationTokenCount.addAndGet(result.tokenCount()); + return new Generation(new AssistantMessage(result.outputText())); + }).toList(); + + ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() + .withModel(this.chatApi.getModelId()) + .withUsage(new DefaultUsage(Long.parseLong(String.valueOf(response.inputTextTokenCount())), + generationTokenCount.get())) + .build(); + return new ChatResponse(generations, chatResponseMetadata); } @Override @@ -85,8 +96,15 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount( String completionReason = chunk.completionReason().name(); chatGenerationMetadata = ChatGenerationMetadata.from(completionReason, extractUsage(chunk)); } + + ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() + .withModel(this.chatApi.getModelId()) + .withUsage(BedrockUsage.from(chunk.amazonBedrockInvocationMetrics())) + .build(); + return new ChatResponse( - List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata))); + List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata)), + chatResponseMetadata); }); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelTests.java new file mode 100644 index 00000000000..2265f1d56c6 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelTests.java @@ -0,0 +1,53 @@ +package org.springframework.ai.bedrock.titan; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; + +import java.time.Duration; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.CompletionReason; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.Result; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; + +/** + * @author Jihoon Kim + * @since 1.0.0 M4 + */ +@ExtendWith(MockitoExtension.class) +public class BedrockTitanChatModelTests { + + @Mock + TitanChatBedrockApi chatApi = new TitanChatBedrockApi(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), ModelOptionsUtils.OBJECT_MAPPER, + Duration.ofMinutes(2)); + + @Test + public void call_test() { + given(chatApi.getModelId()).willReturn(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id()); + given(chatApi.chatCompletion(any())).willReturn(new TitanChatResponse(4, List + .of(new Result(3, "this is joke", null), new Result(4, "see you next time", CompletionReason.FINISH)))); + BedrockTitanChatModel chatModel = new BedrockTitanChatModel(chatApi); + ChatResponse response = chatModel.call(new Prompt("Tell me a joke", BedrockTitanChatOptions.builder().build())); + + Usage usage = response.getMetadata().getUsage(); + + assert null != response.getMetadata(); + assert usage.getPromptTokens() == 4; + assert usage.getGenerationTokens() == 7; + } + +}