diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java
index 380951c7265..79934457abb 100644
--- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java
+++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java
@@ -16,18 +16,6 @@
package org.springframework.ai.bedrock.converse;
-import java.io.IOException;
-import java.io.InputStream;
-import java.net.URISyntaxException;
-import java.net.URL;
-import java.net.URLConnection;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
@@ -73,6 +61,8 @@
import software.amazon.awssdk.services.bedrockruntime.model.VideoBlock;
import software.amazon.awssdk.services.bedrockruntime.model.VideoFormat;
import software.amazon.awssdk.services.bedrockruntime.model.VideoSource;
+import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock;
+import software.amazon.awssdk.services.bedrockruntime.model.CachePointType;
import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
@@ -96,17 +86,33 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
-import org.springframework.ai.model.tool.ToolCallingChatOptions;
-import org.springframework.ai.model.tool.ToolCallingManager;
-import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
-import org.springframework.ai.model.tool.ToolExecutionResult;
+import org.springframework.ai.model.tool.*;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Sinks;
+import reactor.core.publisher.Sinks.EmitFailureHandler;
+import reactor.core.scheduler.Schedulers;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.core.SdkBytes;
+import software.amazon.awssdk.core.document.Document;
+import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
+import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
+import software.amazon.awssdk.services.bedrockruntime.model.*;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URISyntaxException;
+import java.net.URL;
+import java.net.URLConnection;
+import java.time.Duration;
+import java.util.*;
/**
* A {@link ChatModel} implementation that uses the Amazon Bedrock Converse API to
@@ -127,12 +133,15 @@
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
*
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
*
* @author Christian Tzolov
* @author Wei Jiang
* @author Alexandros Pappas
* @author Jihoon Kim
* @author Soby Chacko
+ * @author Brave Lin canhui_lin@fzzixun.com
* @since 1.0.0
*/
public class BedrockProxyChatModel implements ChatModel {
@@ -150,34 +159,34 @@ public class BedrockProxyChatModel implements ChatModel {
private ToolCallingChatOptions defaultOptions;
/**
- * Observation registry used for instrumentation.
- */
+ * Observation registry used for instrumentation.
+ */
private final ObservationRegistry observationRegistry;
private final ToolCallingManager toolCallingManager;
/**
- * The tool execution eligibility predicate used to determine if a tool can be
- * executed.
- */
+ * The tool execution eligibility predicate used to determine if a tool can be
+ * executed.
+ */
private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
/**
- * Conventions to use for generating observations.
- */
+ * Conventions to use for generating observations.
+ */
private ChatModelObservationConvention observationConvention;
public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
- BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
- ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
+ BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
+ ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, defaultOptions, observationRegistry, toolCallingManager,
new DefaultToolExecutionEligibilityPredicate());
}
public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
- BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
- ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager,
- ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
+ BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
+ ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager,
+ ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
@@ -203,13 +212,14 @@ private static ToolCallingChatOptions from(ChatOptions options) {
}
/**
- * Invoke the model and return the response.
- *
- * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
- * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
- * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
- * @return The model invocation response.
- */
+ * Invoke the model and return the response.
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
+ *
+ * @return The model invocation response.
+ */
@Override
public ChatResponse call(Prompt prompt) {
Prompt requestPrompt = buildRequestPrompt(prompt);
@@ -384,7 +394,19 @@ else if (message.getMessageType() == MessageType.TOOL) {
List systemMessages = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
- .map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build())
+ .map(sysMessage -> {
+ /**
+ * add CachePointBlock support
+ * url: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
+ */
+ if(sysMessage.getMetadata()!=null&&sysMessage.getMetadata().get(ConverseApiUtils.CACHE_POINT)!=null){
+ return SystemContentBlock.fromCachePoint(CachePointBlock.builder()
+ .type(CachePointType.DEFAULT)
+ .build());
+ }else{
+ return SystemContentBlock.builder().text(sysMessage.getText()).build();
+ }
+ })
.toList();
ToolCallingChatOptions updatedRuntimeOptions = prompt.getOptions().copy();
@@ -551,12 +573,13 @@ else if (mediaData instanceof URL url) {
}
/**
- * Convert {@link ConverseResponse} to {@link ChatResponse} includes model output,
- * stopReason, usage, metrics etc.
- * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
- * @param response The Bedrock Converse response.
- * @return The ChatResponse entity.
- */
+ * Convert {@link ConverseResponse} to {@link ChatResponse} includes model output,
+ * stopReason, usage, metrics etc.
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
+ *
+ * @param response The Bedrock Converse response.
+ * @return The ChatResponse entity.
+ */
private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) {
Assert.notNull(response, "'response' must not be null.");
@@ -630,13 +653,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
}
/**
- * Invoke the model and return the response stream.
- *
- * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
- * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
- * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
- * @return The model invocation response stream.
- */
+ * Invoke the model and return the response stream.
+ *
+ * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
+ * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
+ *
+ * @return The model invocation response stream.
+ */
@Override
public Flux stream(Prompt prompt) {
Prompt requestPrompt = buildRequestPrompt(prompt);
diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java
index a19de831a7e..d5e27b085d6 100644
--- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java
+++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java
@@ -16,36 +16,8 @@
package org.springframework.ai.bedrock.converse.api;
-import java.math.BigDecimal;
-import java.math.BigInteger;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.stream.Collectors;
-
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-import software.amazon.awssdk.core.SdkField;
-import software.amazon.awssdk.core.document.Document;
-import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
-import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
-import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
-import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
-import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
-import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
-import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetrics;
-import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
-import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType;
-import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor;
-import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamTrace;
-import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent;
-import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
-import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
-import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart;
-
import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
@@ -56,6 +28,22 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import software.amazon.awssdk.core.SdkField;
+import software.amazon.awssdk.core.document.Document;
+import software.amazon.awssdk.services.bedrockruntime.model.*;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput.EventType;
+import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler.Visitor;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
/**
* Amazon Bedrock Converse API utils.
@@ -63,10 +51,14 @@
* @author Wei Jiang
* @author Christian Tzolov
* @author Alexandros Pappas
+ * @author Brave Lin
* @since 1.0.0
*/
public final class ConverseApiUtils {
+ //cachePoint support https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
+ public static final String CACHE_POINT="cachePoint";
+
public static final ChatResponse EMPTY_CHAT_RESPONSE = ChatResponse.builder()
.generations(List.of())
.metadata("empty", true)
@@ -76,6 +68,17 @@ private ConverseApiUtils() {
}
+ /**
+ * buidl aws bedrock prompt-caching
+ * url: https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
+ * @return
+ */
+ public static SystemMessage buildCachePointMesssage(){
+ SystemMessage message = new SystemMessage(CACHE_POINT);
+ message.getMetadata().put(CACHE_POINT, CachePointType.DEFAULT);
+ return message;
+ }
+
public static boolean isToolUseStart(ConverseStreamOutput event) {
if (event == null || event.sdkEventType() == null || event.sdkEventType() != EventType.CONTENT_BLOCK_START) {
return false;
diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java
index 65d85c1a2a4..17f69f09033 100644
--- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java
+++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java
@@ -16,24 +16,17 @@
package org.springframework.ai.bedrock.converse;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import reactor.core.publisher.Flux;
-
+import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
@@ -55,6 +48,14 @@
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeTypeUtils;
+import reactor.core.publisher.Flux;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
@@ -347,4 +348,37 @@ record ActorsFilmsRecord(String actor, List movies) {
}
+ /**
+ * @author Brave Lin
+ * URL:https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
+ * @param modelName
+ */
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "us.anthropic.claude-3-7-sonnet-20250219-v1:0" })
+ void cachePointTest(String modelName) {
+ String systemMessageStr= """
+ You are a helpful AI assistant. Your name is spring.
+ You are an AI assistant that helps people find information.
+ Your name is spring
+ You should reply to the user's request with your name and also in the style of a pirate.
+ """;
+ //Loop 50 times to make the token reach the minimum condition for using the cache
+ String newSystemMessage=systemMessageStr.repeat(50);
+
+ UserMessage userMessage = new UserMessage(
+ "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
+ Message systemMessage = new SystemMessage(newSystemMessage);
+ Prompt prompt = new Prompt(List.of(
+ systemMessage,
+ ConverseApiUtils.buildCachePointMesssage(),//add chache point
+ userMessage),
+ ToolCallingChatOptions.builder().model(modelName).build());
+ ChatResponse response = this.chatModel.call(prompt);
+ assertThat(response.getResults()).hasSize(1);
+ Generation generation = response.getResults().get(0);
+ assertThat(generation.getOutput().getText()).contains("Blackbeard");
+ assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn");
+ logger.info(response.toString());
+ }
+
}