From 1dd686b3e9e767d09da350505af86578ba70de88 Mon Sep 17 00:00:00 2001 From: Gareth Evans Date: Thu, 14 Aug 2025 10:45:53 +0100 Subject: [PATCH 1/2] feat(google genai): support sending labels with chat request Signed-off-by: Gareth Evans Fix checkstyle Signed-off-by: Soby Chacko --- .../ai/google/genai/GoogleGenAiChatModel.java | 6 +++ .../google/genai/GoogleGenAiChatOptions.java | 26 +++++++++++-- .../genai/GoogleGenAiChatOptionsTest.java | 37 +++++++++++++++++++ 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 256fe679427..668c1e5a0d7 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -479,6 +479,8 @@ Prompt buildRequestPrompt(Prompt prompt) { runtimeOptions.getGoogleSearchRetrieval(), this.defaultOptions.getGoogleSearchRetrieval())); requestOptions.setSafetySettings(ModelOptionsUtils.mergeOption(runtimeOptions.getSafetySettings(), this.defaultOptions.getSafetySettings())); + requestOptions + .setLabels(ModelOptionsUtils.mergeOption(runtimeOptions.getLabels(), this.defaultOptions.getLabels())); } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); @@ -488,6 +490,7 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setGoogleSearchRetrieval(this.defaultOptions.getGoogleSearchRetrieval()); requestOptions.setSafetySettings(this.defaultOptions.getSafetySettings()); + requestOptions.setLabels(this.defaultOptions.getLabels()); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); @@ -680,6 +683,9 @@ GeminiRequest createGeminiRequest(Prompt prompt) { configBuilder .thinkingConfig(ThinkingConfig.builder().thinkingBudget(requestOptions.getThinkingBudget()).build()); } + if (requestOptions.getLabels() != null && !requestOptions.getLabels().isEmpty()) { + configBuilder.labels(requestOptions.getLabels()); + } // Add safety settings if (!CollectionUtils.isEmpty(requestOptions.getSafetySettings())) { diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index 4d5eb076166..7e05e5fc921 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -144,6 +144,9 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Map labels = new HashMap<>(); // @formatter:on public static Builder builder() { @@ -170,6 +173,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); options.setToolContext(fromOptions.getToolContext()); options.setThinkingBudget(fromOptions.getThinkingBudget()); + options.setLabels(fromOptions.getLabels()); return options; } @@ -332,6 +336,15 @@ public void setSafetySettings(List safetySettings) { this.safetySettings = safetySettings; } + public Map getLabels() { + return this.labels; + } + + public void setLabels(Map labels) { + Assert.notNull(labels, "labels must not be null"); + this.labels = labels; + } + @Override public Map getToolContext() { return this.toolContext; @@ -363,7 +376,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) - && Objects.equals(this.toolContext, that.toolContext); + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels); } @Override @@ -371,7 +384,7 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, - this.safetySettings, this.internalToolExecutionEnabled, this.toolContext); + this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels); } @Override @@ -382,7 +395,8 @@ public String toString() { + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" - + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}'; + + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels + + '}'; } @Override @@ -510,6 +524,12 @@ public Builder thinkingBudget(Integer thinkingBudget) { return this; } + public Builder labels(Map labels) { + Assert.notNull(labels, "labels must not be null"); + this.options.labels = labels; + return this; + } + public GoogleGenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java index 4d8d45cbd11..3521213bfb5 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java @@ -16,6 +16,8 @@ package org.springframework.ai.google.genai; +import java.util.Map; + import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -104,6 +106,29 @@ public void testEqualsAndHashCodeWithThinkingBudget() { assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); } + @Test + public void testEqualsAndHashCodeWithLabels() { + GoogleGenAiChatOptions options1 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + GoogleGenAiChatOptions options2 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + GoogleGenAiChatOptions options3 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "other-org")) + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + assertThat(options1).isNotEqualTo(options3); + assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); + } + @Test public void testToStringWithThinkingBudget() { GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() @@ -116,4 +141,16 @@ public void testToStringWithThinkingBudget() { assertThat(toString).contains("test-model"); } + @Test + public void testToStringWithLabels() { + GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + String toString = options.toString(); + assertThat(toString).contains("labels={org=my-org}"); + assertThat(toString).contains("test-model"); + } + } From 4e57b4ba363d8f2cd9e4ab2389a63fcd2b96041e Mon Sep 17 00:00:00 2001 From: Austin Dase Date: Mon, 1 Sep 2025 08:08:10 -0400 Subject: [PATCH 2/2] feat(anthropic): enhance caching options with new 1-hour ephemeral cache type, allow for more fined grained configuration of what's cached, account for Anthropic's max of 4 cache blocks per request port over anthropic prompt caching support --- .../ai/anthropic/AnthropicChatModel.java | 113 +++- .../ai/anthropic/AnthropicChatOptions.java | 235 +++++++- .../ai/anthropic/api/AnthropicApi.java | 510 ++++++++++++------ .../ai/anthropic/api/AnthropicCacheType.java | 63 +++ .../ai/anthropic/api/StreamHelper.java | 12 +- .../ai/anthropic/AnthropicChatModelIT.java | 57 +- .../anthropic/AnthropicChatOptionsTests.java | 228 +++++++- .../ai/anthropic/api/AnthropicApiIT.java | 35 ++ .../ROOT/pages/api/chat/anthropic-chat.adoc | 185 ++++++- 9 files changed, 1234 insertions(+), 204 deletions(-) create mode 100644 models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 5ea1195c3a7..80db41d766c 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; @@ -91,6 +92,7 @@ * @author Alexandros Pappas * @author Jonghoon Park * @author Soby Chacko + * @author Austin Dase * @since 1.0.0 */ public class AnthropicChatModel implements ChatModel { @@ -481,21 +483,100 @@ private Map mergeHttpHeaders(Map runtimeHttpHead return mergedHttpHeaders; } + private static ContentBlock cacheAwareContentBlock(String text, AtomicInteger usedCacheBlocks, + AnthropicChatOptions.CacheControlConfiguration cfg, MessageType type) { + return cacheAwareContentBlock(new ContentBlock(text), usedCacheBlocks, cfg, type); + } + + private static ContentBlock cacheAwareContentBlock(ContentBlock contentBlock, AtomicInteger usedCacheBlocks, + AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration, MessageType messageType) { + if (cacheControlConfiguration == null) { + return contentBlock; + } + + // Only proceed if this message is eligible for caching AND we can reserve a cache + // slot + if (isCacheEligible(contentBlock, cacheControlConfiguration, messageType) + && tryReserveCacheBlock(usedCacheBlocks, cacheControlConfiguration.getMaxCacheBlocks())) { + return ContentBlock.from(contentBlock) + .cacheControl(cacheControlConfiguration.getCacheTypeForMessageType(messageType).cacheControl()) + .build(); + } + + if (logger.isDebugEnabled()) { + final Integer minCacheBlockLength = cacheControlConfiguration.getMinBlockLengthForMessageType(messageType); + logger.debug( + "Skipping cache for messageType={}, used={}/{}; textLength={}, contentLength={}, minLength={}, cachableTypes={}", + messageType, usedCacheBlocks.get(), cacheControlConfiguration.getMaxCacheBlocks(), + safeLength(contentBlock.text()), safeLength(contentBlock.content()), minCacheBlockLength, + cacheControlConfiguration.getCachableMessageTypes()); + } + + return contentBlock; + } + + private static int safeLength(String s) { + return (s == null) ? 0 : s.length(); + } + + private static boolean isCacheEligible(ContentBlock block, + AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration, MessageType messageType) { + if (!cacheControlConfiguration.getCachableMessageTypes().contains(messageType)) { + return false; + } + + final int minCacheBlockLength = cacheControlConfiguration.getMinBlockLengthForMessageType(messageType); + + return isNullOrGreaterThanLength(block.text(), minCacheBlockLength) + && isNullOrGreaterThanLength(block.content(), minCacheBlockLength); + } + + private static boolean isNullOrGreaterThanLength(String s, int min) { + return s == null || s.length() >= min; + } + + /** + * Attempts to increment the counter only if we're still under the max. Returns true + * if we successfully reserved a slot. + */ + private static boolean tryReserveCacheBlock(AtomicInteger used, int max) { + int prev = used.getAndUpdate(v -> (v < max) ? (v + 1) : v); + return prev < max; + } + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions(); + AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration = (requestOptions != null) + ? requestOptions.getCacheControlConfiguration() : null; + + AtomicInteger usedCacheBlocks = new AtomicInteger(); + + List systemPrompt = prompt.getInstructions() + .stream() + .filter(m -> m.getMessageType() == MessageType.SYSTEM) + .map(m -> cacheAwareContentBlock(m.getText(), usedCacheBlocks, cacheControlConfiguration, + MessageType.SYSTEM)) + .toList(); + List userMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { if (message.getMessageType() == MessageType.USER) { - List contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + List contents = new ArrayList<>(); + contents.add(cacheAwareContentBlock(message.getText(), usedCacheBlocks, cacheControlConfiguration, + MessageType.USER)); if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia().stream().map(media -> { Type contentBlockType = getContentBlockTypeByMedia(media); var source = getSourceByMedia(media); return new ContentBlock(contentBlockType, source); - }).toList(); + }) + .map(contentBlock -> cacheAwareContentBlock(contentBlock, usedCacheBlocks, + cacheControlConfiguration, MessageType.USER)) + .toList(); contents.addAll(mediaContent); } } @@ -505,12 +586,15 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage) message; List contentBlocks = new ArrayList<>(); if (StringUtils.hasText(message.getText())) { - contentBlocks.add(new ContentBlock(message.getText())); + contentBlocks.add(cacheAwareContentBlock(message.getText(), usedCacheBlocks, + cacheControlConfiguration, MessageType.ASSISTANT)); } if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(), - ModelOptionsUtils.jsonToMap(toolCall.arguments()))); + ContentBlock contentBlock = new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(), + ModelOptionsUtils.jsonToMap(toolCall.arguments())); + contentBlocks.add(cacheAwareContentBlock(contentBlock, usedCacheBlocks, + cacheControlConfiguration, MessageType.ASSISTANT)); } } return new AnthropicMessage(contentBlocks, Role.ASSISTANT); @@ -520,6 +604,8 @@ else if (message.getMessageType() == MessageType.TOOL) { .stream() .map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(), toolResponse.responseData())) + .map(contentBlock -> cacheAwareContentBlock(contentBlock, usedCacheBlocks, + cacheControlConfiguration, MessageType.TOOL)) .toList(); return new AnthropicMessage(toolResponses, Role.USER); } @@ -529,16 +615,15 @@ else if (message.getMessageType() == MessageType.TOOL) { }) .toList(); - String systemPrompt = prompt.getInstructions() - .stream() - .filter(m -> m.getMessageType() == MessageType.SYSTEM) - .map(m -> m.getText()) - .collect(Collectors.joining(System.lineSeparator())); - - ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages, - systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream); + ChatCompletionRequest request = ChatCompletionRequest.builder() + .model(this.defaultOptions.getModel()) + .messages(userMessages) + .system(systemPrompt) + .maxTokens(this.defaultOptions.getMaxTokens()) + .temperature(this.defaultOptions.getTemperature()) + .stream(stream) + .build(); - AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions(); request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class); // Add the tool definitions to the request's tools parameter. diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index dbfbee561c8..f328050784d 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -32,6 +32,8 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.AnthropicCacheType; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -44,6 +46,8 @@ * @author Thomas Vitale * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author Soby Chacko + * @author Austin Dase * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -58,6 +62,10 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { private @JsonProperty("top_p") Double topP; private @JsonProperty("top_k") Integer topK; private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking; + /** + * Cache control configuration options for the chat completion request. + */ + private @JsonProperty("cache_control") CacheControlConfiguration cacheControlConfiguration; /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat @@ -111,6 +119,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) + .cacheControlConfiguration(fromOptions.getCacheControlConfiguration()) .build(); } @@ -259,6 +268,15 @@ public void setHttpHeaders(Map httpHeaders) { this.httpHeaders = httpHeaders; } + @JsonIgnore + public CacheControlConfiguration getCacheControlConfiguration() { + return this.cacheControlConfiguration; + } + + public void setCacheControlConfiguration(CacheControlConfiguration cacheControlConfiguration) { + this.cacheControlConfiguration = cacheControlConfiguration; + } + @Override @SuppressWarnings("unchecked") public AnthropicChatOptions copy() { @@ -282,14 +300,222 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) && Objects.equals(this.toolContext, that.toolContext) - && Objects.equals(this.httpHeaders, that.httpHeaders); + && Objects.equals(this.httpHeaders, that.httpHeaders) + && Objects.equals(this.cacheControlConfiguration, that.cacheControlConfiguration); } @Override public int hashCode() { return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP, this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext, this.httpHeaders); + this.toolContext, this.httpHeaders, this.cacheControlConfiguration); + } + + public static class CacheControlConfiguration { + + /** + * The Anthropic API allows a maximum of 4 cache blocks. By default, we will + * attempt to cache up to 4 blocks. + */ + private static final int DEFAULT_MAX_CACHE_BLOCKS = 4; + + /** + * The minimum text or content length for a message to be considered for caching. + * By default, we will only cache messages with at least 2000 characters - + * counting characters as a lightweight way to roughly estimate tokens. This helps + * to avoid caching very short messages that are unlikely to benefit from caching. + * Note: The Anthropic API has a minimum cacheable message length of 1024 tokens. + * See + * here + */ + private static final int DEFAULT_MIN_CACHE_BLOCK_LENGTH = 2000; + + /** + * The default set of message types that are considered for caching. By default, + * we will cache system, user, assistant, and tool messages. + */ + private static final Set DEFAULT_CACHABLE_MESSAGE_TYPES = Set.of(MessageType.SYSTEM, + MessageType.USER, MessageType.ASSISTANT, MessageType.TOOL); + + /** + * The default cache types to use for each message type. By default, we will use + * EPHEMERAL_1H for system messages and EPHEMERAL for user, assistant, and tool + * messages. See here + */ + private static final Map DEFAULT_MESSAGE_TYPE_CACHE_TYPES = Map.of( + MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H, MessageType.USER, AnthropicCacheType.EPHEMERAL, + MessageType.ASSISTANT, AnthropicCacheType.EPHEMERAL, MessageType.TOOL, AnthropicCacheType.EPHEMERAL); + + private int maxCacheBlocks = DEFAULT_MAX_CACHE_BLOCKS; + + private int minCacheBlockLength = DEFAULT_MIN_CACHE_BLOCK_LENGTH; + + private Set cachableMessageTypes = new HashSet<>(DEFAULT_CACHABLE_MESSAGE_TYPES); + + private Map messageTypeCacheTypes = new HashMap<>( + DEFAULT_MESSAGE_TYPE_CACHE_TYPES); + + /** + * To enable specific minimum block lengths per message type, use this map to + * override the default {@link #minCacheBlockLength} for specific message types. + * For example, you might want to set a higher minimum length for system messages + * and a lower minimum length for user messages. + */ + private Map messageTypeMinBlockLength = new HashMap<>(); + + public static CacheControlConfiguration DEFAULT = new CacheControlConfiguration(); + + public static Builder builder() { + return new Builder(); + } + + public int getMaxCacheBlocks() { + return this.maxCacheBlocks; + } + + public void setMaxCacheBlocks(int maxCacheBlocks) { + this.maxCacheBlocks = maxCacheBlocks; + } + + public int getMinCacheBlockLength() { + return this.minCacheBlockLength; + } + + public void setMinCacheBlockLength(int minCacheBlockLength) { + this.minCacheBlockLength = minCacheBlockLength; + } + + public Set getCachableMessageTypes() { + return this.cachableMessageTypes; + } + + public void setCachableMessageTypes(Set cachableMessageTypes) { + this.cachableMessageTypes = cachableMessageTypes; + } + + public Map getMessageTypeCacheTypes() { + return this.messageTypeCacheTypes; + } + + public void setMessageTypeCacheTypes(Map messageTypeCacheTypes) { + this.messageTypeCacheTypes = messageTypeCacheTypes; + } + + public Map getMessageTypeMinBlockLength() { + return this.messageTypeMinBlockLength; + } + + public void setMessageTypeMinBlockLength(Map messageTypeMinBlockLength) { + this.messageTypeMinBlockLength = messageTypeMinBlockLength; + } + + /** + * Get the cache type for a given message type. If the message type is not + * configured, return EPHEMERAL as the default. + * @param messageType the message type + * @return the cache type for the message type + */ + public AnthropicCacheType getCacheTypeForMessageType(MessageType messageType) { + return this.messageTypeCacheTypes.getOrDefault(messageType, AnthropicCacheType.EPHEMERAL); + } + + /** + * Get the minimum block length for a given message type. If the message type is + * not configured, return the default minimum block length. + * @param messageType + * @return the minimum block length for the message type + */ + public Integer getMinBlockLengthForMessageType(MessageType messageType) { + return this.messageTypeMinBlockLength.getOrDefault(messageType, this.minCacheBlockLength); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CacheControlConfiguration that)) { + return false; + } + return this.maxCacheBlocks == that.maxCacheBlocks && this.minCacheBlockLength == that.minCacheBlockLength + && Objects.equals(this.cachableMessageTypes, that.cachableMessageTypes) + && Objects.equals(this.messageTypeCacheTypes, that.messageTypeCacheTypes) + && Objects.equals(this.messageTypeMinBlockLength, that.messageTypeMinBlockLength); + } + + @Override + public int hashCode() { + return Objects.hash(this.maxCacheBlocks, this.minCacheBlockLength, this.cachableMessageTypes, + this.messageTypeCacheTypes, this.messageTypeMinBlockLength); + } + + @Override + public String toString() { + return "CacheControlConfiguration{" + "maxCacheBlocks=" + this.maxCacheBlocks + ", minCacheBlockLength=" + + this.minCacheBlockLength + ", cachableMessageTypes=" + this.cachableMessageTypes + + ", messageTypeCacheTypes=" + this.messageTypeCacheTypes + ", messageTypeMinBlockLength=" + + this.messageTypeMinBlockLength + '}'; + } + + public static class Builder { + + private final CacheControlConfiguration configuration = new CacheControlConfiguration(); + + public Builder() { + } + + public Builder maxCacheBlocks(int maxCacheBlocks) { + this.configuration.setMaxCacheBlocks(maxCacheBlocks); + return this; + } + + public Builder minCacheBlockLength(int minCacheBlockLength) { + this.configuration.setMinCacheBlockLength(minCacheBlockLength); + return this; + } + + public Builder cachableMessageTypes(Set cachableMessageTypes) { + this.configuration.setCachableMessageTypes(cachableMessageTypes); + return this; + } + + public Builder messageTypeCacheTypes(Map messageTypeCacheTypes) { + this.configuration.setMessageTypeCacheTypes(messageTypeCacheTypes); + return this; + } + + public Builder addCachableMessageType(MessageType messageType) { + if (this.configuration.getCachableMessageTypes() == null) { + this.configuration.setCachableMessageTypes(new HashSet<>()); + } + this.configuration.getCachableMessageTypes().add(messageType); + return this; + } + + public Builder addMessageTypeCacheType(MessageType messageType, AnthropicCacheType cacheType) { + if (this.configuration.getMessageTypeCacheTypes() == null) { + this.configuration.setMessageTypeCacheTypes(new HashMap<>()); + } + this.configuration.getMessageTypeCacheTypes().put(messageType, cacheType); + return this; + } + + public Builder minBlockLengthForMessageType(MessageType messageType, Integer minBlockLength) { + if (this.configuration.messageTypeMinBlockLength == null) { + this.configuration.messageTypeMinBlockLength = new HashMap<>(); + } + this.configuration.messageTypeMinBlockLength.put(messageType, minBlockLength); + return this; + } + + public CacheControlConfiguration build() { + return this.configuration; + } + + } + } public static class Builder { @@ -389,6 +615,11 @@ public Builder httpHeaders(Map httpHeaders) { return this; } + public Builder cacheControlConfiguration(CacheControlConfiguration cacheControlConfiguration) { + this.options.cacheControlConfiguration = cacheControlConfiguration; + return this; + } + public AnthropicChatOptions build() { return this.options; } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index b573ff8a139..1f852ee293f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -35,6 +35,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; @@ -65,6 +66,8 @@ * @author Jonghoon Park * @author Claudio Silva Junior * @author Filip Hrisafov + * @author Soby Chacko + * @author Austin Dase * @since 1.0.0 */ public final class AnthropicApi { @@ -171,14 +174,14 @@ public ResponseEntity chatCompletionEntity(ChatCompletio // @formatter:off return this.restClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletionResponse.class); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletionResponse.class); // @formatter:on } @@ -212,44 +215,44 @@ public Flux chatCompletionStream(ChatCompletionRequest c // @formatter:off return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) // @formatter:off - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) - .filter(event -> event.type() != EventType.PING) - // Detect if the chunk is part of a streaming function call. - .map(event -> { - logger.debug("Received event: {}", event); - - if (this.streamHelper.isToolUseStart(event)) { - isInsideTool.set(true); - } - return event; - }) - // Group all chunks belonging to the same function call. - .windowUntil(event -> { - if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - .concatMapIterable(window -> { - Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), - this.streamHelper::mergeToolUseEvents); - return List.of(monoChunk); - }) - .flatMap(mono -> mono) - .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) - .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:off + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) + .filter(event -> event.type() != EventType.PING) + // Detect if the chunk is part of a streaming function call. + .map(event -> { + logger.debug("Received event: {}", event); + + if (this.streamHelper.isToolUseStart(event)) { + isInsideTool.set(true); + } + return event; + }) + // Group all chunks belonging to the same function call. + .windowUntil(event -> { + if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + .concatMapIterable(window -> { + Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), + this.streamHelper::mergeToolUseEvents); + return List.of(monoChunk); + }) + .flatMap(mono -> mono) + .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) + .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); } private void addDefaultHeadersIfMissing(HttpHeaders headers) { @@ -356,7 +359,7 @@ public enum Role { // @formatter:off /** * The user role. - */ + */ @JsonProperty("user") USER, @@ -512,28 +515,30 @@ public interface StreamEvent { @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest( // @formatter:off - @JsonProperty("model") String model, - @JsonProperty("messages") List messages, - @JsonProperty("system") String system, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("metadata") Metadata metadata, - @JsonProperty("stop_sequences") List stopSequences, - @JsonProperty("stream") Boolean stream, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("top_k") Integer topK, - @JsonProperty("tools") List tools, - @JsonProperty("thinking") ThinkingConfig thinking) { + @JsonProperty("model") String model, + @JsonProperty("messages") List messages, + @JsonProperty("system") List system, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("metadata") Metadata metadata, + @JsonProperty("stop_sequences") List stopSequences, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("top_k") Integer topK, + @JsonProperty("tools") List tools, + @JsonProperty("thinking") ThinkingConfig thinking) { // @formatter:on public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null); + this(model, messages, List.of(new ContentBlock(system)), maxTokens, null, null, stream, temperature, null, + null, null, null); } public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, List stopSequences, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null); + this(model, messages, List.of(new ContentBlock(system)), maxTokens, null, stopSequences, stream, + temperature, null, null, null, null); } public static ChatCompletionRequestBuilder builder() { @@ -557,6 +562,22 @@ public record Metadata(@JsonProperty("user_id") String userId) { } + @JsonInclude(Include.NON_NULL) + public record System(@JsonProperty("user_id") String userId) { + + } + + /** + * @param type is the cache type supported by anthropic. Doc + */ + @JsonInclude(Include.NON_NULL) + public record CacheControl(String type, String ttl) { + public CacheControl(String type) { + this(type, null); + } + } + /** * Configuration for the model's thinking mode. * @@ -577,7 +598,7 @@ public static final class ChatCompletionRequestBuilder { private List messages; - private String system; + private List system; private Integer maxTokens; @@ -631,6 +652,10 @@ public ChatCompletionRequestBuilder messages(List messages) { } public ChatCompletionRequestBuilder system(String system) { + return this.system(List.of(new ContentBlock(system))); + } + + public ChatCompletionRequestBuilder system(List system) { this.system = system; return this; } @@ -717,8 +742,8 @@ public ChatCompletionRequest build() { @JsonInclude(Include.NON_NULL) public record AnthropicMessage( // @formatter:off - @JsonProperty("content") List content, - @JsonProperty("role") Role role) { + @JsonProperty("content") List content, + @JsonProperty("role") Role role) { // @formatter:on } @@ -742,29 +767,32 @@ public record AnthropicMessage( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlock( // @formatter:off - @JsonProperty("type") Type type, - @JsonProperty("source") Source source, - @JsonProperty("text") String text, + @JsonProperty("type") Type type, + @JsonProperty("source") Source source, + @JsonProperty("text") String text, - // applicable only for streaming responses. - @JsonProperty("index") Integer index, + // applicable only for streaming responses. + @JsonProperty("index") Integer index, - // tool_use response only - @JsonProperty("id") String id, - @JsonProperty("name") String name, - @JsonProperty("input") Map input, + // tool_use response only + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input, - // tool_result response only - @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content, + // tool_result response only + @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("content") String content, - // Thinking only - @JsonProperty("signature") String signature, - @JsonProperty("thinking") String thinking, + // Thinking only + @JsonProperty("signature") String signature, + @JsonProperty("thinking") String thinking, - // Redacted Thinking only - @JsonProperty("data") String data - ) { + // Redacted Thinking only + @JsonProperty("data") String data, + + // cache object + @JsonProperty("cache_control") CacheControl cacheControl + ) { // @formatter:on /** @@ -782,7 +810,7 @@ public ContentBlock(String mediaType, String data) { * @param source The source of the content. */ public ContentBlock(Type type, Source source) { - this(type, source, null, null, null, null, null, null, null, null, null, null); + this(type, source, null, null, null, null, null, null, null, null, null, null, null); } /** @@ -790,7 +818,7 @@ public ContentBlock(Type type, Source source) { * @param source The source of the content. */ public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null, null); } /** @@ -798,7 +826,11 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, null); + } + + public ContentBlock(String text, CacheControl cache) { + this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, cache); } // Tool result @@ -809,7 +841,7 @@ public ContentBlock(String text) { * @param content The content of the tool result. */ public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content, null, null, null); + this(type, null, null, null, null, null, null, toolUseId, content, null, null, null, null); } /** @@ -820,7 +852,7 @@ public ContentBlock(Type type, String toolUseId, String content) { * @param index The index of the content block. */ public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null, null, null, null); } // Tool use input JSON delta streaming @@ -832,7 +864,11 @@ public ContentBlock(Type type, Source source, String text, Integer index) { * @param input The input of the tool use. */ public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null, null, null, null); + this(type, null, null, null, id, name, input, null, null, null, null, null, null); + } + + public static ContentBlockBuilder from(ContentBlock contentBlock) { + return new ContentBlockBuilder(contentBlock); } /** @@ -938,10 +974,10 @@ public String getValue() { @JsonInclude(Include.NON_NULL) public record Source( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("media_type") String mediaType, - @JsonProperty("data") String data, - @JsonProperty("url") String url) { + @JsonProperty("type") String type, + @JsonProperty("media_type") String mediaType, + @JsonProperty("data") String data, + @JsonProperty("url") String url) { // @formatter:on /** @@ -959,6 +995,122 @@ public Source(String url) { } + public static class ContentBlockBuilder { + + private Type type; + + private Source source; + + private String text; + + private Integer index; + + private String id; + + private String name; + + private Map input; + + private String toolUseId; + + private String content; + + private String signature; + + private String thinking; + + private String data; + + private CacheControl cacheControl; + + public ContentBlockBuilder(ContentBlock contentBlock) { + this.type = contentBlock.type; + this.source = contentBlock.source; + this.text = contentBlock.text; + this.index = contentBlock.index; + this.id = contentBlock.id; + this.name = contentBlock.name; + this.input = contentBlock.input; + this.toolUseId = contentBlock.toolUseId; + this.content = contentBlock.content; + this.signature = contentBlock.signature; + this.thinking = contentBlock.thinking; + this.data = contentBlock.data; + this.cacheControl = contentBlock.cacheControl; + } + + public ContentBlockBuilder type(Type type) { + this.type = type; + return this; + } + + public ContentBlockBuilder source(Source source) { + this.source = source; + return this; + } + + public ContentBlockBuilder text(String text) { + this.text = text; + return this; + } + + public ContentBlockBuilder index(Integer index) { + this.index = index; + return this; + } + + public ContentBlockBuilder id(String id) { + this.id = id; + return this; + } + + public ContentBlockBuilder name(String name) { + this.name = name; + return this; + } + + public ContentBlockBuilder input(Map input) { + this.input = input; + return this; + } + + public ContentBlockBuilder toolUseId(String toolUseId) { + this.toolUseId = toolUseId; + return this; + } + + public ContentBlockBuilder content(String content) { + this.content = content; + return this; + } + + public ContentBlockBuilder signature(String signature) { + this.signature = signature; + return this; + } + + public ContentBlockBuilder thinking(String thinking) { + this.thinking = thinking; + return this; + } + + public ContentBlockBuilder data(String data) { + this.data = data; + return this; + } + + public ContentBlockBuilder cacheControl(CacheControl cacheControl) { + this.cacheControl = cacheControl; + return this; + } + + public ContentBlock build() { + return new ContentBlock(this.type, this.source, this.text, this.index, this.id, this.name, this.input, + this.toolUseId, this.content, this.signature, this.thinking, this.data, this.cacheControl); + } + + } + } /////////////////////////////////////// @@ -975,9 +1127,9 @@ public Source(String url) { @JsonInclude(Include.NON_NULL) public record Tool( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("input_schema") Map inputSchema) { + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("input_schema") Map inputSchema) { // @formatter:on } @@ -1002,14 +1154,14 @@ public record Tool( @JsonIgnoreProperties(ignoreUnknown = true) public record ChatCompletionResponse( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("type") String type, - @JsonProperty("role") Role role, - @JsonProperty("content") List content, - @JsonProperty("model") String model, - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence, - @JsonProperty("usage") Usage usage) { + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("role") Role role, + @JsonProperty("content") List content, + @JsonProperty("model") String model, + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence, + @JsonProperty("usage") Usage usage) { // @formatter:on } @@ -1025,17 +1177,19 @@ public record ChatCompletionResponse( @JsonIgnoreProperties(ignoreUnknown = true) public record Usage( // @formatter:off - @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, + @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { // @formatter:off } - /// ECB STOP + /// ECB STOP /** * Special event used to aggregate multiple tool use events into a single event with * list of aggregated ContentBlockToolUse. - */ + */ public static class ToolUseAggregationEvent implements StreamEvent { private Integer index; @@ -1054,17 +1208,17 @@ public EventType type() { } /** - * Get tool content blocks. - * @return The tool content blocks. - */ + * Get tool content blocks. + * @return The tool content blocks. + */ public List getToolContentBlocks() { return this.toolContentBlocks; } /** - * Check if the event is empty. - * @return True if the event is empty, false otherwise. - */ + * Check if the event is empty. + * @return True if the event is empty, false otherwise. + */ public boolean isEmpty() { return (this.index == null || this.id == null || this.name == null); } @@ -1102,30 +1256,30 @@ void squashIntoContentBlock() { @Override public String toString() { return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson=" - + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; + + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; } } - /////////////////////////////////////// - /// MESSAGE EVENTS - /////////////////////////////////////// + /////////////////////////////////////// + /// MESSAGE EVENTS + /////////////////////////////////////// - // MESSAGE START EVENT + // MESSAGE START EVENT /** * Content block start event. * @param type The event type. * @param index The index of the content block. * @param contentBlock The content block body. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStartEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -1139,31 +1293,31 @@ public interface ContentBlockBody { } /** - * Tool use content block. - * @param type The content block type. - * @param id The tool use id. - * @param name The tool use name. - * @param input The tool use input. - */ + * Tool use content block. + * @param type The content block type. + * @param id The tool use id. + * @param name The tool use name. + * @param input The tool use input. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockToolUse( - @JsonProperty("type") String type, - @JsonProperty("id") String id, - @JsonProperty("name") String name, - @JsonProperty("input") Map input) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input) implements ContentBlockBody { } /** - * Text content block. - * @param type The content block type. - * @param text The text content. - */ + * Text content block. + * @param type The content block type. + * @param text The text content. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockText( - @JsonProperty("type") String type, - @JsonProperty("text") String text) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("text") String text) implements ContentBlockBody { } /** @@ -1173,9 +1327,9 @@ public record ContentBlockText( */ @JsonInclude(Include.NON_NULL) public record ContentBlockThinking( - @JsonProperty("type") String type, - @JsonProperty("thinking") String thinking, - @JsonProperty("signature") String signature) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking, + @JsonProperty("signature") String signature) implements ContentBlockBody { } } // @formatter:on @@ -1193,9 +1347,9 @@ public record ContentBlockThinking( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -1212,24 +1366,24 @@ public interface ContentBlockDeltaBody { * Text content block delta. * @param type The content block type. * @param text The text content. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaText( - @JsonProperty("type") String type, - @JsonProperty("text") String text) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("text") String text) implements ContentBlockDeltaBody { } /** - * JSON content block delta. - * @param type The content block type. - * @param partialJson The partial JSON content. - */ + * JSON content block delta. + * @param type The content block type. + * @param partialJson The partial JSON content. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaJson( - @JsonProperty("type") String type, - @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { } /** @@ -1240,8 +1394,8 @@ public record ContentBlockDeltaJson( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaThinking( - @JsonProperty("type") String type, - @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { } /** @@ -1252,8 +1406,8 @@ public record ContentBlockDeltaThinking( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaSignature( - @JsonProperty("type") String type, - @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { } } // @formatter:on @@ -1270,8 +1424,8 @@ public record ContentBlockDeltaSignature( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStopEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index) implements StreamEvent { } // @formatter:on @@ -1284,8 +1438,8 @@ public record ContentBlockStopEvent( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageStartEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { } // @formatter:on @@ -1300,29 +1454,29 @@ public record MessageStartEvent(// @formatter:off @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDeltaEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("delta") MessageDelta delta, - @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("delta") MessageDelta delta, + @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { /** - * @param stopReason The stop reason. - * @param stopSequence The stop sequence. - */ + * @param stopReason The stop reason. + * @param stopSequence The stop sequence. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDelta( - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence) { + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence) { } /** * Message delta usage. * @param outputTokens The output tokens. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDeltaUsage( - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("output_tokens") Integer outputTokens) { } } // @formatter:on @@ -1336,7 +1490,7 @@ public record MessageDeltaUsage( @JsonIgnoreProperties(ignoreUnknown = true) public record MessageStopEvent( //@formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { + @JsonProperty("type") EventType type) implements StreamEvent { } // @formatter:on @@ -1353,19 +1507,19 @@ public record MessageStopEvent( @JsonIgnoreProperties(ignoreUnknown = true) public record ErrorEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("error") Error error) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("error") Error error) implements StreamEvent { /** * Error body. * @param type The error type. * @param message The error message. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record Error( - @JsonProperty("type") String type, - @JsonProperty("message") String message) { + @JsonProperty("type") String type, + @JsonProperty("message") String message) { } } // @formatter:on @@ -1382,7 +1536,7 @@ public record Error( @JsonIgnoreProperties(ignoreUnknown = true) public record PingEvent( // @formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { + @JsonProperty("type") EventType type) implements StreamEvent { } // @formatter:on diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java new file mode 100644 index 00000000000..74ced8490be --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -0,0 +1,63 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.api; + +import java.util.function.Supplier; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; + +/** + * Cache types supported by Anthropic's prompt caching feature. + * + *

+ * Prompt caching allows reusing frequently used prompts to reduce costs and improve + * response times for repeated interactions. + * + * @see Anthropic Prompt + * Caching + * @author Claudio Silva Junior + * @author Soby Chacko + * @author Austin Dase + */ +public enum AnthropicCacheType { + + /** + * Ephemeral cache with 5-minute lifetime, refreshed on each use. + */ + EPHEMERAL(() -> new CacheControl("ephemeral")), + + /** + * Ephemeral cache with 1-hour lifetime, refreshed on each use. + */ + EPHEMERAL_1H(() -> new CacheControl("ephemeral", "1h")); + + private final Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + /** + * Returns a new CacheControl instance for this cache type. + * @return a CacheControl instance configured for this cache type + */ + public CacheControl cacheControl() { + return this.value.get(); + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index 673685e6d13..ca519a11d0e 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -55,6 +55,8 @@ * @author Christian Tzolov * @author Jihoon Kim * @author Alexandros Pappas + * @author Claudio Silva Junior + * @author Soby Chacko * @since 1.0.0 */ public class StreamHelper { @@ -159,7 +161,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_START)) { } else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) { ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null, - null, null, null, null, null, thinkingBlock.thinking(), null); + null, null, null, null, null, thinkingBlock.thinking(), null, null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -176,12 +178,12 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) { } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) { ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, null, thinking.thinking(), null); + null, null, null, null, null, null, thinking.thinking(), null, null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) { ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, sig.signature(), null, null); + null, null, null, null, null, sig.signature(), null, null, null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -205,7 +207,9 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { if (messageDeltaEvent.usage() != null) { Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); + messageDeltaEvent.usage().outputTokens(), + contentBlockReference.get().usage.cacheCreationInputTokens(), + contentBlockReference.get().usage.cacheReadInputTokens()); contentBlockReference.get().withUsage(totalUsage); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 6570d5ee6a6..f8088270f44 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -406,7 +406,7 @@ void thinkingTest() { .temperature(1.0) // temperature should be set to 1 when thinking is enabled .maxTokens(8192) .thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && < - // max_tokens + // max_tokens .build(); ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); @@ -438,7 +438,7 @@ void thinkingWithStreamingTest() { .temperature(1.0) // Temperature should be set to 1 when thinking is enabled .maxTokens(8192) .thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && < - // max_tokens + // max_tokens .build(); Flux responseFlux = this.streamingChatModel @@ -491,6 +491,59 @@ void testToolUseContentBlock() { } } + @Test + void chatWithPromptCacheViaOptions() { + String userMessageText = "It could be eitherr a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + // Repeat content to meet minimum token requirements for caching (1024+ tokens) + String largeContent = userMessageText.repeat(20); + + // First request - should create cache + ChatResponse firstResponse = this.chatModel.call(new Prompt(List.of(new UserMessage(largeContent)), + AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) + .maxTokens(100) + .temperature(0.8) + .build())); + + // Access native Anthropic usage data + AnthropicApi.Usage firstUsage = (AnthropicApi.Usage) firstResponse.getMetadata().getUsage().getNativeUsage(); + + // Verify first request created cache + assertThat(firstUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(firstUsage.cacheReadInputTokens()).isEqualTo(0); + + // Second request with identical content - should read from cache + ChatResponse secondResponse = this.chatModel.call(new Prompt(List.of(new UserMessage(largeContent)), + AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) + .maxTokens(100) + .temperature(0.8) + .build())); + + // Access native Anthropic usage data + AnthropicApi.Usage secondUsage = (AnthropicApi.Usage) secondResponse.getMetadata().getUsage().getNativeUsage(); + + // Verify second request used cache + assertThat(secondUsage.cacheCreationInputTokens()).isEqualTo(0); + assertThat(secondUsage.cacheReadInputTokens()).isGreaterThan(0); + + // Both responses should be valid + assertThat(firstResponse.getResult().getOutput().getText()).isNotBlank(); + assertThat(secondResponse.getResult().getOutput().getText()).isNotBlank(); + + logger.info("First request - Cache creation: {}, Cache read: {}", firstUsage.cacheCreationInputTokens(), + firstUsage.cacheReadInputTokens()); + logger.info("Second request - Cache creation: {}, Cache read: {}", secondUsage.cacheCreationInputTokens(), + secondUsage.cacheReadInputTokens()); + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java index d9470070e95..cb6b1a7bfa6 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java @@ -22,6 +22,9 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.anthropic.AnthropicChatOptions.CacheControlConfiguration; +import org.springframework.ai.anthropic.api.AnthropicCacheType; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; import static org.assertj.core.api.Assertions.assertThat; @@ -30,6 +33,8 @@ * Tests for {@link AnthropicChatOptions}. * * @author Alexandros Pappas + * @author Soby Chacko + * @author Austin Dase */ class AnthropicChatOptionsTests { @@ -43,10 +48,14 @@ void testBuilderWithAllFields() { .topP(0.8) .topK(50) .metadata(new Metadata("userId_123")) + .cacheControlConfiguration(CacheControlConfiguration.DEFAULT) .build(); - assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata") - .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123")); + assertThat(options) + .extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata", + "cacheControlConfiguration") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), + CacheControlConfiguration.DEFAULT); } @Test @@ -60,6 +69,7 @@ void testCopy() { .topK(50) .metadata(new Metadata("userId_123")) .toolContext(Map.of("key1", "value1")) + .cacheControlConfiguration(CacheControlConfiguration.builder().minCacheBlockLength(100).build()) .build(); AnthropicChatOptions copied = original.copy(); @@ -68,6 +78,7 @@ void testCopy() { // Ensure deep copy assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + assertThat(copied.getCacheControlConfiguration()).isEqualTo(original.getCacheControlConfiguration()); } @Test @@ -80,6 +91,7 @@ void testSetters() { options.setTopP(0.8); options.setStopSequences(List.of("stop1", "stop2")); options.setMetadata(new Metadata("userId_123")); + options.setCacheControlConfiguration(CacheControlConfiguration.DEFAULT); assertThat(options.getModel()).isEqualTo("test-model"); assertThat(options.getMaxTokens()).isEqualTo(100); @@ -88,6 +100,7 @@ void testSetters() { assertThat(options.getTopP()).isEqualTo(0.8); assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123")); + assertThat(options.getCacheControlConfiguration()).isEqualTo(CacheControlConfiguration.DEFAULT); } @Test @@ -100,6 +113,7 @@ void testDefaultValues() { assertThat(options.getTopP()).isNull(); assertThat(options.getStopSequences()).isNull(); assertThat(options.getMetadata()).isNull(); + assertThat(options.getCacheControlConfiguration()).isNull(); } @Test @@ -133,6 +147,7 @@ void testCopyWithEmptyOptions() { assertThat(copiedOptions.getModel()).isNull(); assertThat(copiedOptions.getMaxTokens()).isNull(); assertThat(copiedOptions.getTemperature()).isNull(); + assertThat(copiedOptions.getCacheControlConfiguration()).isNull(); } @Test @@ -199,6 +214,8 @@ void testChainedBuilderMethods() { .stopSequences(List.of("stop")) .metadata(new Metadata("user_456")) .toolContext(Map.of("context", "value")) + .cacheControlConfiguration( + CacheControlConfiguration.builder().minCacheBlockLength(50).maxCacheBlocks(10).build()) .build(); // Verify all chained methods worked @@ -210,6 +227,9 @@ void testChainedBuilderMethods() { assertThat(options.getStopSequences()).containsExactly("stop"); assertThat(options.getMetadata()).isEqualTo(new Metadata("user_456")); assertThat(options.getToolContext()).containsEntry("context", "value"); + assertThat(options.getCacheControlConfiguration()).isNotNull(); + assertThat(options.getCacheControlConfiguration().getMinCacheBlockLength()).isEqualTo(50); + assertThat(options.getCacheControlConfiguration().getMaxCacheBlocks()).isEqualTo(10); } @Test @@ -224,6 +244,7 @@ void testSettersWithNullValues() { options.setStopSequences(null); options.setMetadata(null); options.setToolContext(null); + options.setCacheControlConfiguration(null); assertThat(options.getModel()).isNull(); assertThat(options.getMaxTokens()).isNull(); @@ -233,6 +254,7 @@ void testSettersWithNullValues() { assertThat(options.getStopSequences()).isNull(); assertThat(options.getMetadata()).isNull(); assertThat(options.getToolContext()).isNull(); + assertThat(options.getCacheControlConfiguration()).isNull(); } @Test @@ -299,6 +321,8 @@ void testCopyPreservesAllFields() { .topK(60) .metadata(new Metadata("comprehensive_test")) .toolContext(Map.of("key1", "value1", "key2", "value2")) + .cacheControlConfiguration( + CacheControlConfiguration.builder().minCacheBlockLength(200).maxCacheBlocks(5).build()) .build(); AnthropicChatOptions copied = original.copy(); @@ -312,6 +336,7 @@ void testCopyPreservesAllFields() { assertThat(copied.getTopK()).isEqualTo(original.getTopK()); assertThat(copied.getMetadata()).isEqualTo(original.getMetadata()); assertThat(copied.getToolContext()).isEqualTo(original.getToolContext()); + assertThat(copied.getCacheControlConfiguration()).isEqualTo(original.getCacheControlConfiguration()); // Ensure deep copy for collections assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); @@ -471,4 +496,203 @@ void testSetterOverwriteBehavior() { assertThat(options.getMaxTokens()).isEqualTo(10); } + @Test + void testCacheControlConfigurationBuilder() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); + + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model("test-model") + .cacheControlConfiguration(config) + .build(); + + assertThat(options.getCacheControlConfiguration()).isEqualTo(config); + // Default max cache blocks is 4 per configuration defaults + assertThat(options.getCacheControlConfiguration().getMaxCacheBlocks()).isEqualTo(4); + } + + @Test + void testCacheControlDefaultValue() { + AnthropicChatOptions options = new AnthropicChatOptions(); + assertThat(options.getCacheControlConfiguration()).isNull(); + } + + @Test + void testCacheControlConfigurationEqualsAndHashCode() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); + + AnthropicChatOptions options1 = AnthropicChatOptions.builder() + .model("test-model") + .cacheControlConfiguration(config) + .build(); + + AnthropicChatOptions options2 = AnthropicChatOptions.builder() + .model("test-model") + .cacheControlConfiguration(config) + .build(); + + AnthropicChatOptions options3 = AnthropicChatOptions.builder().model("test-model").build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + + assertThat(options1).isNotEqualTo(options3); + assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); + } + + @Test + void testCacheControlConfigurationCopy() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); + + AnthropicChatOptions original = AnthropicChatOptions.builder() + .model("test-model") + .cacheControlConfiguration(config) + .build(); + + AnthropicChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + assertThat(copied.getCacheControlConfiguration()).isEqualTo(original.getCacheControlConfiguration()); + // copy() preserves the same configuration instance + assertThat(copied.getCacheControlConfiguration()).isSameAs(config); + } + + @Test + void testCacheControlConfigurationWithNullValue() { + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model("test-model") + .cacheControlConfiguration(null) + .build(); + + assertThat(options.getCacheControlConfiguration()).isNull(); + } + + @Test + void testBuilderWithAllFieldsIncludingCacheControlConfiguration() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); + + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .metadata(new Metadata("userId_123")) + .cacheControlConfiguration(config) + .build(); + + assertThat(options) + .extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata", + "cacheControlConfiguration") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), + config); + } + + @Test + void testCacheControlConfigurationMutationDoesNotAffectOriginal() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); + + AnthropicChatOptions original = AnthropicChatOptions.builder() + .model("original-model") + .cacheControlConfiguration(config) + .build(); + + AnthropicChatOptions copy = original.copy(); + copy.setCacheControlConfiguration(null); + + // Original should remain unchanged + assertThat(original.getCacheControlConfiguration()).isEqualTo(config); + // Copy should have null cache control configuration + assertThat(copy.getCacheControlConfiguration()).isNull(); + } + + @Test + void testCacheControlConfigurationDefaults() { + CacheControlConfiguration defaults = new CacheControlConfiguration(); + + assertThat(defaults.getMaxCacheBlocks()).isEqualTo(4); + assertThat(defaults.getMinCacheBlockLength()).isEqualTo(2000); + assertThat(defaults.getCachableMessageTypes()).containsExactlyInAnyOrder(MessageType.SYSTEM, MessageType.USER, + MessageType.ASSISTANT, MessageType.TOOL); + assertThat(defaults.getMessageTypeCacheTypes()) + .containsEntry(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H) + .containsEntry(MessageType.USER, AnthropicCacheType.EPHEMERAL) + .containsEntry(MessageType.ASSISTANT, AnthropicCacheType.EPHEMERAL) + .containsEntry(MessageType.TOOL, AnthropicCacheType.EPHEMERAL); + + // Static DEFAULT matches a fresh instance by value + assertThat(CacheControlConfiguration.DEFAULT).isEqualTo(defaults); + } + + @Test + void testCacheTypeLookupDefaultAndOverride() { + // Start from empty mapping then add specific override + CacheControlConfiguration config = CacheControlConfiguration.builder() + .messageTypeCacheTypes(null) // force builder to initialize map on demand + .addMessageTypeCacheType(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H) + .build(); + + // Unmapped types default to EPHEMERAL + assertThat(config.getCacheTypeForMessageType(MessageType.USER)).isEqualTo(AnthropicCacheType.EPHEMERAL); + assertThat(config.getCacheTypeForMessageType(MessageType.ASSISTANT)).isEqualTo(AnthropicCacheType.EPHEMERAL); + assertThat(config.getCacheTypeForMessageType(MessageType.TOOL)).isEqualTo(AnthropicCacheType.EPHEMERAL); + + // Mapped type returns configured value + assertThat(config.getCacheTypeForMessageType(MessageType.SYSTEM)).isEqualTo(AnthropicCacheType.EPHEMERAL_1H); + } + + @Test + void testMinBlockLengthLookupDefaultAndOverride() { + CacheControlConfiguration config = CacheControlConfiguration.builder() + .minCacheBlockLength(3000) + .minBlockLengthForMessageType(MessageType.SYSTEM, 1500) + .build(); + + // Override applies for SYSTEM + assertThat(config.getMinBlockLengthForMessageType(MessageType.SYSTEM)).isEqualTo(1500); + // Others use global default + assertThat(config.getMinBlockLengthForMessageType(MessageType.USER)).isEqualTo(3000); + assertThat(config.getMinBlockLengthForMessageType(MessageType.ASSISTANT)).isEqualTo(3000); + assertThat(config.getMinBlockLengthForMessageType(MessageType.TOOL)).isEqualTo(3000); + } + + @Test + void testBuilderAddersInitializeNullCollections() { + CacheControlConfiguration config = CacheControlConfiguration.builder() + .cachableMessageTypes(null) + .addCachableMessageType(MessageType.USER) + .messageTypeCacheTypes(null) + .addMessageTypeCacheType(MessageType.USER, AnthropicCacheType.EPHEMERAL) + .minBlockLengthForMessageType(MessageType.USER, 1234) + .build(); + + assertThat(config.getCachableMessageTypes()).contains(MessageType.USER); + assertThat(config.getMessageTypeCacheTypes()).containsEntry(MessageType.USER, AnthropicCacheType.EPHEMERAL); + assertThat(config.getMinBlockLengthForMessageType(MessageType.USER)).isEqualTo(1234); + } + + @Test + void testCacheControlConfigurationEqualityAcrossInstances() { + CacheControlConfiguration c1 = CacheControlConfiguration.builder() + .maxCacheBlocks(2) + .minCacheBlockLength(1111) + .cachableMessageTypes(new java.util.HashSet<>(java.util.List.of(MessageType.USER, MessageType.SYSTEM))) + .messageTypeCacheTypes(java.util.Map.of(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H)) + .minBlockLengthForMessageType(MessageType.SYSTEM, 999) + .build(); + + CacheControlConfiguration c2 = CacheControlConfiguration.builder() + .maxCacheBlocks(2) + .minCacheBlockLength(1111) + .cachableMessageTypes(new java.util.HashSet<>(java.util.List.of(MessageType.SYSTEM, MessageType.USER))) // different + // order + .messageTypeCacheTypes( + new java.util.HashMap<>(java.util.Map.of(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H))) + .minBlockLengthForMessageType(MessageType.SYSTEM, 999) + .build(); + + assertThat(c1).isEqualTo(c2); + assertThat(c1.hashCode()).isEqualTo(c2.hashCode()); + } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index c78386fb7ce..d6800dbdb74 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -44,6 +44,8 @@ * @author Christian Tzolov * @author Jihoon Kim * @author Alexandros Pappas + * @author Claudio Silva Junior + * @author Soby Chacko */ @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") public class AnthropicApiIT { @@ -70,6 +72,39 @@ public class AnthropicApiIT { } """))); + @Test + void chatWithPromptCache() { + String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.USER); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest( + AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, + false); + + // First request - creates cache + AnthropicApi.Usage createdCacheToken = this.anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + // Second request - reads from cache (same request) + AnthropicApi.Usage readCacheToken = this.anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void chatCompletionEntity() { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index 2094ab4ee17..9d1163168fc 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -191,6 +191,187 @@ ChatResponse response = chatModel.call( TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java[AnthropicChatOptions] you can use a portable link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java[ChatOptions#builder()]. +== Prompt Caching + +Anthropic's prompt caching feature allows you to cache frequently used prompts to reduce costs and improve response times for repeated interactions. +When you cache a prompt, subsequent identical requests can reuse the cached content, significantly reducing the number of input tokens processed. + +[NOTE] +==== +*Supported Models* + +Prompt caching is currently supported on Claude Opus 4, Claude Sonnet 4, Claude Sonnet 3.7, Claude Sonnet 3.5, Claude Haiku 3.5, Claude Haiku 3, and Claude Opus 3. +==== + +=== Cache Types + +Spring AI supports Anthropic's cache types through the `AnthropicCacheType` enum: + +* `EPHEMERAL`: Temporary caching suitable for short-term reuse within a session +* `EPHEMERAL_1H`: Extended ephemeral caching with a 1-hour lifetime. Note that this cache type incurs higher costs compared to standard ephemeral caching. + +=== Enabling Prompt Caching + +To enable prompt caching, use the `cacheControlConfiguration()` method in `AnthropicChatOptions.Builder`: + +==== Basic Usage + +[source,java] +---- +// Enable caching with ephemeral type +ChatResponse response = chatModel.call( + new Prompt( + List.of(new UserMessage("Large content to be cached...")), + AnthropicChatOptions.builder() + .model("claude-3-5-sonnet-latest").cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) + .build() + ) +); +---- + +==== Using ChatClient Fluent API + +[source,java] +---- +String response = ChatClient.create(chatModel) + .prompt() + .user("Analyze this large document: " + document) + .options(AnthropicChatOptions.builder() + .model("claude-3-5-sonnet-latest") + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) + .build()) + .call() + .content(); +---- + +=== Usage Example + +Here's a complete example demonstrating prompt caching with cost tracking: + +[source,java] +---- +// Create content that will be reused multiple times +String largeContent = "Large document content that meets minimum token requirements..."; + +// First request - creates cache +ChatResponse firstResponse = chatModel.call( + new Prompt( + List.of(new UserMessage(largeContent)), + AnthropicChatOptions.builder() + .model("claude-3-haiku-20240307") + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) + .maxTokens(100) + .build() + ) +); + +// Access cache-related token usage +AnthropicApi.Usage firstUsage = (AnthropicApi.Usage) firstResponse.getMetadata() + .getUsage().getNativeUsage(); + +System.out.println("Cache creation tokens: " + firstUsage.cacheCreationInputTokens()); +System.out.println("Cache read tokens: " + firstUsage.cacheReadInputTokens()); + +// Second request with identical content - reads from cache +ChatResponse secondResponse = chatModel.call( + new Prompt( + List.of(new UserMessage(largeContent)), + AnthropicChatOptions.builder() + .model("claude-3-haiku-20240307") + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) + .maxTokens(100) + .build() + ) +); + +AnthropicApi.Usage secondUsage = (AnthropicApi.Usage) secondResponse.getMetadata() + .getUsage().getNativeUsage(); + +System.out.println("Cache creation tokens: " + secondUsage.cacheCreationInputTokens()); +System.out.println("Cache read tokens: " + secondUsage.cacheReadInputTokens()); +---- + +=== Token Usage Tracking + +The `Usage` record provides detailed information about cache-related token consumption. +To access Anthropic-specific cache metrics, use the `getNativeUsage()` method: + +[source,java] +---- +AnthropicApi.Usage usage = (AnthropicApi.Usage) response.getMetadata() + .getUsage().getNativeUsage(); +---- + +Cache-specific metrics include: + +* `cacheCreationInputTokens()`: Returns the number of tokens used when creating a cache entry +* `cacheReadInputTokens()`: Returns the number of tokens read from an existing cache entry + +When you first send a cached prompt: +- `cacheCreationInputTokens()` will be greater than 0 +- `cacheReadInputTokens()` will be 0 + +When you send the same cached prompt again: +- `cacheCreationInputTokens()` will be 0 +- `cacheReadInputTokens()` will be greater than 0 + +=== Best Practices + +1. **Cache Long Prompts**: Focus on caching prompts that meet the minimum token requirements (1024+ tokens for most models, 2048+ for Haiku models). + +2. **Reuse Identical Content**: Caching works best with exact matches of prompt content. +Even small changes will require a new cache entry. + +3. **Monitor Token Usage**: Use the enhanced usage statistics to track cache effectiveness and optimize your caching strategy. + +4. **Place Static Content First**: Position cached content (system instructions, context, examples) at the beginning of your prompt for optimal performance. + +5. **5-Minute Cache Lifetime**: Ephemeral caches expire after 5 minutes of inactivity. +Each time cached content is accessed, the 5-minute timer resets. + +=== Low-level API Usage + +When using the low-level `AnthropicApi` directly, you can specify cache control through the `ContentBlock` constructor: + +[source,java] +---- +// Create content block with cache control +ContentBlock cachedContent = new ContentBlock( + "", + AnthropicCacheType.EPHEMERAL.cacheControl() +); + +AnthropicMessage message = new AnthropicMessage( + List.of(cachedContent), + Role.USER +); + +ChatCompletionRequest request = new ChatCompletionRequest( + AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), + List.of(message), + null, 100, 0.8, false +); + +ResponseEntity response = anthropicApi.chatCompletionEntity(request); + +// Access cache-related token usage +Usage usage = response.getBody().usage(); +System.out.println("Cache creation tokens: " + usage.cacheCreationInputTokens()); +System.out.println("Cache read tokens: " + usage.cacheReadInputTokens()); +---- + +=== Additional Configuration Options + +You can further customize cache behavior using the `AnthropicChatOptions.CacheControlConfiguration`. This configuration gives you more fine-grained control over the way the cache control blocks are applied. For example, in order to optimize for caching the largest possible content blocks, you can configure which `MessageType` should attempt to use cache control and for each `MessageType` what type of cache control to use (e.g. `EPHEMERAL` or `EPHEMERAL_1H`). + +Because Anthropic only allows caching of content > 1024 tokens, `AnthropicChatOptions.CacheControlConfiguration` has a default minimum text content length of 2000 characters to avoid creating cache control blocks for small content that cannot be cached. This can be configured to a different value if needed. Note that the character count is an approximation and the actual token count may vary. + +=== Implementation Details + +Cache control is configured through `AnthropicChatOptions` rather than individual messages. +This preserves compatibility when switching between different AI providers. +The cache control gets applied during request creation in `AnthropicChatModel`. + == Thinking Anthropic Claude models support a "thinking" feature that allows the model to show its reasoning process before providing a final answer. This feature enables more transparent and detailed problem-solving, particularly for complex questions that require step-by-step reasoning. @@ -359,7 +540,7 @@ Read more about xref:api/tools.adoc[Tool Calling]. == Multimodal -Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, pdf, images, data formats. +Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, pdf, images, data formats. === Images Currently, Anthropic Claude 3 supports the `base64` source type for `images`, and the `image/jpeg`, `image/png`, `image/gif`, and `image/webp` media types. @@ -537,4 +718,4 @@ Flux response = this.anthropicApi Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java[AnthropicApi.java]'s JavaDoc for further information. === Low-level API Examples -* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/chat/api/AnthropicApiIT.java[AnthropicApiIT.java] test provides some general examples how to use the lightweight library. +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/chat/api/AnthropicApiIT.java[AnthropicApiIT.java] test provides some general examples how to use the lightweight library. \ No newline at end of file