From be14940c0b58d0e2deff0c7763979b63a67fe225 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 19 Nov 2024 15:36:52 +0100 Subject: [PATCH 1/3] Fix Bedrock Converse streaming and token handling - Modify stream method to support recursive tool call handling - Update token tracking and metadata merging for streamed responses - Improve token usage calculation for tool use events - Update test cases to handle new response processing Resolves #1743 --- .../converse/BedrockProxyChatModel.java | 11 +-- .../converse/api/ConverseApiUtils.java | 69 +++++++++++++++---- .../converse/BedrockConverseChatClientIT.java | 16 ++++- .../BedrockConverseChatModelMain2.java | 2 + 4 files changed, 79 insertions(+), 19 deletions(-) diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index c437786dd5e..6e1300d6e29 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -473,14 +473,16 @@ private ChatResponse toChatResponse(ConverseResponse response) { */ @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + private Flux internalStream(Prompt prompt, ChatResponse perviousChatResponse) { Assert.notNull(prompt, "'prompt' must not be null"); return Flux.deferContextual(contextView -> { ConverseRequest converseRequest = this.createRequest(prompt); - // System.out.println(">>>>> CONVERSE REQUEST: " + converseRequest); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.BEDROCK_CONVERSE.value()) @@ -504,13 +506,14 @@ public Flux stream(Prompt prompt) { Flux response = converseStream(converseStreamRequest); // @formatter:off - Flux chatResponses = ConverseApiUtils.toChatResponse(response); + Flux chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse); Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null && this.isToolCall(chatResponse, Set.of("tool_use"))) { + var toolCallConversation = this.handleToolCalls(prompt, chatResponse); - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); } return Mono.just(chatResponse); }) diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index 29038d972f4..fa8da7dbe4b 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -90,7 +90,8 @@ public static boolean isToolUseFinish(ConverseStreamOutput event) { return true; } - public static Flux toChatResponse(Flux responses) { + public static Flux toChatResponse(Flux responses, + ChatResponse perviousChatResponse) { AtomicBoolean isInsideTool = new AtomicBoolean(false); @@ -120,12 +121,22 @@ public static Flux toChatResponse(Flux respo List toolCalls = new ArrayList<>(); + Long promptTokens = 0L; + Long generationTokens = 0L; + Long totalTokens = 0L; + for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) { var functionCallId = toolUseEntry.id(); var functionName = toolUseEntry.name(); var functionArguments = toolUseEntry.input(); toolCalls.add( new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); + + if (toolUseEntry.usage() != null) { + promptTokens += toolUseEntry.usage().getPromptTokens(); + generationTokens += toolUseEntry.usage().getGenerationTokens(); + totalTokens += toolUseEntry.usage().getTotalTokens(); + } } AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); @@ -133,7 +144,7 @@ public static Flux toChatResponse(Flux respo ChatGenerationMetadata.from("tool_use", null)); var chatResponseMetaData = ChatResponseMetadata.builder() - .withUsage(toolUseAggregationEvent.usage) + .withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens)) .build(); return new Aggregation( @@ -181,7 +192,7 @@ else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) { return new Aggregation(); } else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { - // return new Aggregation(); + var newMeta = MetadataAggregation.builder() .copy(lastAggregation.metadataAggregation()) .withTokenUsage(metadataEvent.usage()) @@ -189,14 +200,14 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { .withTrace(metadataEvent.trace()) .build(); - DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(), - metadataEvent.usage().outputTokens().longValue(), - metadataEvent.usage().totalTokens().longValue()); - // TODO Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields(); ConverseStreamMetrics metrics = metadataEvent.metrics(); + DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(), + metadataEvent.usage().outputTokens().longValue(), + metadataEvent.usage().totalTokens().longValue()); + var chatResponseMetaData = ChatResponseMetadata.builder().withUsage(usage).build(); return new Aggregation(newMeta, new ChatResponse(List.of(), chatResponseMetaData)); @@ -206,8 +217,42 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) { } }) // .skip(1) - .map(aggregation -> aggregation.chatResponse()) - .filter(chatResponse -> chatResponse != ConverseApiUtils.EMPTY_CHAT_RESPONSE); + .filter(aggregation -> aggregation.chatResponse() != ConverseApiUtils.EMPTY_CHAT_RESPONSE) + .map(aggregation -> { + + var chatResponse = aggregation.chatResponse(); + + // Merge the previous chat response metadata with the current one. + if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null + && perviousChatResponse.getMetadata().getUsage() != null) { + + var metadataBuilder = ChatResponseMetadata.builder(); + + Long promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens(); + Long generationTokens = perviousChatResponse.getMetadata().getUsage().getGenerationTokens(); + Long totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens(); + + if (chatResponse.getMetadata() != null) { + metadataBuilder.withId(chatResponse.getMetadata().getId()); + metadataBuilder.withModel(chatResponse.getMetadata().getModel()); + metadataBuilder.withRateLimit(chatResponse.getMetadata().getRateLimit()); + metadataBuilder.withPromptMetadata(chatResponse.getMetadata().getPromptMetadata()); + + if (chatResponse.getMetadata().getUsage() != null) { + promptTokens = promptTokens + chatResponse.getMetadata().getUsage().getPromptTokens(); + generationTokens = generationTokens + + chatResponse.getMetadata().getUsage().getGenerationTokens(); + totalTokens = totalTokens + chatResponse.getMetadata().getUsage().getTotalTokens(); + } + } + + metadataBuilder.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens)); + + return new ChatResponse(chatResponse.getResults(), metadataBuilder.build()); + } + + return aggregation.chatResponse(); + }); } public static ConverseStreamOutput mergeToolUseEvents(ConverseStreamOutput previousEvent, @@ -245,7 +290,7 @@ else if (event.sdkEventType() == EventType.METADATA) { DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(), metadataEvent.usage().outputTokens().longValue(), metadataEvent.usage().totalTokens().longValue()); toolUseEventAggregator.withUsage(usage); - // TODO + if (!toolUseEventAggregator.isEmpty()) { toolUseEventAggregator.squashIntoContentBlock(); return toolUseEventAggregator; @@ -400,7 +445,7 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) { } void squashIntoContentBlock() { - this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson)); + this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson, this.usage)); this.index = null; this.id = null; this.name = null; @@ -424,7 +469,7 @@ public void accept(Visitor visitor) { throw new UnsupportedOperationException(); } - public record ToolUseEntry(Integer index, String id, String name, String input) { + public record ToolUseEntry(Integer index, String id, String name, String input, DefaultUsage usage) { } } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 58e658164f2..8a5839ed238 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -274,7 +274,7 @@ void defaultFunctionCallTest() { void streamFunctionCallTest() { // @formatter:off - Flux response = ChatClient.create(this.chatModel).prompt() + Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") .functions(FunctionCallback.builder() .description("Get the weather in location") @@ -282,10 +282,20 @@ void streamFunctionCallTest() { .inputType(MockWeatherService.Request.class) .build()) .stream() - .content(); + .chatResponse(); // @formatter:on - String content = response.collectList().block().stream().collect(Collectors.joining()); + List chatResponses = response.collectList().block(); + + chatResponses.forEach(cr -> logger.info("Response: {}", cr)); + + List chatResponses2 = chatResponses.stream() + .filter(cr -> cr.getResult() != null) + .collect(Collectors.toList()); + + String content = chatResponses2.stream() + .map(cr -> cr.getResult().getOutput().getContent()) + .collect(Collectors.joining()); logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java index d0cd9320d74..54cd9833202 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java @@ -69,6 +69,8 @@ public static void main(String[] args) { Flux responses = chatModel.converseStream(streamRequest); List responseList = responses.collectList().block(); System.out.println(responseList); + System.out.println("Response count: " + responseList.size()); + responseList.forEach(System.out::println); } } From dfb4b479ea8f0882f46e268d87d5248ef6b3f0e8 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 19 Nov 2024 16:02:30 +0100 Subject: [PATCH 2/3] Enhance Bedrock Converse token handling and tool call processing - Modify call method to support recursive tool call handling - Add support for cumulative token tracking across tool call iterations - Introduce internal call method to track and aggregate token usage - Merge previous chat response tokens with current response tokens --- .../converse/BedrockProxyChatModel.java | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 6e1300d6e29..e7fdffb1c11 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -169,6 +169,10 @@ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, */ @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) { ConverseRequest converseRequest = this.createRequest(prompt); @@ -185,7 +189,7 @@ public ChatResponse call(Prompt prompt) { ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest); - var response = this.toChatResponse(converseResponse); + var response = this.toChatResponse(converseResponse, perviousChatResponse); observationContext.setResponse(response); @@ -195,7 +199,7 @@ public ChatResponse call(Prompt prompt) { if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null && this.isToolCall(chatResponse, Set.of("tool_use"))) { var toolCallConversation = this.handleToolCalls(prompt, chatResponse); - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); } return chatResponse; @@ -402,7 +406,7 @@ else if (mediaData instanceof URL url) { * @param response The Bedrock Converse response. * @return The ChatResponse entity. */ - private ChatResponse toChatResponse(ConverseResponse response) { + private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) { Assert.notNull(response, "'response' must not be null."); @@ -448,8 +452,19 @@ private ChatResponse toChatResponse(ConverseResponse response) { allGenerations.add(toolCallGeneration); } - DefaultUsage usage = new DefaultUsage(response.usage().inputTokens().longValue(), - response.usage().outputTokens().longValue(), response.usage().totalTokens().longValue()); + Long promptTokens = response.usage().inputTokens().longValue(); + Long generationTokens = response.usage().outputTokens().longValue(); + Long totalTokens = response.usage().totalTokens().longValue(); + + if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null + && perviousChatResponse.getMetadata().getUsage() != null) { + + promptTokens += perviousChatResponse.getMetadata().getUsage().getPromptTokens(); + generationTokens += perviousChatResponse.getMetadata().getUsage().getGenerationTokens(); + totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens(); + } + + DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, totalTokens); Document modelResponseFields = response.additionalModelResponseFields(); From 625a55736c4fb62b409eeca6773e33ad4a6e07c1 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 20 Nov 2024 09:12:13 +0100 Subject: [PATCH 3/3] minor tests adjustments --- .../converse/BedrockConverseChatClientIT.java | 57 +++++++++++++++++-- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 8a5839ed238..de0ab4f1f43 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -49,6 +49,7 @@ import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.matches; @SpringBootTest(classes = BedrockConverseTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -227,6 +228,41 @@ void functionCallTest() { assertThat(response).contains("30", "10", "15"); } + @Test + void functionCallWithUsageMetadataTest() { + + // @formatter:off + ChatResponse response = ChatClient.create(this.chatModel) + .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") + .functions(FunctionCallback.builder() + .description("Get the weather in location") + .function("getCurrentWeather", new MockWeatherService()) + .inputType(MockWeatherService.Request.class) + .build()) + .call() + .chatResponse(); + // @formatter:on + + var metadata = response.getMetadata(); + + assertThat(metadata.getUsage()).isNotNull(); + + logger.info(metadata.getUsage().toString()); + + assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(500); + assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); + + assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500); + + assertThat(metadata.getUsage().getTotalTokens()) + .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens()); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + @Test void functionCallWithAdvisorTest() { @@ -287,13 +323,24 @@ void streamFunctionCallTest() { List chatResponses = response.collectList().block(); - chatResponses.forEach(cr -> logger.info("Response: {}", cr)); + // chatResponses.forEach(cr -> logger.info("Response: {}", cr)); + var lastChatResponse = chatResponses.get(chatResponses.size() - 1); + var metadata = lastChatResponse.getMetadata(); + assertThat(metadata.getUsage()).isNotNull(); - List chatResponses2 = chatResponses.stream() - .filter(cr -> cr.getResult() != null) - .collect(Collectors.toList()); + logger.info(metadata.getUsage().toString()); + + assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500); + assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500); - String content = chatResponses2.stream() + assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500); + + assertThat(metadata.getUsage().getTotalTokens()) + .isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens()); + + String content = chatResponses.stream() + .filter(cr -> cr.getResult() != null) .map(cr -> cr.getResult().getOutput().getContent()) .collect(Collectors.joining()); logger.info("Response: {}", content);