diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 380951c7265..79934457abb 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -16,18 +16,6 @@ package org.springframework.ai.bedrock.converse; -import java.io.IOException; -import java.io.InputStream; -import java.net.URISyntaxException; -import java.net.URL; -import java.net.URLConnection; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import java.util.Set; - import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; @@ -73,6 +61,8 @@ import software.amazon.awssdk.services.bedrockruntime.model.VideoBlock; import software.amazon.awssdk.services.bedrockruntime.model.VideoFormat; import software.amazon.awssdk.services.bedrockruntime.model.VideoSource; +import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock; +import software.amazon.awssdk.services.bedrockruntime.model.CachePointType; import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat; import org.springframework.ai.bedrock.converse.api.ConverseApiUtils; @@ -96,17 +86,33 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.*; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StreamUtils; import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.core.publisher.Sinks.EmitFailureHandler; +import reactor.core.scheduler.Schedulers; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.*; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.net.URL; +import java.net.URLConnection; +import java.time.Duration; +import java.util.*; /** * A {@link ChatModel} implementation that uses the Amazon Bedrock Converse API to @@ -127,12 +133,15 @@ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html *

* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + *

+ * https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html * * @author Christian Tzolov * @author Wei Jiang * @author Alexandros Pappas * @author Jihoon Kim * @author Soby Chacko + * @author Brave Lin canhui_lin@fzzixun.com * @since 1.0.0 */ public class BedrockProxyChatModel implements ChatModel { @@ -150,34 +159,34 @@ public class BedrockProxyChatModel implements ChatModel { private ToolCallingChatOptions defaultOptions; /** - * Observation registry used for instrumentation. - */ + * Observation registry used for instrumentation. + */ private final ObservationRegistry observationRegistry; private final ToolCallingManager toolCallingManager; /** - * The tool execution eligibility predicate used to determine if a tool can be - * executed. - */ + * The tool execution eligibility predicate used to determine if a tool can be + * executed. + */ private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; /** - * Conventions to use for generating observations. - */ + * Conventions to use for generating observations. + */ private ChatModelObservationConvention observationConvention; public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, - BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, - ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) { + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, + ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) { this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, defaultOptions, observationRegistry, toolCallingManager, new DefaultToolExecutionEligibilityPredicate()); } public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, - BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, - ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager, - ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, + ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null"); Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null"); @@ -203,13 +212,14 @@ private static ToolCallingChatOptions from(ChatOptions options) { } /** - * Invoke the model and return the response. - * - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html - * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse - * @return The model invocation response. - */ + * Invoke the model and return the response. + *

+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse + * + * @return The model invocation response. + */ @Override public ChatResponse call(Prompt prompt) { Prompt requestPrompt = buildRequestPrompt(prompt); @@ -384,7 +394,19 @@ else if (message.getMessageType() == MessageType.TOOL) { List systemMessages = prompt.getInstructions() .stream() .filter(m -> m.getMessageType() == MessageType.SYSTEM) - .map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build()) + .map(sysMessage -> { + /** + * add CachePointBlock support + * url: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + */ + if(sysMessage.getMetadata()!=null&&sysMessage.getMetadata().get(ConverseApiUtils.CACHE_POINT)!=null){ + return SystemContentBlock.fromCachePoint(CachePointBlock.builder() + .type(CachePointType.DEFAULT) + .build()); + }else{ + return SystemContentBlock.builder().text(sysMessage.getText()).build(); + } + }) .toList(); ToolCallingChatOptions updatedRuntimeOptions = prompt.getOptions().copy(); @@ -551,12 +573,13 @@ else if (mediaData instanceof URL url) { } /** - * Convert {@link ConverseResponse} to {@link ChatResponse} includes model output, - * stopReason, usage, metrics etc. - * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax - * @param response The Bedrock Converse response. - * @return The ChatResponse entity. - */ + * Convert {@link ConverseResponse} to {@link ChatResponse} includes model output, + * stopReason, usage, metrics etc. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax + * + * @param response The Bedrock Converse response. + * @return The ChatResponse entity. + */ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) { Assert.notNull(response, "'response' must not be null."); @@ -630,13 +653,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv } /** - * Invoke the model and return the response stream. - * - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html - * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream - * @return The model invocation response stream. - */ + * Invoke the model and return the response stream. + *

+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream + * + * @return The model invocation response stream. + */ @Override public Flux stream(Prompt prompt) { Prompt requestPrompt = buildRequestPrompt(prompt); diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index a19de831a7e..d5e27b085d6 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -16,36 +16,8 @@ package org.springframework.ai.bedrock.converse.api; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.stream.Collectors; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import software.amazon.awssdk.core.SdkField; -import software.amazon.awssdk.core.document.Document; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamTrace; -import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent; -import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent; -import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; -import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart; - import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; @@ -56,6 +28,22 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.services.bedrockruntime.model.*; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; /** * Amazon Bedrock Converse API utils. @@ -63,10 +51,14 @@ * @author Wei Jiang * @author Christian Tzolov * @author Alexandros Pappas + * @author Brave Lin * @since 1.0.0 */ public final class ConverseApiUtils { + //cachePoint support https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + public static final String CACHE_POINT="cachePoint"; + public static final ChatResponse EMPTY_CHAT_RESPONSE = ChatResponse.builder() .generations(List.of()) .metadata("empty", true) @@ -76,6 +68,17 @@ private ConverseApiUtils() { } + /** + * buidl aws bedrock prompt-caching + * url: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + * @return + */ + public static SystemMessage buildCachePointMesssage(){ + SystemMessage message = new SystemMessage(CACHE_POINT); + message.getMetadata().put(CACHE_POINT, CachePointType.DEFAULT); + return message; + } + public static boolean isToolUseStart(ConverseStreamOutput event) { if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.CONTENT_BLOCK_START) { return false; diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index 65d85c1a2a4..17f69f09033 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -16,24 +16,17 @@ package org.springframework.ai.bedrock.converse; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; - +import org.springframework.ai.bedrock.converse.api.ConverseApiUtils; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -55,6 +48,14 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Flux; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -347,4 +348,37 @@ record ActorsFilmsRecord(String actor, List movies) { } + /** + * @author Brave Lin + * URL:https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + * @param modelName + */ + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "us.anthropic.claude-3-7-sonnet-20250219-v1:0" }) + void cachePointTest(String modelName) { + String systemMessageStr= """ + You are a helpful AI assistant. Your name is spring. + You are an AI assistant that helps people find information. + Your name is spring + You should reply to the user's request with your name and also in the style of a pirate. + """; + //Loop 50 times to make the token reach the minimum condition for using the cache + String newSystemMessage=systemMessageStr.repeat(50); + + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + Message systemMessage = new SystemMessage(newSystemMessage); + Prompt prompt = new Prompt(List.of( + systemMessage, + ConverseApiUtils.buildCachePointMesssage(),//add chache point + userMessage), + ToolCallingChatOptions.builder().model(modelName).build()); + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResults()).hasSize(1); + Generation generation = response.getResults().get(0); + assertThat(generation.getOutput().getText()).contains("Blackbeard"); + assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn"); + logger.info(response.toString()); + } + }