diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 2d36014a719..dfaa1f66bc6 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -16,12 +16,7 @@ package org.springframework.ai.anthropic; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; @@ -30,6 +25,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -42,10 +38,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.anthropic.api.AnthropicCacheType; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; @@ -482,12 +475,35 @@ private Map mergeHttpHeaders(Map runtimeHttpHead ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + List userMessagesList = prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() == MessageType.USER) + .toList(); + Message lastUserMessage = userMessagesList.isEmpty() ? null : userMessagesList.get(userMessagesList.size() - 1); + + List assistantMessageList = prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() == MessageType.ASSISTANT) + .toList(); + Message lastAssistantMessage = assistantMessageList.isEmpty() ? null + : assistantMessageList.get(assistantMessageList.size() - 1); + List userMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { + AbstractMessage abstractMessage = (AbstractMessage) message; if (message.getMessageType() == MessageType.USER) { - List contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + List contents; + boolean isLastItem = message.equals(lastUserMessage); + if (isLastItem && abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contents = new ArrayList<>( + List.of(new ContentBlock(message.getText(), cacheType.cacheControl()))); + } + else { + contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + } if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia().stream().map(media -> { @@ -503,8 +519,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { else if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage) message; List contentBlocks = new ArrayList<>(); + boolean isLastItem = message.equals(lastAssistantMessage); if (StringUtils.hasText(message.getText())) { - contentBlocks.add(new ContentBlock(message.getText())); + if (isLastItem && abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contentBlocks.add(new ContentBlock(message.getText(), cacheType.cacheControl())); + } + else { + contentBlocks.add(new ContentBlock(message.getText())); + } } if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { @@ -543,6 +566,7 @@ else if (message.getMessageType() == MessageType.TOOL) { // Add the tool definitions to the request's tools parameter. List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { + var tool = getFunctionTools(toolDefinitions); request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index cf410690216..1e86b241b5e 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -24,6 +24,15 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; @@ -46,14 +55,6 @@ import org.springframework.web.reactive.function.client.WebClient; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** * The Anthropic API client. @@ -94,6 +95,8 @@ public static Builder builder() { private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; + public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final String completionsPath; @@ -538,17 +541,7 @@ public ChatCompletionRequest(String model, List messages, Stri this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null); } - public static ChatCompletionRequestBuilder builder() { - return new ChatCompletionRequestBuilder(); - } - - public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { - return new ChatCompletionRequestBuilder(request); - } - /** - * Metadata about the request. - * * @param userId An external identifier for the user who is associated with the * request. This should be a uuid, hash value, or other opaque identifier. * Anthropic may use this id to help detect abuse. Do not include any identifying @@ -556,7 +549,22 @@ public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { */ @JsonInclude(Include.NON_NULL) public record Metadata(@JsonProperty("user_id") String userId) { + } + + /** + * @param type is the cache type supported by anthropic. Doc + */ + @JsonInclude(Include.NON_NULL) + public record CacheControl(String type) { + } + + public static ChatCompletionRequestBuilder builder() { + return new ChatCompletionRequestBuilder(); + } + public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { + return new ChatCompletionRequestBuilder(request); } /** @@ -760,7 +768,10 @@ public record ContentBlock( @JsonProperty("tool_use_id") String toolUseId, @JsonProperty("content") String content, - // Thinking only + // cache object + @JsonProperty("cache_control") CacheControl cacheControl, + + // Thinking only @JsonProperty("signature") String signature, @JsonProperty("thinking") String thinking, @@ -784,7 +795,7 @@ public ContentBlock(String mediaType, String data) { * @param source The source of the content. */ public ContentBlock(Type type, Source source) { - this(type, source, null, null, null, null, null, null, null, null, null, null); + this(type, source, null, null, null, null, null, null, null, null, null, null, null); } /** @@ -792,7 +803,7 @@ public ContentBlock(Type type, Source source) { * @param source The source of the content. */ public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null, null); } /** @@ -800,7 +811,11 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, null); + } + + public ContentBlock(String text, CacheControl cache) { + this(Type.TEXT, null, text, null, null, null, null, null, null, cache, null, null, null); } // Tool result @@ -811,7 +826,7 @@ public ContentBlock(String text) { * @param content The content of the tool result. */ public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content, null, null, null); + this(type, null, null, null, null, null, null, toolUseId, content, null, null, null, null); } /** @@ -822,7 +837,7 @@ public ContentBlock(Type type, String toolUseId, String content) { * @param index The index of the content block. */ public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null, null, null, null); } // Tool use input JSON delta streaming @@ -834,7 +849,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) { * @param input The input of the tool use. */ public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null, null, null, null); + this(type, null, null, null, id, name, input, null, null, null, null, null, null); } /** @@ -1028,7 +1043,9 @@ public record ChatCompletionResponse( public record Usage( // @formatter:off @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, + @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { // @formatter:off } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java new file mode 100644 index 00000000000..06a756be42f --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -0,0 +1,21 @@ +package org.springframework.ai.anthropic.api; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; + +import java.util.function.Supplier; + +public enum AnthropicCacheType { + + EPHEMERAL(() -> new CacheControl("ephemeral")); + + private Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + public CacheControl cacheControl() { + return this.value.get(); + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index f636f29a158..8366ca0b712 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -159,7 +159,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_START)) { } else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) { ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null, - null, null, null, null, null, thinkingBlock.thinking(), null); + null, null, null, null, null, null, thinkingBlock.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -176,12 +176,12 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) { } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) { ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, null, thinking.thinking(), null); + null, null, null, null, null, null, null, thinking.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) { ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, sig.signature(), null, null); + null, null, null, null, null, null, sig.signature(), null, null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -204,8 +204,10 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { } if (messageDeltaEvent.usage() != null) { - Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); + var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), + messageDeltaEvent.usage().outputTokens(), + contentBlockReference.get().usage.cacheCreationInputTokens(), + contentBlockReference.get().usage.cacheReadInputTokens()); contentBlockReference.get().withUsage(totalUsage); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index c78386fb7ce..33301cd7573 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Random; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -37,6 +38,9 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -70,6 +74,67 @@ public class AnthropicApiIT { } """))); + @Test + void chatWithPromptCacheInAssistantMessage() { + String assistantMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(assistantMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.ASSISTANT); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU) + .messages(List.of(chatCompletionMessage)) + .maxTokens(1500) + .temperature(0.8) + .stream(false) + .build(); + + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + + @Test + void chatWithPromptCache() { + String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.USER); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest( + AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, + false); + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void chatCompletionEntity() { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 6e37fd7548b..ec6dd245709 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -53,11 +53,25 @@ public abstract class AbstractMessage implements Message { @Nullable protected final String textContent; + protected String cache; + /** * Additional options for the message to influence the response, not a generative map. */ protected final Map metadata; + protected AbstractMessage(MessageType messageType, String textContent, Map metadata, String cache) { + Assert.notNull(messageType, "Message type must not be null"); + if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { + Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); + } + this.messageType = messageType; + this.textContent = textContent; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + /** * Create a new AbstractMessage with the given message type, text content, and * metadata. @@ -98,6 +112,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map metadata, String cache) { + Assert.notNull(resource, "Resource must not be null"); + try (InputStream inputStream = resource.getInputStream()) { + this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + this.messageType = messageType; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + /** * Get the content of the message. * @return the content of the message @@ -126,6 +154,10 @@ public MessageType getMessageType() { return this.messageType; } + public String getCache() { + return cache; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index fc005392c34..f4dbf28361b 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -40,6 +40,10 @@ public class UserMessage extends AbstractMessage implements MediaContent { protected final List media; + public UserMessage(String textContent, String cache) { + this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache); + } + public UserMessage(String textContent) { this(textContent, new ArrayList<>(), Map.of()); } @@ -55,6 +59,15 @@ public UserMessage(Resource resource) { this(MessageUtils.readResource(resource)); } + public UserMessage(Resource resource, String cache) { + super(MessageType.USER, resource, Map.of(), cache); + this.media = new ArrayList<>(); + } + + public UserMessage(String textContent, List media) { + this(MessageType.USER, textContent, media, Map.of(), null); + } + @Override public String toString() { return "UserMessage{" + "content='" + getText() + '\'' + ", metadata=" + this.metadata + ", messageType=" @@ -76,6 +89,17 @@ public UserMessage copy() { return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata())).build(); } + public UserMessage(MessageType messageType, String textContent, Collection media, + Map metadata, String cache) { + super(messageType, textContent, metadata, cache); + Assert.notNull(media, "media data must not be null"); + this.media = new ArrayList<>(media); + } + + public List getMedia(String... dummy) { + return this.media; + } + public Builder mutate() { return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata())); } @@ -135,4 +159,9 @@ else if (this.resource != null) { } + @Override + public String getCache() { + return super.getCache(); + } + }