diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index a513c25abab..e7a8de9d1ab 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -42,6 +42,8 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.AnthropicCacheType; +import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -432,7 +434,16 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { if (message.getMessageType() == MessageType.USER) { - List contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + AbstractMessage abstractMessage = (AbstractMessage) message; + List contents; + if (abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contents = new ArrayList<>( + List.of(new ContentBlock(message.getText(), cacheType.cacheControl()))); + } + else { + contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + } if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia().stream().map(media -> { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index ab95f3c8cb7..081d06f31f5 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -32,7 +32,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -77,6 +77,8 @@ public class AnthropicApi { private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; + public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; @@ -495,17 +497,7 @@ public ChatCompletionRequest(String model, List messages, Stri this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null); } - public static ChatCompletionRequestBuilder builder() { - return new ChatCompletionRequestBuilder(); - } - - public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { - return new ChatCompletionRequestBuilder(request); - } - /** - * Metadata about the request. - * * @param userId An external identifier for the user who is associated with the * request. This should be a uuid, hash value, or other opaque identifier. * Anthropic may use this id to help detect abuse. Do not include any identifying @@ -513,7 +505,22 @@ public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { */ @JsonInclude(Include.NON_NULL) public record Metadata(@JsonProperty("user_id") String userId) { + } + /** + * @param type is the cache type supported by anthropic. Doc + */ + @JsonInclude(Include.NON_NULL) + public record CacheControl(String type) { + } + + public static ChatCompletionRequestBuilder builder() { + return new ChatCompletionRequestBuilder(); + } + + public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { + return new ChatCompletionRequestBuilder(request); } } @@ -689,7 +696,10 @@ public record ContentBlock( // tool_result response only @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content + @JsonProperty("content") String content, + + // cache object + @JsonProperty("cache_control") CacheControl cacheControl ) { // @formatter:on @@ -708,7 +718,7 @@ public ContentBlock(String mediaType, String data) { * @param source The source of the content. */ public ContentBlock(Type type, Source source) { - this(type, source, null, null, null, null, null, null, null); + this(type, source, null, null, null, null, null, null, null, null); } /** @@ -716,7 +726,7 @@ public ContentBlock(Type type, Source source) { * @param source The source of the content. */ public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null); } /** @@ -724,7 +734,11 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null); + } + + public ContentBlock(String text, CacheControl cache) { + this(Type.TEXT, null, text, null, null, null, null, null, null, cache); } // Tool result @@ -735,7 +749,7 @@ public ContentBlock(String text) { * @param content The content of the tool result. */ public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content); + this(type, null, null, null, null, null, null, toolUseId, content, null); } /** @@ -746,7 +760,7 @@ public ContentBlock(Type type, String toolUseId, String content) { * @param index The index of the content block. */ public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null); } // Tool use input JSON delta streaming @@ -758,7 +772,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) { * @param input The input of the tool use. */ public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null); + this(type, null, null, null, id, name, input, null, null, null); } /** @@ -917,7 +931,9 @@ public record ChatCompletionResponse( public record Usage( // @formatter:off @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, + @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { // @formatter:off } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java new file mode 100644 index 00000000000..06a756be42f --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -0,0 +1,21 @@ +package org.springframework.ai.anthropic.api; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; + +import java.util.function.Supplier; + +public enum AnthropicCacheType { + + EPHEMERAL(() -> new CacheControl("ephemeral")); + + private Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + public CacheControl cacheControl() { + return this.value.get(); + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index ae62eb0748c..f3a515e324d 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -174,7 +174,9 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { if (messageDeltaEvent.usage() != null) { var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); + messageDeltaEvent.usage().outputTokens(), + contentBlockReference.get().usage.cacheCreationInputTokens(), + contentBlockReference.get().usage.cacheReadInputTokens()); contentBlockReference.get().withUsage(totalUsage); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index 752e9247fae..99e681a6642 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -29,6 +29,9 @@ import org.springframework.ai.anthropic.api.AnthropicApi.Role; import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -41,6 +44,34 @@ public class AnthropicApiIT { AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY")); + @Test + void chatWithPromptCache() { + String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.USER); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest( + AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, + false); + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void chatCompletionEntity() { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 8f3dd228510..e345579067a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -51,18 +51,25 @@ public abstract class AbstractMessage implements Message { */ protected final String textContent; + protected String cache; + /** * Additional options for the message to influence the response, not a generative map. */ protected final Map metadata; - /** - * Create a new AbstractMessage with the given message type, text content, and - * metadata. - * @param messageType the message type - * @param textContent the text content - * @param metadata the metadata - */ + protected AbstractMessage(MessageType messageType, String textContent, Map metadata, String cache) { + Assert.notNull(messageType, "Message type must not be null"); + if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { + Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); + } + this.messageType = messageType; + this.textContent = textContent; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + protected AbstractMessage(MessageType messageType, String textContent, Map metadata) { Assert.notNull(messageType, "Message type must not be null"); if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { @@ -93,6 +100,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map metadata, String cache) { + Assert.notNull(resource, "Resource must not be null"); + try (InputStream inputStream = resource.getInputStream()) { + this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + this.messageType = messageType; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + /** * Get the content of the message. * @return the content of the message @@ -120,6 +141,10 @@ public MessageType getMessageType() { return this.messageType; } + public String getCache() { + return cache; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index 7bae70b64ee..ab4d6ffc0b2 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -36,6 +36,10 @@ public class UserMessage extends AbstractMessage implements MediaContent { protected final List media; + public UserMessage(String textContent, String cache) { + this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache); + } + public UserMessage(String textContent) { this(MessageType.USER, textContent, new ArrayList<>(), Map.of()); } @@ -45,6 +49,11 @@ public UserMessage(Resource resource) { this.media = new ArrayList<>(); } + public UserMessage(Resource resource, String cache) { + super(MessageType.USER, resource, Map.of(), cache); + this.media = new ArrayList<>(); + } + public UserMessage(String textContent, List media) { this(MessageType.USER, textContent, media, Map.of()); } @@ -64,6 +73,17 @@ public UserMessage(MessageType messageType, String textContent, Collection(media); } + public UserMessage(MessageType messageType, String textContent, Collection media, + Map metadata, String cache) { + super(messageType, textContent, metadata, cache); + Assert.notNull(media, "media data must not be null"); + this.media = new ArrayList<>(media); + } + + public List getMedia(String... dummy) { + return this.media; + } + @Override public String toString() { return "UserMessage{" + "content='" + getText() + '\'' + ", properties=" + this.metadata + ", messageType=" @@ -80,4 +100,9 @@ public String getText() { return this.textContent; } + @Override + public String getCache() { + return super.getCache(); + } + }