From e3bb1fea77e96e0de3a8f1795756fbd42b94ce18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Wed, 25 Sep 2024 22:02:34 -0300 Subject: [PATCH 1/5] feat(anthropic): add support for prompt caching Implements Anthropic's prompt caching feature to improve token efficiency. - Adds cache control support in AnthropicApi and AnthropicChatModel - Creates AnthropicCacheType enum with EPHEMERAL cache type - Extends AbstractMessage and UserMessage to support cache parameters - Updates Usage tracking to include cache-related token metrics - Adds integration test to verify prompt caching functionality This implementation follows Anthropic's prompt caching API (beta-2024-07-31) which allows for more efficient token usage by caching frequently used prompts. --- .../ai/anthropic/AnthropicChatModel.java | 13 ++++- .../ai/anthropic/api/AnthropicApi.java | 54 ++++++++++++------- .../ai/anthropic/api/AnthropicCacheType.java | 21 ++++++++ .../ai/anthropic/api/StreamHelper.java | 4 +- .../ai/anthropic/api/AnthropicApiIT.java | 31 +++++++++++ .../ai/chat/messages/AbstractMessage.java | 39 +++++++++++--- .../ai/chat/messages/UserMessage.java | 25 +++++++++ 7 files changed, 159 insertions(+), 28 deletions(-) create mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index a513c25abab..e7a8de9d1ab 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -42,6 +42,8 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.AnthropicCacheType; +import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -432,7 +434,16 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { if (message.getMessageType() == MessageType.USER) { - List contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + AbstractMessage abstractMessage = (AbstractMessage) message; + List contents; + if (abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contents = new ArrayList<>( + List.of(new ContentBlock(message.getText(), cacheType.cacheControl()))); + } + else { + contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + } if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia().stream().map(media -> { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index ab95f3c8cb7..081d06f31f5 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -32,7 +32,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -77,6 +77,8 @@ public class AnthropicApi { private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; + public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; @@ -495,17 +497,7 @@ public ChatCompletionRequest(String model, List messages, Stri this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null); } - public static ChatCompletionRequestBuilder builder() { - return new ChatCompletionRequestBuilder(); - } - - public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { - return new ChatCompletionRequestBuilder(request); - } - /** - * Metadata about the request. - * * @param userId An external identifier for the user who is associated with the * request. This should be a uuid, hash value, or other opaque identifier. * Anthropic may use this id to help detect abuse. Do not include any identifying @@ -513,7 +505,22 @@ public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { */ @JsonInclude(Include.NON_NULL) public record Metadata(@JsonProperty("user_id") String userId) { + } + /** + * @param type is the cache type supported by anthropic. Doc + */ + @JsonInclude(Include.NON_NULL) + public record CacheControl(String type) { + } + + public static ChatCompletionRequestBuilder builder() { + return new ChatCompletionRequestBuilder(); + } + + public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { + return new ChatCompletionRequestBuilder(request); } } @@ -689,7 +696,10 @@ public record ContentBlock( // tool_result response only @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content + @JsonProperty("content") String content, + + // cache object + @JsonProperty("cache_control") CacheControl cacheControl ) { // @formatter:on @@ -708,7 +718,7 @@ public ContentBlock(String mediaType, String data) { * @param source The source of the content. */ public ContentBlock(Type type, Source source) { - this(type, source, null, null, null, null, null, null, null); + this(type, source, null, null, null, null, null, null, null, null); } /** @@ -716,7 +726,7 @@ public ContentBlock(Type type, Source source) { * @param source The source of the content. */ public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null); } /** @@ -724,7 +734,11 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null); + } + + public ContentBlock(String text, CacheControl cache) { + this(Type.TEXT, null, text, null, null, null, null, null, null, cache); } // Tool result @@ -735,7 +749,7 @@ public ContentBlock(String text) { * @param content The content of the tool result. */ public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content); + this(type, null, null, null, null, null, null, toolUseId, content, null); } /** @@ -746,7 +760,7 @@ public ContentBlock(Type type, String toolUseId, String content) { * @param index The index of the content block. */ public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null); } // Tool use input JSON delta streaming @@ -758,7 +772,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) { * @param input The input of the tool use. */ public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null); + this(type, null, null, null, id, name, input, null, null, null); } /** @@ -917,7 +931,9 @@ public record ChatCompletionResponse( public record Usage( // @formatter:off @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, + @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { // @formatter:off } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java new file mode 100644 index 00000000000..06a756be42f --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -0,0 +1,21 @@ +package org.springframework.ai.anthropic.api; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; + +import java.util.function.Supplier; + +public enum AnthropicCacheType { + + EPHEMERAL(() -> new CacheControl("ephemeral")); + + private Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + public CacheControl cacheControl() { + return this.value.get(); + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index ae62eb0748c..f3a515e324d 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -174,7 +174,9 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { if (messageDeltaEvent.usage() != null) { var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); + messageDeltaEvent.usage().outputTokens(), + contentBlockReference.get().usage.cacheCreationInputTokens(), + contentBlockReference.get().usage.cacheReadInputTokens()); contentBlockReference.get().withUsage(totalUsage); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index 752e9247fae..99e681a6642 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -29,6 +29,9 @@ import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -41,6 +44,34 @@ public class AnthropicApiIT { AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); + @Test + void chatWithPromptCache() { + String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.USER); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest( + AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, + false); + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void chatCompletionEntity() { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 8f3dd228510..e345579067a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -51,18 +51,25 @@ public abstract class AbstractMessage implements Message { */ protected final String textContent; + protected String cache; + /** * Additional options for the message to influence the response, not a generative map. */ protected final Map metadata; - /** - * Create a new AbstractMessage with the given message type, text content, and - * metadata. - * @param messageType the message type - * @param textContent the text content - * @param metadata the metadata - */ + protected AbstractMessage(MessageType messageType, String textContent, Map metadata, String cache) { + Assert.notNull(messageType, "Message type must not be null"); + if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { + Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); + } + this.messageType = messageType; + this.textContent = textContent; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + protected AbstractMessage(MessageType messageType, String textContent, Map metadata) { Assert.notNull(messageType, "Message type must not be null"); if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { @@ -93,6 +100,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map metadata, String cache) { + Assert.notNull(resource, "Resource must not be null"); + try (InputStream inputStream = resource.getInputStream()) { + this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + this.messageType = messageType; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + /** * Get the content of the message. * @return the content of the message @@ -120,6 +141,10 @@ public MessageType getMessageType() { return this.messageType; } + public String getCache() { + return cache; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 7bae70b64ee..ab4d6ffc0b2 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -36,6 +36,10 @@ public class UserMessage extends AbstractMessage implements MediaContent { protected final List media; + public UserMessage(String textContent, String cache) { + this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache); + } + public UserMessage(String textContent) { this(MessageType.USER, textContent, new ArrayList<>(), Map.of()); } @@ -45,6 +49,11 @@ public UserMessage(Resource resource) { this.media = new ArrayList<>(); } + public UserMessage(Resource resource, String cache) { + super(MessageType.USER, resource, Map.of(), cache); + this.media = new ArrayList<>(); + } + public UserMessage(String textContent, List media) { this(MessageType.USER, textContent, media, Map.of()); } @@ -64,6 +73,17 @@ public UserMessage(MessageType messageType, String textContent, Collection(media); } + public UserMessage(MessageType messageType, String textContent, Collection media, + Map metadata, String cache) { + super(messageType, textContent, metadata, cache); + Assert.notNull(media, "media data must not be null"); + this.media = new ArrayList<>(media); + } + + public List getMedia(String... dummy) { + return this.media; + } + @Override public String toString() { return "UserMessage{" + "content='" + getText() + '\'' + ", properties=" + this.metadata + ", messageType=" @@ -80,4 +100,9 @@ public String getText() { return this.textContent; } + @Override + public String getCache() { + return super.getCache(); + } + } From 6253677a16b0f5a7ec826baca3396a9fe012e8fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Sat, 9 Aug 2025 21:16:34 -0300 Subject: [PATCH 2/5] fixed bug in anthropic cache --- .../ai/anthropic/AnthropicChatModel.java | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index e7a8de9d1ab..bc6a6f573ac 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -16,12 +16,7 @@ package org.springframework.ai.anthropic; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; @@ -30,6 +25,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -43,11 +39,6 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.ai.anthropic.api.AnthropicCacheType; -import org.springframework.ai.chat.messages.AbstractMessage; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; @@ -429,6 +420,11 @@ private Map mergeHttpHeaders(Map runtimeHttpHead ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + Optional lastMessage = prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() != MessageType.SYSTEM) + .findFirst(); + List userMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() != MessageType.SYSTEM) @@ -436,7 +432,8 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { if (message.getMessageType() == MessageType.USER) { AbstractMessage abstractMessage = (AbstractMessage) message; List contents; - if (abstractMessage.getCache() != null) { + boolean isLastItem = lastMessage.filter(message::equals).isPresent(); + if (isLastItem && abstractMessage.getCache() != null) { AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); contents = new ArrayList<>( List.of(new ContentBlock(message.getText(), cacheType.cacheControl()))); From f17ee20ee95bd44a25a3127818d4babf9fd49712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Sat, 9 Aug 2025 21:31:34 -0300 Subject: [PATCH 3/5] fixed merge with main --- .../ai/anthropic/api/AnthropicApi.java | 16 ++++------------ .../ai/anthropic/api/StreamHelper.java | 6 +++--- .../ai/chat/messages/UserMessage.java | 8 +------- 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 0190925d116..1e86b241b5e 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -55,14 +55,6 @@ import org.springframework.web.reactive.function.client.WebClient; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** * The Anthropic API client. @@ -777,9 +769,9 @@ public record ContentBlock( @JsonProperty("content") String content, // cache object - @JsonProperty("cache_control") CacheControl cacheControl + @JsonProperty("cache_control") CacheControl cacheControl, - // Thinking only + // Thinking only @JsonProperty("signature") String signature, @JsonProperty("thinking") String thinking, @@ -819,11 +811,11 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, null); } public ContentBlock(String text, CacheControl cache) { - this(Type.TEXT, null, text, null, null, null, null, null, null, cache null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, cache, null, null, null); } // Tool result diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index c5be3529d2a..c34241d39a1 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -159,7 +159,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_START)) { } else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) { ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null, - null, null, null, null, null, thinkingBlock.thinking(), null); + null, null, null, null, null, null, thinkingBlock.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -176,12 +176,12 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) { } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) { ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, null, thinking.thinking(), null); + null, null, null, null, null, null,null, thinking.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) { ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, sig.signature(), null, null); + null, null, null, null, null, null, sig.signature(), null, null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index a8dcb3072c9..81d4d772125 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -65,7 +65,7 @@ public UserMessage(Resource resource, String cache) { } public UserMessage(String textContent, List media) { - this(MessageType.USER, textContent, media, Map.of()); + this(MessageType.USER, textContent, media, Map.of(), null); } @Override @@ -100,12 +100,6 @@ public List getMedia(String... dummy) { return this.media; } - @Override - public String toString() { - return "UserMessage{" + "content='" + getText() + '\'' + ", properties=" + this.metadata + ", messageType=" - + this.messageType + '}'; - } - public Builder mutate() { return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata())); } From d7326178b5e963cb8c50d70876225465e8f3bf97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Sat, 9 Aug 2025 23:55:28 -0300 Subject: [PATCH 4/5] added cache suport to assistant message --- .../ai/anthropic/AnthropicChatModel.java | 26 ++++++++++---- .../ai/anthropic/api/AnthropicApiIT.java | 34 +++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index ec277cc336b..59b1906e2c7 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -475,19 +475,26 @@ private Map mergeHttpHeaders(Map runtimeHttpHead ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { - Optional lastMessage = prompt.getInstructions() + List userMessagesList = prompt.getInstructions() .stream() - .filter(message -> message.getMessageType() != MessageType.SYSTEM) - .findFirst(); + .filter(message -> message.getMessageType() == MessageType.USER) + .toList(); + Message lastUserMessage = userMessagesList.isEmpty() ? null : userMessagesList.get(userMessagesList.size() - 1); + + List assistantMessageList = prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() == MessageType.ASSISTANT) + .toList(); + Message lastAssistantMessage = assistantMessageList.isEmpty() ? null : assistantMessageList.get(assistantMessageList.size() - 1); List userMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { + AbstractMessage abstractMessage = (AbstractMessage) message; if (message.getMessageType() == MessageType.USER) { - AbstractMessage abstractMessage = (AbstractMessage) message; List contents; - boolean isLastItem = lastMessage.filter(message::equals).isPresent(); + boolean isLastItem = message.equals(lastUserMessage); if (isLastItem && abstractMessage.getCache() != null) { AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); contents = new ArrayList<>( @@ -511,8 +518,14 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { else if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage) message; List contentBlocks = new ArrayList<>(); + boolean isLastItem = message.equals(lastAssistantMessage); if (StringUtils.hasText(message.getText())) { - contentBlocks.add(new ContentBlock(message.getText())); + if (isLastItem && abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contentBlocks.add(new ContentBlock(message.getText(), cacheType.cacheControl())); + } else { + contentBlocks.add(new ContentBlock(message.getText())); + } } if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { @@ -551,6 +564,7 @@ else if (message.getMessageType() == MessageType.TOOL) { // Add the tool definitions to the request's tools parameter. List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { + var tool = getFunctionTools(toolDefinitions); request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index b647c7744f6..2fe769a06ba 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Random; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -73,6 +74,39 @@ public class AnthropicApiIT { } """))); + @Test + void chatWithPromptCacheInAssistantMessage() { + String assistantMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(assistantMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.ASSISTANT); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU) + .messages(List.of(chatCompletionMessage)) + .maxTokens(1500) + .temperature(0.8) + .stream(false) + .build(); + + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void chatWithPromptCache() { String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " From 077c0673546a7dade8c0cdcf0ed1027b497f31b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cclaudio-code=E2=80=9D?= Date: Sat, 9 Aug 2025 23:58:34 -0300 Subject: [PATCH 5/5] apply java format --- .../ai/anthropic/AnthropicChatModel.java | 18 ++++++++++-------- .../ai/anthropic/api/StreamHelper.java | 2 +- .../ai/anthropic/api/AnthropicApiIT.java | 16 ++++++++-------- .../ai/chat/messages/UserMessage.java | 6 +++--- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 59b1906e2c7..dfaa1f66bc6 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -476,16 +476,17 @@ private Map mergeHttpHeaders(Map runtimeHttpHead ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List userMessagesList = prompt.getInstructions() - .stream() - .filter(message -> message.getMessageType() == MessageType.USER) - .toList(); + .stream() + .filter(message -> message.getMessageType() == MessageType.USER) + .toList(); Message lastUserMessage = userMessagesList.isEmpty() ? null : userMessagesList.get(userMessagesList.size() - 1); List assistantMessageList = prompt.getInstructions() - .stream() - .filter(message -> message.getMessageType() == MessageType.ASSISTANT) - .toList(); - Message lastAssistantMessage = assistantMessageList.isEmpty() ? null : assistantMessageList.get(assistantMessageList.size() - 1); + .stream() + .filter(message -> message.getMessageType() == MessageType.ASSISTANT) + .toList(); + Message lastAssistantMessage = assistantMessageList.isEmpty() ? null + : assistantMessageList.get(assistantMessageList.size() - 1); List userMessages = prompt.getInstructions() .stream() @@ -523,7 +524,8 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { if (isLastItem && abstractMessage.getCache() != null) { AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); contentBlocks.add(new ContentBlock(message.getText(), cacheType.cacheControl())); - } else { + } + else { contentBlocks.add(new ContentBlock(message.getText())); } } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index c34241d39a1..8366ca0b712 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -176,7 +176,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) { } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) { ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, null,null, thinking.thinking(), null); + null, null, null, null, null, null, null, thinking.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) { diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index 2fe769a06ba..33301cd7573 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -87,16 +87,16 @@ void chatWithPromptCacheInAssistantMessage() { Role.ASSISTANT); ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU) - .messages(List.of(chatCompletionMessage)) - .maxTokens(1500) - .temperature(0.8) - .stream(false) - .build(); + .model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU) + .messages(List.of(chatCompletionMessage)) + .maxTokens(1500) + .temperature(0.8) + .stream(false) + .build(); AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) - .getBody() - .usage(); + .getBody() + .usage(); assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 81d4d772125..f4dbf28361b 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -66,8 +66,8 @@ public UserMessage(Resource resource, String cache) { public UserMessage(String textContent, List media) { this(MessageType.USER, textContent, media, Map.of(), null); - } - + } + @Override public String toString() { return "UserMessage{" + "content='" + getText() + '\'' + ", metadata=" + this.metadata + ", messageType=" @@ -100,7 +100,7 @@ public List getMedia(String... dummy) { return this.media; } - public Builder mutate() { + public Builder mutate() { return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata())); }