Skip to content

Commit 1560694

Browse files
committed
Fix Bedrock Converse streaming and token handling
- Modify stream method to support recursive tool call handling - Update token tracking and metadata merging for streamed responses - Improve token usage calculation for tool use events - Update test cases to handle new response processing Resolves #1743
1 parent 6261ce0 commit 1560694

File tree

4 files changed

+79
-19
lines changed

4 files changed

+79
-19
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,16 @@ private ChatResponse toChatResponse(ConverseResponse response) {
473473
*/
474474
@Override
475475
public Flux<ChatResponse> stream(Prompt prompt) {
476+
return this.internalStream(prompt, null);
477+
}
478+
479+
private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousChatResponse) {
476480
Assert.notNull(prompt, "'prompt' must not be null");
477481

478482
return Flux.deferContextual(contextView -> {
479483

480484
ConverseRequest converseRequest = this.createRequest(prompt);
481485

482-
// System.out.println(">>>>> CONVERSE REQUEST: " + converseRequest);
483-
484486
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
485487
.prompt(prompt)
486488
.provider(AiProvider.BEDROCK_CONVERSE.value())
@@ -504,13 +506,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
504506
Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest);
505507

506508
// @formatter:off
507-
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response);
509+
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse);
508510

509511
Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
510512
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
511513
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {
514+
512515
var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
513-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
516+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
514517
}
515518
return Mono.just(chatResponse);
516519
})

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ public static boolean isToolUseFinish(ConverseStreamOutput event) {
9090
return true;
9191
}
9292

93-
public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> responses) {
93+
public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> responses,
94+
ChatResponse perviousChatResponse) {
9495

9596
AtomicBoolean isInsideTool = new AtomicBoolean(false);
9697

@@ -120,20 +121,30 @@ public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> respo
120121

121122
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
122123

124+
Long promptTokens = 0L;
125+
Long generationTokens = 0L;
126+
Long totalTokens = 0L;
127+
123128
for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) {
124129
var functionCallId = toolUseEntry.id();
125130
var functionName = toolUseEntry.name();
126131
var functionArguments = toolUseEntry.input();
127132
toolCalls.add(
128133
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
134+
135+
if (toolUseEntry.usage() != null) {
136+
promptTokens += toolUseEntry.usage().getPromptTokens();
137+
generationTokens += toolUseEntry.usage().getGenerationTokens();
138+
totalTokens += toolUseEntry.usage().getTotalTokens();
139+
}
129140
}
130141

131142
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
132143
Generation toolCallGeneration = new Generation(assistantMessage,
133144
ChatGenerationMetadata.from("tool_use", null));
134145

135146
var chatResponseMetaData = ChatResponseMetadata.builder()
136-
.withUsage(toolUseAggregationEvent.usage)
147+
.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens))
137148
.build();
138149

139150
return new Aggregation(
@@ -181,22 +192,22 @@ else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) {
181192
return new Aggregation();
182193
}
183194
else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
184-
// return new Aggregation();
195+
185196
var newMeta = MetadataAggregation.builder()
186197
.copy(lastAggregation.metadataAggregation())
187198
.withTokenUsage(metadataEvent.usage())
188199
.withMetrics(metadataEvent.metrics())
189200
.withTrace(metadataEvent.trace())
190201
.build();
191202

192-
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
193-
metadataEvent.usage().outputTokens().longValue(),
194-
metadataEvent.usage().totalTokens().longValue());
195-
196203
// TODO
197204
Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields();
198205
ConverseStreamMetrics metrics = metadataEvent.metrics();
199206

207+
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
208+
metadataEvent.usage().outputTokens().longValue(),
209+
metadataEvent.usage().totalTokens().longValue());
210+
200211
var chatResponseMetaData = ChatResponseMetadata.builder().withUsage(usage).build();
201212

202213
return new Aggregation(newMeta, new ChatResponse(List.of(), chatResponseMetaData));
@@ -206,8 +217,42 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
206217
}
207218
})
208219
// .skip(1)
209-
.map(aggregation -> aggregation.chatResponse())
210-
.filter(chatResponse -> chatResponse != ConverseApiUtils.EMPTY_CHAT_RESPONSE);
220+
.filter(aggregation -> aggregation.chatResponse() != ConverseApiUtils.EMPTY_CHAT_RESPONSE)
221+
.map(aggregation -> {
222+
223+
var chatResponse = aggregation.chatResponse();
224+
225+
// Merge the previous chat response metadata with the current one.
226+
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null
227+
&& perviousChatResponse.getMetadata().getUsage() != null) {
228+
229+
var metadataBuilder = ChatResponseMetadata.builder();
230+
231+
Long promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens();
232+
Long generationTokens = perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
233+
Long totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens();
234+
235+
if (chatResponse.getMetadata() != null) {
236+
metadataBuilder.withId(chatResponse.getMetadata().getId());
237+
metadataBuilder.withModel(chatResponse.getMetadata().getModel());
238+
metadataBuilder.withRateLimit(chatResponse.getMetadata().getRateLimit());
239+
metadataBuilder.withPromptMetadata(chatResponse.getMetadata().getPromptMetadata());
240+
241+
if (chatResponse.getMetadata().getUsage() != null) {
242+
promptTokens = promptTokens + chatResponse.getMetadata().getUsage().getPromptTokens();
243+
generationTokens = generationTokens
244+
+ chatResponse.getMetadata().getUsage().getGenerationTokens();
245+
totalTokens = totalTokens + chatResponse.getMetadata().getUsage().getTotalTokens();
246+
}
247+
}
248+
249+
metadataBuilder.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens));
250+
251+
return new ChatResponse(chatResponse.getResults(), metadataBuilder.build());
252+
}
253+
254+
return aggregation.chatResponse();
255+
});
211256
}
212257

213258
public static ConverseStreamOutput mergeToolUseEvents(ConverseStreamOutput previousEvent,
@@ -245,7 +290,7 @@ else if (event.sdkEventType() == EventType.METADATA) {
245290
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
246291
metadataEvent.usage().outputTokens().longValue(), metadataEvent.usage().totalTokens().longValue());
247292
toolUseEventAggregator.withUsage(usage);
248-
// TODO
293+
249294
if (!toolUseEventAggregator.isEmpty()) {
250295
toolUseEventAggregator.squashIntoContentBlock();
251296
return toolUseEventAggregator;
@@ -400,7 +445,7 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
400445
}
401446

402447
void squashIntoContentBlock() {
403-
this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson));
448+
this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson, this.usage));
404449
this.index = null;
405450
this.id = null;
406451
this.name = null;
@@ -424,7 +469,7 @@ public void accept(Visitor visitor) {
424469
throw new UnsupportedOperationException();
425470
}
426471

427-
public record ToolUseEntry(Integer index, String id, String name, String input) {
472+
public record ToolUseEntry(Integer index, String id, String name, String input, DefaultUsage usage) {
428473
}
429474

430475
}

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,18 +274,28 @@ void defaultFunctionCallTest() {
274274
void streamFunctionCallTest() {
275275

276276
// @formatter:off
277-
Flux<String> response = ChatClient.create(this.chatModel).prompt()
277+
Flux<ChatResponse> response = ChatClient.create(this.chatModel).prompt()
278278
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
279279
.functions(FunctionCallback.builder()
280280
.description("Get the weather in location")
281281
.function("getCurrentWeather", new MockWeatherService())
282282
.inputType(MockWeatherService.Request.class)
283283
.build())
284284
.stream()
285-
.content();
285+
.chatResponse();
286286
// @formatter:on
287287

288-
String content = response.collectList().block().stream().collect(Collectors.joining());
288+
List<ChatResponse> chatResponses = response.collectList().block();
289+
290+
chatResponses.forEach(cr -> logger.info("Response: {}", cr));
291+
292+
List<ChatResponse> chatResponses2 = chatResponses.stream()
293+
.filter(cr -> cr.getResult() != null)
294+
.collect(Collectors.toList());
295+
296+
String content = chatResponses2.stream()
297+
.map(cr -> cr.getResult().getOutput().getContent())
298+
.collect(Collectors.joining());
289299
logger.info("Response: {}", content);
290300

291301
assertThat(content).contains("30", "10", "15");

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public static void main(String[] args) {
6969
Flux<ConverseStreamOutput> responses = chatModel.converseStream(streamRequest);
7070
List<ConverseStreamOutput> responseList = responses.collectList().block();
7171
System.out.println(responseList);
72+
System.out.println("Response count: " + responseList.size());
73+
responseList.forEach(System.out::println);
7274
}
7375

7476
}

0 commit comments

Comments
 (0)