Skip to content

Commit eb96cfc

Browse files
committed
Include metadata when creating ChatResponse in Bedrock Titan ChatModel
Signed-off-by: jitokim <[email protected]>
1 parent e5410d7 commit eb96cfc

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@
1717
package org.springframework.ai.bedrock.titan;
1818

1919
import java.util.List;
20+
import java.util.concurrent.atomic.AtomicLong;
2021

2122
import reactor.core.publisher.Flux;
2223

24+
import org.springframework.ai.bedrock.BedrockUsage;
2325
import org.springframework.ai.bedrock.MessageToPromptConverter;
2426
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
2527
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest;
2628
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
2729
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
2830
import org.springframework.ai.chat.messages.AssistantMessage;
2931
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
32+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
33+
import org.springframework.ai.chat.metadata.DefaultUsage;
3034
import org.springframework.ai.chat.metadata.Usage;
3135
import org.springframework.ai.chat.model.ChatModel;
3236
import org.springframework.ai.chat.model.ChatResponse;
@@ -42,6 +46,7 @@
4246
* uses the Titan Chat API.
4347
*
4448
* @author Christian Tzolov
49+
* @author Jihoon Kim
4550
* @since 0.8.0
4651
*/
4752
public class BedrockTitanChatModel implements ChatModel, StreamingChatModel {
@@ -63,13 +68,19 @@ public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOption
6368

6469
@Override
6570
public ChatResponse call(Prompt prompt) {
71+
AtomicLong generationTokenCount = new AtomicLong(0L);
6672
TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt));
67-
List<Generation> generations = response.results()
68-
.stream()
69-
.map(result -> new Generation(new AssistantMessage(result.outputText())))
70-
.toList();
71-
72-
return new ChatResponse(generations);
73+
List<Generation> generations = response.results().stream().map(result -> {
74+
generationTokenCount.addAndGet(result.tokenCount());
75+
return new Generation(new AssistantMessage(result.outputText()));
76+
}).toList();
77+
78+
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
79+
.withModel(prompt.getOptions().getModel())
80+
.withUsage(new DefaultUsage(Long.parseLong(String.valueOf(response.inputTextTokenCount())),
81+
generationTokenCount.get()))
82+
.build();
83+
return new ChatResponse(generations, chatResponseMetadata);
7384
}
7485

7586
@Override
@@ -85,8 +96,15 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount(
8596
String completionReason = chunk.completionReason().name();
8697
chatGenerationMetadata = ChatGenerationMetadata.from(completionReason, extractUsage(chunk));
8798
}
99+
100+
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
101+
.withModel(prompt.getOptions().getModel())
102+
.withUsage(BedrockUsage.from(chunk.amazonBedrockInvocationMetrics()))
103+
.build();
104+
88105
return new ChatResponse(
89-
List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata)));
106+
List.of(new Generation(new AssistantMessage(chunk.outputText()), chatGenerationMetadata)),
107+
chatResponseMetadata);
90108
});
91109
}
92110

0 commit comments

Comments
 (0)