From 0555cbceaf9ff6ef3beb9fa2759a62cbf250c757 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 19 Nov 2024 15:36:52 +0100 Subject: [PATCH] fix(bedrock-converse): Update tool use handling and add usage aggregation tests - Replace hardcoded "tool_use" with StopReason enum value - Add tests for token usage aggregation with tool calls - Add handling for null response metadata --- .../converse/BedrockProxyChatModel.java | 10 +- .../converse/BedrockConverseChatClientIT.java | 1 - .../BedrockConverseUsageAggregationTests.java | 167 ++++++++++++++++++ .../BedrockConverseChatModelMain.java | 2 +- .../BedrockConverseChatModelMain2.java | 2 +- .../BedrockConverseChatModelMain3.java | 71 ++++++++ .../FunctionCallingOptionsBuilder.java | 6 +- 7 files changed, 250 insertions(+), 9 deletions(-) create mode 100644 models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java rename models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/{experiements => experiments}/BedrockConverseChatModelMain.java (96%) rename models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/{experiements => experiments}/BedrockConverseChatModelMain2.java (97%) create mode 100644 models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 440a571630f..d2f200c5035 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -55,6 +55,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.ImageSource; import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration; import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.StopReason; import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; import software.amazon.awssdk.services.bedrockruntime.model.Tool; import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; @@ -189,6 +190,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest); + logger.debug("ConverseResponse: {}", converseResponse); + var response = this.toChatResponse(converseResponse, perviousChatResponse); observationContext.setResponse(response); @@ -197,7 +200,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon }); if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null - && this.isToolCall(chatResponse, Set.of("tool_use"))) { + && this.isToolCall(chatResponse, Set.of(StopReason.TOOL_USE.toString()))) { var toolCallConversation = this.handleToolCalls(prompt, chatResponse); return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); } @@ -471,7 +474,7 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv ConverseMetrics metrics = response.metrics(); var chatResponseMetaData = ChatResponseMetadata.builder() - .withId(response.responseMetadata().requestId()) + .withId(response.responseMetadata() != null ? response.responseMetadata().requestId() : "Unknown") .withUsage(usage) .build(); @@ -525,7 +528,8 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null - && this.isToolCall(chatResponse, Set.of("tool_use"))) { + && this.isToolCall(chatResponse, Set.of(StopReason.TOOL_USE.toString()))) { + var toolCallConversation = this.handleToolCalls(prompt, chatResponse); return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index de0ab4f1f43..f49cc87ebf0 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -49,7 +49,6 @@ import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.matches; @SpringBootTest(classes = BedrockConverseTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java new file mode 100644 index 00000000000..a1c301bd019 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java @@ -0,0 +1,167 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse; + +import java.util.List; + +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.core.document.internal.MapDocument; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.StopReason; +import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; +import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; + +/** + * @author Christian Tzolov + */ +@ExtendWith(MockitoExtension.class) +public class BedrockConverseUsageAggregationTests { + + private @Mock BedrockRuntimeClient bedrockRuntimeClient; + + private @Mock BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; + + private BedrockProxyChatModel chatModel; + + @BeforeEach + public void beforeEach() { + chatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, + FunctionCallingOptions.builder().build(), null, List.of(), ObservationRegistry.NOOP); + } + + @Test + public void call() { + ConverseResponse converseResponse = ConverseResponse.builder() + + .output(ConverseOutput.builder() + .message(Message.builder() + .role(ConversationRole.ASSISTANT) + .content(ContentBlock.fromText("Response Content Block")) + .build()) + .build()) + .usage(TokenUsage.builder().inputTokens(16).outputTokens(14).totalTokens(30).build()) + .build(); + + given(this.bedrockRuntimeClient.converse(isA(ConverseRequest.class))).willReturn(converseResponse); + + var result = this.chatModel.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isSameAs("Response Content Block"); + + assertThat(result.getMetadata().getUsage().getPromptTokens()).isEqualTo(16); + assertThat(result.getMetadata().getUsage().getGenerationTokens()).isEqualTo(14); + assertThat(result.getMetadata().getUsage().getTotalTokens()).isEqualTo(30); + } + + public record Request(String location, String unit) { + } + + @Test + public void callWithToolUse() { + + ConverseResponse converseResponseToolUse = ConverseResponse.builder() + .output(ConverseOutput.builder() + .message(Message.builder() + .role(ConversationRole.ASSISTANT) + .content(ContentBlock.fromText( + "Certainly! I'd be happy to check the current weather in Paris for you, with the temperature in Celsius. To get this information, I'll use the getCurrentWeather function. Let me fetch that for you right away."), + ContentBlock.fromToolUse(ToolUseBlock.builder() + .toolUseId("tooluse_2SZuiUDkRbeGysun8O2Wag") + .name("getCurrentWeather") + .input(MapDocument.mapBuilder() + .putString("location", "Paris, France") + .putString("unit", "C") + .build()) + .build())) + + .build()) + .build()) + .usage(TokenUsage.builder().inputTokens(445).outputTokens(119).totalTokens(564).build()) + .stopReason(StopReason.TOOL_USE) + .metrics(ConverseMetrics.builder().latencyMs(3435L).build()) + .build(); + + ConverseResponse converseResponseFinal = ConverseResponse.builder() + .output(ConverseOutput.builder() + .message(Message.builder() + .role(ConversationRole.ASSISTANT) + .content(ContentBlock.fromText( + """ + Based on the information from the weather tool, the current temperature in Paris, France is 15.0°C (Celsius). + + Please note that weather conditions can change throughout the day, so this temperature represents the current + reading at the time of the request. If you need more detailed information about the weather in Paris, such as + humidity, wind speed, or forecast for the coming days, please let me know, and I'll be happy to provide more + details if that information is available through our weather service. + """)) + .build()) + .build()) + .usage(TokenUsage.builder().inputTokens(540).outputTokens(106).totalTokens(646).build()) + .stopReason(StopReason.END_TURN) + .metrics(ConverseMetrics.builder().latencyMs(3435L).build()) + .build(); + + given(this.bedrockRuntimeClient.converse(isA(ConverseRequest.class))).willReturn(converseResponseToolUse) + .willReturn(converseResponseFinal); + + FunctionCallback functionCallback = FunctionCallback.builder() + .description("Gets the weather in location") + .function("getCurrentWeather", (Request request) -> "15.0°C") + .inputType(Request.class) + .build(); + + var result = this.chatModel.call(new Prompt("What is the weather in Paris?", + PortableFunctionCallingOptions.builder().withFunctionCallbacks(functionCallback).build())); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()) + .isSameAs(converseResponseFinal.output().message().content().get(0).text()); + + assertThat(result.getMetadata().getUsage().getPromptTokens()).isEqualTo(445 + 540); + assertThat(result.getMetadata().getUsage().getGenerationTokens()).isEqualTo(119 + 106); + assertThat(result.getMetadata().getUsage().getTotalTokens()).isEqualTo(564 + 646); + } + + @Test + public void streamWithToolUse() { + // TODO: Implement the test + } + +} diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java similarity index 96% rename from models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain.java rename to models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java index f7978fa1bb9..7404f3d4b26 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.bedrock.converse.experiements; +package org.springframework.ai.bedrock.converse.experiments; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java similarity index 97% rename from models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java rename to models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java index 54cd9833202..47e69dbd3f1 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.bedrock.converse.experiements; +package org.springframework.ai.bedrock.converse.experiments; import java.util.List; diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java new file mode 100644 index 00000000000..078c3b2a645 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse.experiments; + +import java.util.List; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; +import org.springframework.ai.bedrock.converse.MockWeatherService; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; + +/** + * Used for reverse engineering the protocol + */ +public final class BedrockConverseChatModelMain3 { + + private BedrockConverseChatModelMain3() { + + } + + public static void main(String[] args) { + + // String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + // String modelId = "ai21.jamba-1-5-large-v1:0"; + String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + + // var prompt = new Prompt("Tell me a joke?", + // ChatOptionsBuilder.builder().withModel(modelId).build()); + var prompt = new Prompt( + // "What's the weather like in San Francisco, Tokyo, and Paris? Return the + // temperature in Celsius.", + "What's the weather like in Paris? Return the temperature in Celsius.", + PortableFunctionCallingOptions.builder() + .withModel(modelId) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .description("Get the weather in location") + .function("getCurrentWeather", new MockWeatherService()) + .inputType(MockWeatherService.Request.class) + .build())) + .build()); + + BedrockProxyChatModel chatModel = BedrockProxyChatModel.builder() + .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .withRegion(Region.US_EAST_1) + .build(); + + var response = chatModel.call(prompt); + + System.out.println(response); + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index ba64c8e81eb..a4ceca490ef 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -51,9 +51,9 @@ public FunctionCallingOptionsBuilder withFunctionCallbacks(List