Skip to content

Commit 65936a9

Browse files
committed
Enhance Bedrock Converse token handling and tool call processing
- Modify call method to support recursive tool call handling - Add support for cumulative token tracking across tool call iterations - Introduce internal call method to track and aggregate token usage - Merge previous chat response tokens with current response tokens
1 parent 1560694 commit 65936a9

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
169169
*/
170170
@Override
171171
public ChatResponse call(Prompt prompt) {
172+
return this.internalCall(prompt, null);
173+
}
174+
175+
private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) {
172176

173177
ConverseRequest converseRequest = this.createRequest(prompt);
174178

@@ -185,7 +189,7 @@ public ChatResponse call(Prompt prompt) {
185189

186190
ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest);
187191

188-
var response = this.toChatResponse(converseResponse);
192+
var response = this.toChatResponse(converseResponse, perviousChatResponse);
189193

190194
observationContext.setResponse(response);
191195

@@ -195,7 +199,7 @@ public ChatResponse call(Prompt prompt) {
195199
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
196200
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {
197201
var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
198-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
202+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
199203
}
200204

201205
return chatResponse;
@@ -402,7 +406,7 @@ else if (mediaData instanceof URL url) {
402406
* @param response The Bedrock Converse response.
403407
* @return The ChatResponse entity.
404408
*/
405-
private ChatResponse toChatResponse(ConverseResponse response) {
409+
private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) {
406410

407411
Assert.notNull(response, "'response' must not be null.");
408412

@@ -448,8 +452,19 @@ private ChatResponse toChatResponse(ConverseResponse response) {
448452
allGenerations.add(toolCallGeneration);
449453
}
450454

451-
DefaultUsage usage = new DefaultUsage(response.usage().inputTokens().longValue(),
452-
response.usage().outputTokens().longValue(), response.usage().totalTokens().longValue());
455+
Long promptTokens = response.usage().inputTokens().longValue();
456+
Long generationTokens = response.usage().outputTokens().longValue();
457+
Long totalTokens = response.usage().totalTokens().longValue();
458+
459+
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null
460+
&& perviousChatResponse.getMetadata().getUsage() != null) {
461+
462+
promptTokens += perviousChatResponse.getMetadata().getUsage().getPromptTokens();
463+
generationTokens += perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
464+
totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens();
465+
}
466+
467+
DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, totalTokens);
453468

454469
Document modelResponseFields = response.additionalModelResponseFields();
455470

0 commit comments

Comments
 (0)