Skip to content

Commit 58cf35e

Browse files
sobychackomarkpollack
authored andcommitted
Add prompt caching support for AWS Bedrock Converse API
Implements prompt caching to reduce costs on repeated content and improve response times. Applications with large system prompts, extensive tool definitions, or multi-turn conversations can see significant savings, as cached content costs ~90% less to process than uncached content. Adds five caching strategies to address different use cases: - `SYSTEM_ONLY`: Cache system messages (most common - stable instructions) - `TOOLS_ONLY`: Cache tool definitions (when tools are stable but system varies) - `SYSTEM_AND_TOOLS`: Cache both (when both are large and stable) - `CONVERSATION_HISTORY`: Cache conversation history (for chatbots and assistants) - `NONE`: Default, no caching Implementation: - `BedrockCacheStrategy` enum with `BedrockCacheOptions` configuration class - Integrated with `BedrockChatOptions` (`equals`/`hashCode`/`copy` support) - Cache points applied as separate blocks to satisfy AWS SDK UNION type constraints where each block can only contain one field type - Boolean flags derived from strategy to improve code readability and avoid repetitive conditional checks throughout request building - Last user message pattern for `CONVERSATION_HISTORY` enables incremental caching where each turn builds on the previous cached prefix - Cache metrics exposed via metadata `Map` to maintain provider independence without adding Bedrock-specific fields to shared interfaces - Cache hierarchy respects AWS cascade invalidation (tools → system → messages) to prevent stale cache combinations - Debug logging for troubleshooting cache point application Model compatibility: - Claude 3.x/4.x: All strategies supported - Amazon Nova: `SYSTEM_ONLY` and `CONVERSATION_HISTORY` only (AWS limitation on tool caching for Nova models) Testing: - Integration tests for all strategies using Claude 3.7 Sonnet - Tests handle cache TTL overlap between runs to avoid flakiness in CI Documentation includes usage examples, real-world use cases (legal document analysis, code review, customer support, multi-tenant SaaS), best practices, cache invalidation behavior, and cost considerations. Break-even occurs after one cache hit since cache reads cost ~90% less than base input tokens while cache writes cost ~25% more. Signed-off-by: Soby Chacko <[email protected]>
1 parent 38ea4ff commit 58cf35e

File tree

6 files changed

+1544
-57
lines changed

6 files changed

+1544
-57
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.fasterxml.jackson.annotation.JsonInclude;
3030
import com.fasterxml.jackson.annotation.JsonProperty;
3131

32+
import org.springframework.ai.bedrock.converse.api.BedrockCacheOptions;
3233
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3334
import org.springframework.ai.tool.ToolCallback;
3435
import org.springframework.lang.Nullable;
@@ -81,6 +82,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions {
8182
@JsonIgnore
8283
private Boolean internalToolExecutionEnabled;
8384

85+
@JsonIgnore
86+
private BedrockCacheOptions cacheOptions;
87+
8488
public static Builder builder() {
8589
return new Builder();
8690
}
@@ -101,6 +105,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) {
101105
.toolNames(new HashSet<>(fromOptions.getToolNames()))
102106
.toolContext(new HashMap<>(fromOptions.getToolContext()))
103107
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
108+
.cacheOptions(fromOptions.getCacheOptions())
104109
.build();
105110
}
106111

@@ -237,6 +242,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
237242
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
238243
}
239244

245+
@JsonIgnore
246+
public BedrockCacheOptions getCacheOptions() {
247+
return this.cacheOptions;
248+
}
249+
250+
@JsonIgnore
251+
public void setCacheOptions(BedrockCacheOptions cacheOptions) {
252+
this.cacheOptions = cacheOptions;
253+
}
254+
240255
@Override
241256
@SuppressWarnings("unchecked")
242257
public BedrockChatOptions copy() {
@@ -259,14 +274,15 @@ public boolean equals(Object o) {
259274
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK)
260275
&& Objects.equals(this.topP, that.topP) && Objects.equals(this.toolCallbacks, that.toolCallbacks)
261276
&& Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext)
262-
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled);
277+
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
278+
&& Objects.equals(this.cacheOptions, that.cacheOptions);
263279
}
264280

265281
@Override
266282
public int hashCode() {
267283
return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty,
268284
this.requestParameters, this.stopSequences, this.temperature, this.topK, this.topP, this.toolCallbacks,
269-
this.toolNames, this.toolContext, this.internalToolExecutionEnabled);
285+
this.toolNames, this.toolContext, this.internalToolExecutionEnabled, this.cacheOptions);
270286
}
271287

272288
public static final class Builder {
@@ -356,6 +372,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
356372
return this;
357373
}
358374

375+
public Builder cacheOptions(BedrockCacheOptions cacheOptions) {
376+
this.options.setCacheOptions(cacheOptions);
377+
return this;
378+
}
379+
359380
public BedrockChatOptions build() {
360381
return this.options;
361382
}

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 162 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.time.Duration;
2525
import java.util.ArrayList;
2626
import java.util.Base64;
27+
import java.util.HashMap;
2728
import java.util.List;
2829
import java.util.Map;
2930
import java.util.Set;
@@ -44,6 +45,7 @@
4445
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
4546
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
4647
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
48+
import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock;
4749
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
4850
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
4951
import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
@@ -70,6 +72,8 @@
7072
import software.amazon.awssdk.services.bedrockruntime.model.VideoFormat;
7173
import software.amazon.awssdk.services.bedrockruntime.model.VideoSource;
7274

75+
import org.springframework.ai.bedrock.converse.api.BedrockCacheOptions;
76+
import org.springframework.ai.bedrock.converse.api.BedrockCacheStrategy;
7377
import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
7478
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
7579
import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream;
@@ -314,6 +318,8 @@ else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOp
314318
.internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null
315319
? runtimeOptions.getInternalToolExecutionEnabled()
316320
: this.defaultOptions.getInternalToolExecutionEnabled())
321+
.cacheOptions(runtimeOptions.getCacheOptions() != null ? runtimeOptions.getCacheOptions()
322+
: this.defaultOptions.getCacheOptions())
317323
.build();
318324
}
319325

@@ -324,93 +330,183 @@ else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOp
324330

325331
ConverseRequest createRequest(Prompt prompt) {
326332

327-
List<Message> instructionMessages = prompt.getInstructions()
333+
BedrockChatOptions updatedRuntimeOptions = prompt.getOptions().copy();
334+
335+
// Get cache options to determine strategy
336+
BedrockCacheOptions cacheOptions = updatedRuntimeOptions.getCacheOptions();
337+
boolean shouldCacheConversationHistory = cacheOptions != null
338+
&& cacheOptions.getStrategy() == BedrockCacheStrategy.CONVERSATION_HISTORY;
339+
340+
// Get all non-system messages
341+
List<org.springframework.ai.chat.messages.Message> allNonSystemMessages = prompt.getInstructions()
328342
.stream()
329343
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
330-
.map(message -> {
331-
if (message.getMessageType() == MessageType.USER) {
332-
List<ContentBlock> contents = new ArrayList<>();
333-
if (message instanceof UserMessage userMessage) {
334-
contents.add(ContentBlock.fromText(userMessage.getText()));
335-
336-
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
337-
List<ContentBlock> mediaContent = userMessage.getMedia()
338-
.stream()
339-
.map(this::mapMediaToContentBlock)
340-
.toList();
341-
contents.addAll(mediaContent);
342-
}
343-
}
344-
return Message.builder().content(contents).role(ConversationRole.USER).build();
344+
.toList();
345+
346+
// Find the last user message index for CONVERSATION_HISTORY caching
347+
int lastUserMessageIndex = -1;
348+
if (shouldCacheConversationHistory) {
349+
for (int i = allNonSystemMessages.size() - 1; i >= 0; i--) {
350+
if (allNonSystemMessages.get(i).getMessageType() == MessageType.USER) {
351+
lastUserMessageIndex = i;
352+
break;
345353
}
346-
else if (message.getMessageType() == MessageType.ASSISTANT) {
347-
AssistantMessage assistantMessage = (AssistantMessage) message;
348-
List<ContentBlock> contentBlocks = new ArrayList<>();
349-
if (StringUtils.hasText(message.getText())) {
350-
contentBlocks.add(ContentBlock.fromText(message.getText()));
354+
}
355+
if (logger.isDebugEnabled()) {
356+
logger.debug("CONVERSATION_HISTORY caching: lastUserMessageIndex={}, totalMessages={}",
357+
lastUserMessageIndex, allNonSystemMessages.size());
358+
}
359+
}
360+
361+
// Build instruction messages with potential caching
362+
List<Message> instructionMessages = new ArrayList<>();
363+
for (int i = 0; i < allNonSystemMessages.size(); i++) {
364+
org.springframework.ai.chat.messages.Message message = allNonSystemMessages.get(i);
365+
366+
// Determine if this message should have a cache point
367+
// For CONVERSATION_HISTORY: cache point goes on the last user message
368+
boolean shouldApplyCachePoint = shouldCacheConversationHistory && i == lastUserMessageIndex;
369+
370+
if (message.getMessageType() == MessageType.USER) {
371+
List<ContentBlock> contents = new ArrayList<>();
372+
if (message instanceof UserMessage) {
373+
var userMessage = (UserMessage) message;
374+
contents.add(ContentBlock.fromText(userMessage.getText()));
375+
376+
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
377+
List<ContentBlock> mediaContent = userMessage.getMedia()
378+
.stream()
379+
.map(this::mapMediaToContentBlock)
380+
.toList();
381+
contents.addAll(mediaContent);
351382
}
352-
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
353-
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
383+
}
354384

355-
var argumentsDocument = ConverseApiUtils
356-
.convertObjectToDocument(ModelOptionsUtils.jsonToMap(toolCall.arguments()));
385+
// Apply cache point if this is the last user message
386+
if (shouldApplyCachePoint) {
387+
CachePointBlock cachePoint = CachePointBlock.builder().type("default").build();
388+
contents.add(ContentBlock.fromCachePoint(cachePoint));
389+
logger.debug("Applied cache point on last user message (conversation history caching)");
390+
}
391+
392+
instructionMessages.add(Message.builder().content(contents).role(ConversationRole.USER).build());
393+
}
394+
else if (message.getMessageType() == MessageType.ASSISTANT) {
395+
AssistantMessage assistantMessage = (AssistantMessage) message;
396+
List<ContentBlock> contentBlocks = new ArrayList<>();
397+
if (StringUtils.hasText(message.getText())) {
398+
contentBlocks.add(ContentBlock.fromText(message.getText()));
399+
}
400+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
401+
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
357402

358-
contentBlocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder()
359-
.toolUseId(toolCall.id())
360-
.name(toolCall.name())
361-
.input(argumentsDocument)
362-
.build()));
403+
var argumentsDocument = ConverseApiUtils
404+
.convertObjectToDocument(ModelOptionsUtils.jsonToMap(toolCall.arguments()));
405+
406+
contentBlocks.add(ContentBlock.fromToolUse(ToolUseBlock.builder()
407+
.toolUseId(toolCall.id())
408+
.name(toolCall.name())
409+
.input(argumentsDocument)
410+
.build()));
363411

364-
}
365412
}
366-
return Message.builder().content(contentBlocks).role(ConversationRole.ASSISTANT).build();
367413
}
368-
else if (message.getMessageType() == MessageType.TOOL) {
369-
List<ContentBlock> contentBlocks = ((ToolResponseMessage) message).getResponses()
370-
.stream()
371-
.map(toolResponse -> {
414+
415+
instructionMessages
416+
.add(Message.builder().content(contentBlocks).role(ConversationRole.ASSISTANT).build());
417+
}
418+
else if (message.getMessageType() == MessageType.TOOL) {
419+
List<ContentBlock> contentBlocks = new ArrayList<>(
420+
((ToolResponseMessage) message).getResponses().stream().map(toolResponse -> {
372421
ToolResultBlock toolResultBlock = ToolResultBlock.builder()
373422
.toolUseId(toolResponse.id())
374423
.content(ToolResultContentBlock.builder().text(toolResponse.responseData()).build())
375424
.build();
376425
return ContentBlock.fromToolResult(toolResultBlock);
377-
})
378-
.toList();
379-
return Message.builder().content(contentBlocks).role(ConversationRole.USER).build();
380-
}
381-
else {
382-
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
383-
}
384-
})
385-
.toList();
426+
}).toList());
427+
428+
instructionMessages.add(Message.builder().content(contentBlocks).role(ConversationRole.USER).build());
429+
}
430+
else {
431+
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
432+
}
433+
}
386434

387-
List<SystemContentBlock> systemMessages = prompt.getInstructions()
435+
// Determine if system message caching should be applied
436+
boolean shouldCacheSystem = cacheOptions != null
437+
&& (cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_ONLY
438+
|| cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_AND_TOOLS);
439+
440+
if (logger.isDebugEnabled() && cacheOptions != null) {
441+
logger.debug("Cache strategy: {}, shouldCacheSystem: {}", cacheOptions.getStrategy(), shouldCacheSystem);
442+
}
443+
444+
// Build system messages with optional caching on last message
445+
List<org.springframework.ai.chat.messages.Message> systemMessageList = prompt.getInstructions()
388446
.stream()
389447
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
390-
.map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build())
391448
.toList();
392449

393-
BedrockChatOptions updatedRuntimeOptions = prompt.getOptions().copy();
450+
List<SystemContentBlock> systemMessages = new ArrayList<>();
451+
for (int i = 0; i < systemMessageList.size(); i++) {
452+
org.springframework.ai.chat.messages.Message sysMessage = systemMessageList.get(i);
453+
454+
// Add the text content block
455+
SystemContentBlock textBlock = SystemContentBlock.builder().text(sysMessage.getText()).build();
456+
systemMessages.add(textBlock);
457+
458+
// Apply cache point marker after last system message if caching is enabled
459+
// SystemContentBlock is a UNION type - text and cachePoint must be separate
460+
// blocks
461+
boolean isLastSystem = (i == systemMessageList.size() - 1);
462+
if (isLastSystem && shouldCacheSystem) {
463+
CachePointBlock cachePoint = CachePointBlock.builder().type("default").build();
464+
SystemContentBlock cachePointBlock = SystemContentBlock.builder().cachePoint(cachePoint).build();
465+
systemMessages.add(cachePointBlock);
466+
logger.debug("Applied cache point after system message");
467+
}
468+
}
394469

395470
ToolConfiguration toolConfiguration = null;
396471

397472
// Add the tool definitions to the request's tools parameter.
398473
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(updatedRuntimeOptions);
399474

475+
// Determine if tool caching should be applied
476+
boolean shouldCacheTools = cacheOptions != null
477+
&& (cacheOptions.getStrategy() == BedrockCacheStrategy.TOOLS_ONLY
478+
|| cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_AND_TOOLS);
479+
400480
if (!CollectionUtils.isEmpty(toolDefinitions)) {
401-
List<Tool> bedrockTools = toolDefinitions.stream().map(toolDefinition -> {
481+
List<Tool> bedrockTools = new ArrayList<>();
482+
483+
for (int i = 0; i < toolDefinitions.size(); i++) {
484+
ToolDefinition toolDefinition = toolDefinitions.get(i);
402485
var description = toolDefinition.description();
403486
var name = toolDefinition.name();
404487
String inputSchema = toolDefinition.inputSchema();
405-
return Tool.builder()
488+
489+
// Create tool specification
490+
Tool tool = Tool.builder()
406491
.toolSpec(ToolSpecification.builder()
407492
.name(name)
408493
.description(description)
409494
.inputSchema(ToolInputSchema.fromJson(
410495
ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema))))
411496
.build())
412497
.build();
413-
}).toList();
498+
bedrockTools.add(tool);
499+
500+
// Apply cache point marker after last tool if caching is enabled
501+
// Tool is a UNION type - toolSpec and cachePoint must be separate objects
502+
boolean isLastTool = (i == toolDefinitions.size() - 1);
503+
if (isLastTool && shouldCacheTools) {
504+
CachePointBlock cachePoint = CachePointBlock.builder().type("default").build();
505+
Tool cachePointTool = Tool.builder().cachePoint(cachePoint).build();
506+
bedrockTools.add(cachePointTool);
507+
logger.debug("Applied cache point after tool definitions");
508+
}
509+
}
414510

415511
toolConfiguration = ToolConfiguration.builder().tools(bedrockTools).build();
416512
}
@@ -633,12 +729,23 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
633729

634730
ConverseMetrics metrics = response.metrics();
635731

636-
var chatResponseMetaData = ChatResponseMetadata.builder()
732+
var metadataBuilder = ChatResponseMetadata.builder()
637733
.id(response.responseMetadata() != null ? response.responseMetadata().requestId() : "Unknown")
638-
.usage(usage)
639-
.build();
734+
.usage(usage);
735+
736+
// Add cache metrics if available
737+
Map<String, Object> additionalMetadata = new HashMap<>();
738+
if (response.usage().cacheReadInputTokens() != null) {
739+
additionalMetadata.put("cacheReadInputTokens", response.usage().cacheReadInputTokens());
740+
}
741+
if (response.usage().cacheWriteInputTokens() != null) {
742+
additionalMetadata.put("cacheWriteInputTokens", response.usage().cacheWriteInputTokens());
743+
}
744+
if (!additionalMetadata.isEmpty()) {
745+
metadataBuilder.metadata(additionalMetadata);
746+
}
640747

641-
return new ChatResponse(allGenerations, chatResponseMetaData);
748+
return new ChatResponse(allGenerations, metadataBuilder.build());
642749
}
643750

644751
/**

0 commit comments

Comments
 (0)