2424import java .time .Duration ;
2525import java .util .ArrayList ;
2626import java .util .Base64 ;
27+ import java .util .HashMap ;
2728import java .util .List ;
2829import java .util .Map ;
2930import java .util .Set ;
4445import software .amazon .awssdk .regions .providers .DefaultAwsRegionProviderChain ;
4546import software .amazon .awssdk .services .bedrockruntime .BedrockRuntimeAsyncClient ;
4647import software .amazon .awssdk .services .bedrockruntime .BedrockRuntimeClient ;
48+ import software .amazon .awssdk .services .bedrockruntime .model .CachePointBlock ;
4749import software .amazon .awssdk .services .bedrockruntime .model .ContentBlock ;
4850import software .amazon .awssdk .services .bedrockruntime .model .ConversationRole ;
4951import software .amazon .awssdk .services .bedrockruntime .model .ConverseMetrics ;
7072import software .amazon .awssdk .services .bedrockruntime .model .VideoFormat ;
7173import software .amazon .awssdk .services .bedrockruntime .model .VideoSource ;
7274
75+ import org .springframework .ai .bedrock .converse .api .BedrockCacheOptions ;
76+ import org .springframework .ai .bedrock .converse .api .BedrockCacheStrategy ;
7377import org .springframework .ai .bedrock .converse .api .BedrockMediaFormat ;
7478import org .springframework .ai .bedrock .converse .api .ConverseApiUtils ;
7579import org .springframework .ai .bedrock .converse .api .ConverseChatResponseStream ;
@@ -314,6 +318,8 @@ else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOp
314318 .internalToolExecutionEnabled (runtimeOptions .getInternalToolExecutionEnabled () != null
315319 ? runtimeOptions .getInternalToolExecutionEnabled ()
316320 : this .defaultOptions .getInternalToolExecutionEnabled ())
321+ .cacheOptions (runtimeOptions .getCacheOptions () != null ? runtimeOptions .getCacheOptions ()
322+ : this .defaultOptions .getCacheOptions ())
317323 .build ();
318324 }
319325
@@ -324,93 +330,183 @@ else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOp
324330
325331 ConverseRequest createRequest (Prompt prompt ) {
326332
327- List <Message > instructionMessages = prompt .getInstructions ()
333+ BedrockChatOptions updatedRuntimeOptions = prompt .getOptions ().copy ();
334+
335+ // Get cache options to determine strategy
336+ BedrockCacheOptions cacheOptions = updatedRuntimeOptions .getCacheOptions ();
337+ boolean shouldCacheConversationHistory = cacheOptions != null
338+ && cacheOptions .getStrategy () == BedrockCacheStrategy .CONVERSATION_HISTORY ;
339+
340+ // Get all non-system messages
341+ List <org .springframework .ai .chat .messages .Message > allNonSystemMessages = prompt .getInstructions ()
328342 .stream ()
329343 .filter (message -> message .getMessageType () != MessageType .SYSTEM )
330- .map (message -> {
331- if (message .getMessageType () == MessageType .USER ) {
332- List <ContentBlock > contents = new ArrayList <>();
333- if (message instanceof UserMessage userMessage ) {
334- contents .add (ContentBlock .fromText (userMessage .getText ()));
335-
336- if (!CollectionUtils .isEmpty (userMessage .getMedia ())) {
337- List <ContentBlock > mediaContent = userMessage .getMedia ()
338- .stream ()
339- .map (this ::mapMediaToContentBlock )
340- .toList ();
341- contents .addAll (mediaContent );
342- }
343- }
344- return Message .builder ().content (contents ).role (ConversationRole .USER ).build ();
344+ .toList ();
345+
346+ // Find the last user message index for CONVERSATION_HISTORY caching
347+ int lastUserMessageIndex = -1 ;
348+ if (shouldCacheConversationHistory ) {
349+ for (int i = allNonSystemMessages .size () - 1 ; i >= 0 ; i --) {
350+ if (allNonSystemMessages .get (i ).getMessageType () == MessageType .USER ) {
351+ lastUserMessageIndex = i ;
352+ break ;
345353 }
346- else if (message .getMessageType () == MessageType .ASSISTANT ) {
347- AssistantMessage assistantMessage = (AssistantMessage ) message ;
348- List <ContentBlock > contentBlocks = new ArrayList <>();
349- if (StringUtils .hasText (message .getText ())) {
350- contentBlocks .add (ContentBlock .fromText (message .getText ()));
354+ }
355+ if (logger .isDebugEnabled ()) {
356+ logger .debug ("CONVERSATION_HISTORY caching: lastUserMessageIndex={}, totalMessages={}" ,
357+ lastUserMessageIndex , allNonSystemMessages .size ());
358+ }
359+ }
360+
361+ // Build instruction messages with potential caching
362+ List <Message > instructionMessages = new ArrayList <>();
363+ for (int i = 0 ; i < allNonSystemMessages .size (); i ++) {
364+ org .springframework .ai .chat .messages .Message message = allNonSystemMessages .get (i );
365+
366+ // Determine if this message should have a cache point
367+ // For CONVERSATION_HISTORY: cache point goes on the last user message
368+ boolean shouldApplyCachePoint = shouldCacheConversationHistory && i == lastUserMessageIndex ;
369+
370+ if (message .getMessageType () == MessageType .USER ) {
371+ List <ContentBlock > contents = new ArrayList <>();
372+ if (message instanceof UserMessage ) {
373+ var userMessage = (UserMessage ) message ;
374+ contents .add (ContentBlock .fromText (userMessage .getText ()));
375+
376+ if (!CollectionUtils .isEmpty (userMessage .getMedia ())) {
377+ List <ContentBlock > mediaContent = userMessage .getMedia ()
378+ .stream ()
379+ .map (this ::mapMediaToContentBlock )
380+ .toList ();
381+ contents .addAll (mediaContent );
351382 }
352- if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
353- for (AssistantMessage .ToolCall toolCall : assistantMessage .getToolCalls ()) {
383+ }
354384
355- var argumentsDocument = ConverseApiUtils
356- .convertObjectToDocument (ModelOptionsUtils .jsonToMap (toolCall .arguments ()));
385+ // Apply cache point if this is the last user message
386+ if (shouldApplyCachePoint ) {
387+ CachePointBlock cachePoint = CachePointBlock .builder ().type ("default" ).build ();
388+ contents .add (ContentBlock .fromCachePoint (cachePoint ));
389+ logger .debug ("Applied cache point on last user message (conversation history caching)" );
390+ }
391+
392+ instructionMessages .add (Message .builder ().content (contents ).role (ConversationRole .USER ).build ());
393+ }
394+ else if (message .getMessageType () == MessageType .ASSISTANT ) {
395+ AssistantMessage assistantMessage = (AssistantMessage ) message ;
396+ List <ContentBlock > contentBlocks = new ArrayList <>();
397+ if (StringUtils .hasText (message .getText ())) {
398+ contentBlocks .add (ContentBlock .fromText (message .getText ()));
399+ }
400+ if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
401+ for (AssistantMessage .ToolCall toolCall : assistantMessage .getToolCalls ()) {
357402
358- contentBlocks .add (ContentBlock .fromToolUse (ToolUseBlock .builder ()
359- .toolUseId (toolCall .id ())
360- .name (toolCall .name ())
361- .input (argumentsDocument )
362- .build ()));
403+ var argumentsDocument = ConverseApiUtils
404+ .convertObjectToDocument (ModelOptionsUtils .jsonToMap (toolCall .arguments ()));
405+
406+ contentBlocks .add (ContentBlock .fromToolUse (ToolUseBlock .builder ()
407+ .toolUseId (toolCall .id ())
408+ .name (toolCall .name ())
409+ .input (argumentsDocument )
410+ .build ()));
363411
364- }
365412 }
366- return Message .builder ().content (contentBlocks ).role (ConversationRole .ASSISTANT ).build ();
367413 }
368- else if (message .getMessageType () == MessageType .TOOL ) {
369- List <ContentBlock > contentBlocks = ((ToolResponseMessage ) message ).getResponses ()
370- .stream ()
371- .map (toolResponse -> {
414+
415+ instructionMessages
416+ .add (Message .builder ().content (contentBlocks ).role (ConversationRole .ASSISTANT ).build ());
417+ }
418+ else if (message .getMessageType () == MessageType .TOOL ) {
419+ List <ContentBlock > contentBlocks = new ArrayList <>(
420+ ((ToolResponseMessage ) message ).getResponses ().stream ().map (toolResponse -> {
372421 ToolResultBlock toolResultBlock = ToolResultBlock .builder ()
373422 .toolUseId (toolResponse .id ())
374423 .content (ToolResultContentBlock .builder ().text (toolResponse .responseData ()).build ())
375424 .build ();
376425 return ContentBlock .fromToolResult (toolResultBlock );
377- })
378- .toList ();
379- return Message .builder ().content (contentBlocks ).role (ConversationRole .USER ).build ();
380- }
381- else {
382- throw new IllegalArgumentException ("Unsupported message type: " + message .getMessageType ());
383- }
384- })
385- .toList ();
426+ }).toList ());
427+
428+ instructionMessages .add (Message .builder ().content (contentBlocks ).role (ConversationRole .USER ).build ());
429+ }
430+ else {
431+ throw new IllegalArgumentException ("Unsupported message type: " + message .getMessageType ());
432+ }
433+ }
386434
387- List <SystemContentBlock > systemMessages = prompt .getInstructions ()
435+ // Determine if system message caching should be applied
436+ boolean shouldCacheSystem = cacheOptions != null
437+ && (cacheOptions .getStrategy () == BedrockCacheStrategy .SYSTEM_ONLY
438+ || cacheOptions .getStrategy () == BedrockCacheStrategy .SYSTEM_AND_TOOLS );
439+
440+ if (logger .isDebugEnabled () && cacheOptions != null ) {
441+ logger .debug ("Cache strategy: {}, shouldCacheSystem: {}" , cacheOptions .getStrategy (), shouldCacheSystem );
442+ }
443+
444+ // Build system messages with optional caching on last message
445+ List <org .springframework .ai .chat .messages .Message > systemMessageList = prompt .getInstructions ()
388446 .stream ()
389447 .filter (m -> m .getMessageType () == MessageType .SYSTEM )
390- .map (sysMessage -> SystemContentBlock .builder ().text (sysMessage .getText ()).build ())
391448 .toList ();
392449
393- BedrockChatOptions updatedRuntimeOptions = prompt .getOptions ().copy ();
450+ List <SystemContentBlock > systemMessages = new ArrayList <>();
451+ for (int i = 0 ; i < systemMessageList .size (); i ++) {
452+ org .springframework .ai .chat .messages .Message sysMessage = systemMessageList .get (i );
453+
454+ // Add the text content block
455+ SystemContentBlock textBlock = SystemContentBlock .builder ().text (sysMessage .getText ()).build ();
456+ systemMessages .add (textBlock );
457+
458+ // Apply cache point marker after last system message if caching is enabled
459+ // SystemContentBlock is a UNION type - text and cachePoint must be separate
460+ // blocks
461+ boolean isLastSystem = (i == systemMessageList .size () - 1 );
462+ if (isLastSystem && shouldCacheSystem ) {
463+ CachePointBlock cachePoint = CachePointBlock .builder ().type ("default" ).build ();
464+ SystemContentBlock cachePointBlock = SystemContentBlock .builder ().cachePoint (cachePoint ).build ();
465+ systemMessages .add (cachePointBlock );
466+ logger .debug ("Applied cache point after system message" );
467+ }
468+ }
394469
395470 ToolConfiguration toolConfiguration = null ;
396471
397472 // Add the tool definitions to the request's tools parameter.
398473 List <ToolDefinition > toolDefinitions = this .toolCallingManager .resolveToolDefinitions (updatedRuntimeOptions );
399474
475+ // Determine if tool caching should be applied
476+ boolean shouldCacheTools = cacheOptions != null
477+ && (cacheOptions .getStrategy () == BedrockCacheStrategy .TOOLS_ONLY
478+ || cacheOptions .getStrategy () == BedrockCacheStrategy .SYSTEM_AND_TOOLS );
479+
400480 if (!CollectionUtils .isEmpty (toolDefinitions )) {
401- List <Tool > bedrockTools = toolDefinitions .stream ().map (toolDefinition -> {
481+ List <Tool > bedrockTools = new ArrayList <>();
482+
483+ for (int i = 0 ; i < toolDefinitions .size (); i ++) {
484+ ToolDefinition toolDefinition = toolDefinitions .get (i );
402485 var description = toolDefinition .description ();
403486 var name = toolDefinition .name ();
404487 String inputSchema = toolDefinition .inputSchema ();
405- return Tool .builder ()
488+
489+ // Create tool specification
490+ Tool tool = Tool .builder ()
406491 .toolSpec (ToolSpecification .builder ()
407492 .name (name )
408493 .description (description )
409494 .inputSchema (ToolInputSchema .fromJson (
410495 ConverseApiUtils .convertObjectToDocument (ModelOptionsUtils .jsonToMap (inputSchema ))))
411496 .build ())
412497 .build ();
413- }).toList ();
498+ bedrockTools .add (tool );
499+
500+ // Apply cache point marker after last tool if caching is enabled
501+ // Tool is a UNION type - toolSpec and cachePoint must be separate objects
502+ boolean isLastTool = (i == toolDefinitions .size () - 1 );
503+ if (isLastTool && shouldCacheTools ) {
504+ CachePointBlock cachePoint = CachePointBlock .builder ().type ("default" ).build ();
505+ Tool cachePointTool = Tool .builder ().cachePoint (cachePoint ).build ();
506+ bedrockTools .add (cachePointTool );
507+ logger .debug ("Applied cache point after tool definitions" );
508+ }
509+ }
414510
415511 toolConfiguration = ToolConfiguration .builder ().tools (bedrockTools ).build ();
416512 }
@@ -633,12 +729,23 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
633729
634730 ConverseMetrics metrics = response .metrics ();
635731
636- var chatResponseMetaData = ChatResponseMetadata .builder ()
732+ var metadataBuilder = ChatResponseMetadata .builder ()
637733 .id (response .responseMetadata () != null ? response .responseMetadata ().requestId () : "Unknown" )
638- .usage (usage )
639- .build ();
734+ .usage (usage );
735+
736+ // Add cache metrics if available
737+ Map <String , Object > additionalMetadata = new HashMap <>();
738+ if (response .usage ().cacheReadInputTokens () != null ) {
739+ additionalMetadata .put ("cacheReadInputTokens" , response .usage ().cacheReadInputTokens ());
740+ }
741+ if (response .usage ().cacheWriteInputTokens () != null ) {
742+ additionalMetadata .put ("cacheWriteInputTokens" , response .usage ().cacheWriteInputTokens ());
743+ }
744+ if (!additionalMetadata .isEmpty ()) {
745+ metadataBuilder .metadata (additionalMetadata );
746+ }
640747
641- return new ChatResponse (allGenerations , chatResponseMetaData );
748+ return new ChatResponse (allGenerations , metadataBuilder . build () );
642749 }
643750
644751 /**
0 commit comments