diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 0485e552584..80db41d766c 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; @@ -91,6 +92,7 @@ * @author Alexandros Pappas * @author Jonghoon Park * @author Soby Chacko + * @author Austin Dase * @since 1.0.0 */ public class AnthropicChatModel implements ChatModel { @@ -481,12 +483,81 @@ private Map mergeHttpHeaders(Map runtimeHttpHead return mergedHttpHeaders; } + private static ContentBlock cacheAwareContentBlock(String text, AtomicInteger usedCacheBlocks, + AnthropicChatOptions.CacheControlConfiguration cfg, MessageType type) { + return cacheAwareContentBlock(new ContentBlock(text), usedCacheBlocks, cfg, type); + } + + private static ContentBlock cacheAwareContentBlock(ContentBlock contentBlock, AtomicInteger usedCacheBlocks, + AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration, MessageType messageType) { + if (cacheControlConfiguration == null) { + return contentBlock; + } + + // Only proceed if this message is eligible for caching AND we can reserve a cache + // slot + if (isCacheEligible(contentBlock, cacheControlConfiguration, messageType) + && tryReserveCacheBlock(usedCacheBlocks, cacheControlConfiguration.getMaxCacheBlocks())) { + return ContentBlock.from(contentBlock) + .cacheControl(cacheControlConfiguration.getCacheTypeForMessageType(messageType).cacheControl()) + .build(); + } + + if (logger.isDebugEnabled()) { + final Integer minCacheBlockLength = cacheControlConfiguration.getMinBlockLengthForMessageType(messageType); + logger.debug( + "Skipping cache for messageType={}, used={}/{}; textLength={}, contentLength={}, minLength={}, cachableTypes={}", + messageType, usedCacheBlocks.get(), cacheControlConfiguration.getMaxCacheBlocks(), + safeLength(contentBlock.text()), safeLength(contentBlock.content()), minCacheBlockLength, + cacheControlConfiguration.getCachableMessageTypes()); + } + + return contentBlock; + } + + private static int safeLength(String s) { + return (s == null) ? 0 : s.length(); + } + + private static boolean isCacheEligible(ContentBlock block, + AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration, MessageType messageType) { + if (!cacheControlConfiguration.getCachableMessageTypes().contains(messageType)) { + return false; + } + + final int minCacheBlockLength = cacheControlConfiguration.getMinBlockLengthForMessageType(messageType); + + return isNullOrGreaterThanLength(block.text(), minCacheBlockLength) + && isNullOrGreaterThanLength(block.content(), minCacheBlockLength); + } + + private static boolean isNullOrGreaterThanLength(String s, int min) { + return s == null || s.length() >= min; + } + + /** + * Attempts to increment the counter only if we're still under the max. Returns true + * if we successfully reserved a slot. + */ + private static boolean tryReserveCacheBlock(AtomicInteger used, int max) { + int prev = used.getAndUpdate(v -> (v < max) ? (v + 1) : v); + return prev < max; + } + ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { - // Get cache control from options AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions(); - AnthropicApi.ChatCompletionRequest.CacheControl cacheControl = (requestOptions != null) - ? requestOptions.getCacheControl() : null; + AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration = (requestOptions != null) + ? requestOptions.getCacheControlConfiguration() : null; + + AtomicInteger usedCacheBlocks = new AtomicInteger(); + + List systemPrompt = prompt.getInstructions() + .stream() + .filter(m -> m.getMessageType() == MessageType.SYSTEM) + .map(m -> cacheAwareContentBlock(m.getText(), usedCacheBlocks, cacheControlConfiguration, + MessageType.SYSTEM)) + .toList(); List userMessages = prompt.getInstructions() .stream() @@ -494,21 +565,18 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { .map(message -> { if (message.getMessageType() == MessageType.USER) { List contents = new ArrayList<>(); - - // Apply cache control if enabled for user messages - if (cacheControl != null) { - contents.add(new ContentBlock(message.getText(), cacheControl)); - } - else { - contents.add(new ContentBlock(message.getText())); - } + contents.add(cacheAwareContentBlock(message.getText(), usedCacheBlocks, cacheControlConfiguration, + MessageType.USER)); if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia().stream().map(media -> { Type contentBlockType = getContentBlockTypeByMedia(media); var source = getSourceByMedia(media); return new ContentBlock(contentBlockType, source); - }).toList(); + }) + .map(contentBlock -> cacheAwareContentBlock(contentBlock, usedCacheBlocks, + cacheControlConfiguration, MessageType.USER)) + .toList(); contents.addAll(mediaContent); } } @@ -518,12 +586,15 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage) message; List contentBlocks = new ArrayList<>(); if (StringUtils.hasText(message.getText())) { - contentBlocks.add(new ContentBlock(message.getText())); + contentBlocks.add(cacheAwareContentBlock(message.getText(), usedCacheBlocks, + cacheControlConfiguration, MessageType.ASSISTANT)); } if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(), - ModelOptionsUtils.jsonToMap(toolCall.arguments()))); + ContentBlock contentBlock = new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(), + ModelOptionsUtils.jsonToMap(toolCall.arguments())); + contentBlocks.add(cacheAwareContentBlock(contentBlock, usedCacheBlocks, + cacheControlConfiguration, MessageType.ASSISTANT)); } } return new AnthropicMessage(contentBlocks, Role.ASSISTANT); @@ -533,6 +604,8 @@ else if (message.getMessageType() == MessageType.TOOL) { .stream() .map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(), toolResponse.responseData())) + .map(contentBlock -> cacheAwareContentBlock(contentBlock, usedCacheBlocks, + cacheControlConfiguration, MessageType.TOOL)) .toList(); return new AnthropicMessage(toolResponses, Role.USER); } @@ -542,14 +615,14 @@ else if (message.getMessageType() == MessageType.TOOL) { }) .toList(); - String systemPrompt = prompt.getInstructions() - .stream() - .filter(m -> m.getMessageType() == MessageType.SYSTEM) - .map(m -> m.getText()) - .collect(Collectors.joining(System.lineSeparator())); - - ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages, - systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream); + ChatCompletionRequest request = ChatCompletionRequest.builder() + .model(this.defaultOptions.getModel()) + .messages(userMessages) + .system(systemPrompt) + .maxTokens(this.defaultOptions.getMaxTokens()) + .temperature(this.defaultOptions.getTemperature()) + .stream(stream) + .build(); request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 16421eb04d0..f328050784d 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -32,6 +32,8 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.AnthropicCacheType; +import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -45,6 +47,7 @@ * @author Alexandros Pappas * @author Ilayaperumal Gopinathan * @author Soby Chacko + * @author Austin Dase * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -59,20 +62,10 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { private @JsonProperty("top_p") Double topP; private @JsonProperty("top_k") Integer topK; private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking; - /** - * Cache control for user messages. When set, enables caching for user messages. - * Uses the existing CacheControl record from AnthropicApi.ChatCompletionRequest. + * Cache control configuration options for the chat completion request. */ - private @JsonProperty("cache_control") ChatCompletionRequest.CacheControl cacheControl; - - public ChatCompletionRequest.CacheControl getCacheControl() { - return this.cacheControl; - } - - public void setCacheControl(ChatCompletionRequest.CacheControl cacheControl) { - this.cacheControl = cacheControl; - } + private @JsonProperty("cache_control") CacheControlConfiguration cacheControlConfiguration; /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat @@ -126,7 +119,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) - .cacheControl(fromOptions.getCacheControl()) + .cacheControlConfiguration(fromOptions.getCacheControlConfiguration()) .build(); } @@ -275,6 +268,15 @@ public void setHttpHeaders(Map httpHeaders) { this.httpHeaders = httpHeaders; } + @JsonIgnore + public CacheControlConfiguration getCacheControlConfiguration() { + return this.cacheControlConfiguration; + } + + public void setCacheControlConfiguration(CacheControlConfiguration cacheControlConfiguration) { + this.cacheControlConfiguration = cacheControlConfiguration; + } + @Override @SuppressWarnings("unchecked") public AnthropicChatOptions copy() { @@ -299,14 +301,221 @@ public boolean equals(Object o) { && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.httpHeaders, that.httpHeaders) - && Objects.equals(this.cacheControl, that.cacheControl); + && Objects.equals(this.cacheControlConfiguration, that.cacheControlConfiguration); } @Override public int hashCode() { return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP, this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext, this.httpHeaders, this.cacheControl); + this.toolContext, this.httpHeaders, this.cacheControlConfiguration); + } + + public static class CacheControlConfiguration { + + /** + * The Anthropic API allows a maximum of 4 cache blocks. By default, we will + * attempt to cache up to 4 blocks. + */ + private static final int DEFAULT_MAX_CACHE_BLOCKS = 4; + + /** + * The minimum text or content length for a message to be considered for caching. + * By default, we will only cache messages with at least 2000 characters - + * counting characters as a lightweight way to roughly estimate tokens. This helps + * to avoid caching very short messages that are unlikely to benefit from caching. + * Note: The Anthropic API has a minimum cacheable message length of 1024 tokens. + * See + * here + */ + private static final int DEFAULT_MIN_CACHE_BLOCK_LENGTH = 2000; + + /** + * The default set of message types that are considered for caching. By default, + * we will cache system, user, assistant, and tool messages. + */ + private static final Set DEFAULT_CACHABLE_MESSAGE_TYPES = Set.of(MessageType.SYSTEM, + MessageType.USER, MessageType.ASSISTANT, MessageType.TOOL); + + /** + * The default cache types to use for each message type. By default, we will use + * EPHEMERAL_1H for system messages and EPHEMERAL for user, assistant, and tool + * messages. See here + */ + private static final Map DEFAULT_MESSAGE_TYPE_CACHE_TYPES = Map.of( + MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H, MessageType.USER, AnthropicCacheType.EPHEMERAL, + MessageType.ASSISTANT, AnthropicCacheType.EPHEMERAL, MessageType.TOOL, AnthropicCacheType.EPHEMERAL); + + private int maxCacheBlocks = DEFAULT_MAX_CACHE_BLOCKS; + + private int minCacheBlockLength = DEFAULT_MIN_CACHE_BLOCK_LENGTH; + + private Set cachableMessageTypes = new HashSet<>(DEFAULT_CACHABLE_MESSAGE_TYPES); + + private Map messageTypeCacheTypes = new HashMap<>( + DEFAULT_MESSAGE_TYPE_CACHE_TYPES); + + /** + * To enable specific minimum block lengths per message type, use this map to + * override the default {@link #minCacheBlockLength} for specific message types. + * For example, you might want to set a higher minimum length for system messages + * and a lower minimum length for user messages. + */ + private Map messageTypeMinBlockLength = new HashMap<>(); + + public static CacheControlConfiguration DEFAULT = new CacheControlConfiguration(); + + public static Builder builder() { + return new Builder(); + } + + public int getMaxCacheBlocks() { + return this.maxCacheBlocks; + } + + public void setMaxCacheBlocks(int maxCacheBlocks) { + this.maxCacheBlocks = maxCacheBlocks; + } + + public int getMinCacheBlockLength() { + return this.minCacheBlockLength; + } + + public void setMinCacheBlockLength(int minCacheBlockLength) { + this.minCacheBlockLength = minCacheBlockLength; + } + + public Set getCachableMessageTypes() { + return this.cachableMessageTypes; + } + + public void setCachableMessageTypes(Set cachableMessageTypes) { + this.cachableMessageTypes = cachableMessageTypes; + } + + public Map getMessageTypeCacheTypes() { + return this.messageTypeCacheTypes; + } + + public void setMessageTypeCacheTypes(Map messageTypeCacheTypes) { + this.messageTypeCacheTypes = messageTypeCacheTypes; + } + + public Map getMessageTypeMinBlockLength() { + return this.messageTypeMinBlockLength; + } + + public void setMessageTypeMinBlockLength(Map messageTypeMinBlockLength) { + this.messageTypeMinBlockLength = messageTypeMinBlockLength; + } + + /** + * Get the cache type for a given message type. If the message type is not + * configured, return EPHEMERAL as the default. + * @param messageType the message type + * @return the cache type for the message type + */ + public AnthropicCacheType getCacheTypeForMessageType(MessageType messageType) { + return this.messageTypeCacheTypes.getOrDefault(messageType, AnthropicCacheType.EPHEMERAL); + } + + /** + * Get the minimum block length for a given message type. If the message type is + * not configured, return the default minimum block length. + * @param messageType + * @return the minimum block length for the message type + */ + public Integer getMinBlockLengthForMessageType(MessageType messageType) { + return this.messageTypeMinBlockLength.getOrDefault(messageType, this.minCacheBlockLength); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CacheControlConfiguration that)) { + return false; + } + return this.maxCacheBlocks == that.maxCacheBlocks && this.minCacheBlockLength == that.minCacheBlockLength + && Objects.equals(this.cachableMessageTypes, that.cachableMessageTypes) + && Objects.equals(this.messageTypeCacheTypes, that.messageTypeCacheTypes) + && Objects.equals(this.messageTypeMinBlockLength, that.messageTypeMinBlockLength); + } + + @Override + public int hashCode() { + return Objects.hash(this.maxCacheBlocks, this.minCacheBlockLength, this.cachableMessageTypes, + this.messageTypeCacheTypes, this.messageTypeMinBlockLength); + } + + @Override + public String toString() { + return "CacheControlConfiguration{" + "maxCacheBlocks=" + this.maxCacheBlocks + ", minCacheBlockLength=" + + this.minCacheBlockLength + ", cachableMessageTypes=" + this.cachableMessageTypes + + ", messageTypeCacheTypes=" + this.messageTypeCacheTypes + ", messageTypeMinBlockLength=" + + this.messageTypeMinBlockLength + '}'; + } + + public static class Builder { + + private final CacheControlConfiguration configuration = new CacheControlConfiguration(); + + public Builder() { + } + + public Builder maxCacheBlocks(int maxCacheBlocks) { + this.configuration.setMaxCacheBlocks(maxCacheBlocks); + return this; + } + + public Builder minCacheBlockLength(int minCacheBlockLength) { + this.configuration.setMinCacheBlockLength(minCacheBlockLength); + return this; + } + + public Builder cachableMessageTypes(Set cachableMessageTypes) { + this.configuration.setCachableMessageTypes(cachableMessageTypes); + return this; + } + + public Builder messageTypeCacheTypes(Map messageTypeCacheTypes) { + this.configuration.setMessageTypeCacheTypes(messageTypeCacheTypes); + return this; + } + + public Builder addCachableMessageType(MessageType messageType) { + if (this.configuration.getCachableMessageTypes() == null) { + this.configuration.setCachableMessageTypes(new HashSet<>()); + } + this.configuration.getCachableMessageTypes().add(messageType); + return this; + } + + public Builder addMessageTypeCacheType(MessageType messageType, AnthropicCacheType cacheType) { + if (this.configuration.getMessageTypeCacheTypes() == null) { + this.configuration.setMessageTypeCacheTypes(new HashMap<>()); + } + this.configuration.getMessageTypeCacheTypes().put(messageType, cacheType); + return this; + } + + public Builder minBlockLengthForMessageType(MessageType messageType, Integer minBlockLength) { + if (this.configuration.messageTypeMinBlockLength == null) { + this.configuration.messageTypeMinBlockLength = new HashMap<>(); + } + this.configuration.messageTypeMinBlockLength.put(messageType, minBlockLength); + return this; + } + + public CacheControlConfiguration build() { + return this.configuration; + } + + } + } public static class Builder { @@ -406,11 +615,8 @@ public Builder httpHeaders(Map httpHeaders) { return this; } - /** - * Set cache control for user messages - */ - public Builder cacheControl(ChatCompletionRequest.CacheControl cacheControl) { - this.options.cacheControl = cacheControl; + public Builder cacheControlConfiguration(CacheControlConfiguration cacheControlConfiguration) { + this.options.cacheControlConfiguration = cacheControlConfiguration; return this; } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index e7bb4d0406f..1f852ee293f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -67,6 +67,7 @@ * @author Claudio Silva Junior * @author Filip Hrisafov * @author Soby Chacko + * @author Austin Dase * @since 1.0.0 */ public final class AnthropicApi { @@ -173,14 +174,14 @@ public ResponseEntity chatCompletionEntity(ChatCompletio // @formatter:off return this.restClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletionResponse.class); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletionResponse.class); // @formatter:on } @@ -214,44 +215,44 @@ public Flux chatCompletionStream(ChatCompletionRequest c // @formatter:off return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) // @formatter:off - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) - .filter(event -> event.type() != EventType.PING) - // Detect if the chunk is part of a streaming function call. - .map(event -> { - logger.debug("Received event: {}", event); - - if (this.streamHelper.isToolUseStart(event)) { - isInsideTool.set(true); - } - return event; - }) - // Group all chunks belonging to the same function call. - .windowUntil(event -> { - if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - .concatMapIterable(window -> { - Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), - this.streamHelper::mergeToolUseEvents); - return List.of(monoChunk); - }) - .flatMap(mono -> mono) - .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) - .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:off + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) + .filter(event -> event.type() != EventType.PING) + // Detect if the chunk is part of a streaming function call. + .map(event -> { + logger.debug("Received event: {}", event); + + if (this.streamHelper.isToolUseStart(event)) { + isInsideTool.set(true); + } + return event; + }) + // Group all chunks belonging to the same function call. + .windowUntil(event -> { + if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + .concatMapIterable(window -> { + Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), + this.streamHelper::mergeToolUseEvents); + return List.of(monoChunk); + }) + .flatMap(mono -> mono) + .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) + .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); } private void addDefaultHeadersIfMissing(HttpHeaders headers) { @@ -358,7 +359,7 @@ public enum Role { // @formatter:off /** * The user role. - */ + */ @JsonProperty("user") USER, @@ -514,28 +515,30 @@ public interface StreamEvent { @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest( // @formatter:off - @JsonProperty("model") String model, - @JsonProperty("messages") List messages, - @JsonProperty("system") String system, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("metadata") Metadata metadata, - @JsonProperty("stop_sequences") List stopSequences, - @JsonProperty("stream") Boolean stream, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("top_k") Integer topK, - @JsonProperty("tools") List tools, - @JsonProperty("thinking") ThinkingConfig thinking) { + @JsonProperty("model") String model, + @JsonProperty("messages") List messages, + @JsonProperty("system") List system, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("metadata") Metadata metadata, + @JsonProperty("stop_sequences") List stopSequences, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("top_k") Integer topK, + @JsonProperty("tools") List tools, + @JsonProperty("thinking") ThinkingConfig thinking) { // @formatter:on public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null); + this(model, messages, List.of(new ContentBlock(system)), maxTokens, null, null, stream, temperature, null, + null, null, null); } public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, List stopSequences, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null); + this(model, messages, List.of(new ContentBlock(system)), maxTokens, null, stopSequences, stream, + temperature, null, null, null, null); } public static ChatCompletionRequestBuilder builder() { @@ -559,12 +562,20 @@ public record Metadata(@JsonProperty("user_id") String userId) { } + @JsonInclude(Include.NON_NULL) + public record System(@JsonProperty("user_id") String userId) { + + } + /** * @param type is the cache type supported by anthropic. Doc */ @JsonInclude(Include.NON_NULL) - public record CacheControl(String type) { + public record CacheControl(String type, String ttl) { + public CacheControl(String type) { + this(type, null); + } } /** @@ -587,7 +598,7 @@ public static final class ChatCompletionRequestBuilder { private List messages; - private String system; + private List system; private Integer maxTokens; @@ -641,6 +652,10 @@ public ChatCompletionRequestBuilder messages(List messages) { } public ChatCompletionRequestBuilder system(String system) { + return this.system(List.of(new ContentBlock(system))); + } + + public ChatCompletionRequestBuilder system(List system) { this.system = system; return this; } @@ -727,8 +742,8 @@ public ChatCompletionRequest build() { @JsonInclude(Include.NON_NULL) public record AnthropicMessage( // @formatter:off - @JsonProperty("content") List content, - @JsonProperty("role") Role role) { + @JsonProperty("content") List content, + @JsonProperty("role") Role role) { // @formatter:on } @@ -752,31 +767,31 @@ public record AnthropicMessage( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlock( // @formatter:off - @JsonProperty("type") Type type, - @JsonProperty("source") Source source, - @JsonProperty("text") String text, + @JsonProperty("type") Type type, + @JsonProperty("source") Source source, + @JsonProperty("text") String text, - // applicable only for streaming responses. - @JsonProperty("index") Integer index, + // applicable only for streaming responses. + @JsonProperty("index") Integer index, - // tool_use response only - @JsonProperty("id") String id, - @JsonProperty("name") String name, - @JsonProperty("input") Map input, + // tool_use response only + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input, - // tool_result response only - @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content, + // tool_result response only + @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("content") String content, - // Thinking only - @JsonProperty("signature") String signature, - @JsonProperty("thinking") String thinking, + // Thinking only + @JsonProperty("signature") String signature, + @JsonProperty("thinking") String thinking, - // Redacted Thinking only - @JsonProperty("data") String data, + // Redacted Thinking only + @JsonProperty("data") String data, - // cache object - @JsonProperty("cache_control") CacheControl cacheControl + // cache object + @JsonProperty("cache_control") CacheControl cacheControl ) { // @formatter:on @@ -852,6 +867,10 @@ public ContentBlock(Type type, String id, String name, Map input this(type, null, null, null, id, name, input, null, null, null, null, null, null); } + public static ContentBlockBuilder from(ContentBlock contentBlock) { + return new ContentBlockBuilder(contentBlock); + } + /** * The ContentBlock type. */ @@ -955,10 +974,10 @@ public String getValue() { @JsonInclude(Include.NON_NULL) public record Source( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("media_type") String mediaType, - @JsonProperty("data") String data, - @JsonProperty("url") String url) { + @JsonProperty("type") String type, + @JsonProperty("media_type") String mediaType, + @JsonProperty("data") String data, + @JsonProperty("url") String url) { // @formatter:on /** @@ -976,6 +995,122 @@ public Source(String url) { } + public static class ContentBlockBuilder { + + private Type type; + + private Source source; + + private String text; + + private Integer index; + + private String id; + + private String name; + + private Map input; + + private String toolUseId; + + private String content; + + private String signature; + + private String thinking; + + private String data; + + private CacheControl cacheControl; + + public ContentBlockBuilder(ContentBlock contentBlock) { + this.type = contentBlock.type; + this.source = contentBlock.source; + this.text = contentBlock.text; + this.index = contentBlock.index; + this.id = contentBlock.id; + this.name = contentBlock.name; + this.input = contentBlock.input; + this.toolUseId = contentBlock.toolUseId; + this.content = contentBlock.content; + this.signature = contentBlock.signature; + this.thinking = contentBlock.thinking; + this.data = contentBlock.data; + this.cacheControl = contentBlock.cacheControl; + } + + public ContentBlockBuilder type(Type type) { + this.type = type; + return this; + } + + public ContentBlockBuilder source(Source source) { + this.source = source; + return this; + } + + public ContentBlockBuilder text(String text) { + this.text = text; + return this; + } + + public ContentBlockBuilder index(Integer index) { + this.index = index; + return this; + } + + public ContentBlockBuilder id(String id) { + this.id = id; + return this; + } + + public ContentBlockBuilder name(String name) { + this.name = name; + return this; + } + + public ContentBlockBuilder input(Map input) { + this.input = input; + return this; + } + + public ContentBlockBuilder toolUseId(String toolUseId) { + this.toolUseId = toolUseId; + return this; + } + + public ContentBlockBuilder content(String content) { + this.content = content; + return this; + } + + public ContentBlockBuilder signature(String signature) { + this.signature = signature; + return this; + } + + public ContentBlockBuilder thinking(String thinking) { + this.thinking = thinking; + return this; + } + + public ContentBlockBuilder data(String data) { + this.data = data; + return this; + } + + public ContentBlockBuilder cacheControl(CacheControl cacheControl) { + this.cacheControl = cacheControl; + return this; + } + + public ContentBlock build() { + return new ContentBlock(this.type, this.source, this.text, this.index, this.id, this.name, this.input, + this.toolUseId, this.content, this.signature, this.thinking, this.data, this.cacheControl); + } + + } + } /////////////////////////////////////// @@ -992,9 +1127,9 @@ public Source(String url) { @JsonInclude(Include.NON_NULL) public record Tool( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("input_schema") Map inputSchema) { + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("input_schema") Map inputSchema) { // @formatter:on } @@ -1019,14 +1154,14 @@ public record Tool( @JsonIgnoreProperties(ignoreUnknown = true) public record ChatCompletionResponse( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("type") String type, - @JsonProperty("role") Role role, - @JsonProperty("content") List content, - @JsonProperty("model") String model, - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence, - @JsonProperty("usage") Usage usage) { + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("role") Role role, + @JsonProperty("content") List content, + @JsonProperty("model") String model, + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence, + @JsonProperty("usage") Usage usage) { // @formatter:on } @@ -1042,19 +1177,19 @@ public record ChatCompletionResponse( @JsonIgnoreProperties(ignoreUnknown = true) public record Usage( // @formatter:off - @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens, - @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, - @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { + @JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, + @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { // @formatter:off } - /// ECB STOP + /// ECB STOP /** * Special event used to aggregate multiple tool use events into a single event with * list of aggregated ContentBlockToolUse. - */ + */ public static class ToolUseAggregationEvent implements StreamEvent { private Integer index; @@ -1073,17 +1208,17 @@ public EventType type() { } /** - * Get tool content blocks. - * @return The tool content blocks. - */ + * Get tool content blocks. + * @return The tool content blocks. + */ public List getToolContentBlocks() { return this.toolContentBlocks; } /** - * Check if the event is empty. - * @return True if the event is empty, false otherwise. - */ + * Check if the event is empty. + * @return True if the event is empty, false otherwise. + */ public boolean isEmpty() { return (this.index == null || this.id == null || this.name == null); } @@ -1121,30 +1256,30 @@ void squashIntoContentBlock() { @Override public String toString() { return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson=" - + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; + + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; } } - /////////////////////////////////////// - /// MESSAGE EVENTS - /////////////////////////////////////// + /////////////////////////////////////// + /// MESSAGE EVENTS + /////////////////////////////////////// - // MESSAGE START EVENT + // MESSAGE START EVENT /** * Content block start event. * @param type The event type. * @param index The index of the content block. * @param contentBlock The content block body. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStartEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -1158,31 +1293,31 @@ public interface ContentBlockBody { } /** - * Tool use content block. - * @param type The content block type. - * @param id The tool use id. - * @param name The tool use name. - * @param input The tool use input. - */ + * Tool use content block. + * @param type The content block type. + * @param id The tool use id. + * @param name The tool use name. + * @param input The tool use input. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockToolUse( - @JsonProperty("type") String type, - @JsonProperty("id") String id, - @JsonProperty("name") String name, - @JsonProperty("input") Map input) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input) implements ContentBlockBody { } /** - * Text content block. - * @param type The content block type. - * @param text The text content. - */ + * Text content block. + * @param type The content block type. + * @param text The text content. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockText( - @JsonProperty("type") String type, - @JsonProperty("text") String text) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("text") String text) implements ContentBlockBody { } /** @@ -1192,9 +1327,9 @@ public record ContentBlockText( */ @JsonInclude(Include.NON_NULL) public record ContentBlockThinking( - @JsonProperty("type") String type, - @JsonProperty("thinking") String thinking, - @JsonProperty("signature") String signature) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking, + @JsonProperty("signature") String signature) implements ContentBlockBody { } } // @formatter:on @@ -1212,9 +1347,9 @@ public record ContentBlockThinking( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -1231,24 +1366,24 @@ public interface ContentBlockDeltaBody { * Text content block delta. * @param type The content block type. * @param text The text content. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaText( - @JsonProperty("type") String type, - @JsonProperty("text") String text) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("text") String text) implements ContentBlockDeltaBody { } /** - * JSON content block delta. - * @param type The content block type. - * @param partialJson The partial JSON content. - */ + * JSON content block delta. + * @param type The content block type. + * @param partialJson The partial JSON content. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaJson( - @JsonProperty("type") String type, - @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { } /** @@ -1259,8 +1394,8 @@ public record ContentBlockDeltaJson( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaThinking( - @JsonProperty("type") String type, - @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { } /** @@ -1271,8 +1406,8 @@ public record ContentBlockDeltaThinking( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaSignature( - @JsonProperty("type") String type, - @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { } } // @formatter:on @@ -1289,8 +1424,8 @@ public record ContentBlockDeltaSignature( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStopEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index) implements StreamEvent { } // @formatter:on @@ -1303,8 +1438,8 @@ public record ContentBlockStopEvent( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageStartEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { } // @formatter:on @@ -1319,29 +1454,29 @@ public record MessageStartEvent(// @formatter:off @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDeltaEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("delta") MessageDelta delta, - @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("delta") MessageDelta delta, + @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { /** - * @param stopReason The stop reason. - * @param stopSequence The stop sequence. - */ + * @param stopReason The stop reason. + * @param stopSequence The stop sequence. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDelta( - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence) { + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence) { } /** * Message delta usage. * @param outputTokens The output tokens. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDeltaUsage( - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("output_tokens") Integer outputTokens) { } } // @formatter:on @@ -1355,7 +1490,7 @@ public record MessageDeltaUsage( @JsonIgnoreProperties(ignoreUnknown = true) public record MessageStopEvent( //@formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { + @JsonProperty("type") EventType type) implements StreamEvent { } // @formatter:on @@ -1372,19 +1507,19 @@ public record MessageStopEvent( @JsonIgnoreProperties(ignoreUnknown = true) public record ErrorEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("error") Error error) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("error") Error error) implements StreamEvent { /** * Error body. * @param type The error type. * @param message The error message. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record Error( - @JsonProperty("type") String type, - @JsonProperty("message") String message) { + @JsonProperty("type") String type, + @JsonProperty("message") String message) { } } // @formatter:on @@ -1401,7 +1536,7 @@ public record Error( @JsonIgnoreProperties(ignoreUnknown = true) public record PingEvent( // @formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { + @JsonProperty("type") EventType type) implements StreamEvent { } // @formatter:on diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java index 0348670573a..74ced8490be 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -32,13 +32,19 @@ * Caching * @author Claudio Silva Junior * @author Soby Chacko + * @author Austin Dase */ public enum AnthropicCacheType { /** * Ephemeral cache with 5-minute lifetime, refreshed on each use. */ - EPHEMERAL(() -> new CacheControl("ephemeral")); + EPHEMERAL(() -> new CacheControl("ephemeral")), + + /** + * Ephemeral cache with 1-hour lifetime, refreshed on each use. + */ + EPHEMERAL_1H(() -> new CacheControl("ephemeral", "1h")); private final Supplier value; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index c522f75cf4b..f8088270f44 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -32,7 +32,6 @@ import reactor.core.publisher.Flux; import org.springframework.ai.anthropic.api.AnthropicApi; -import org.springframework.ai.anthropic.api.AnthropicCacheType; import org.springframework.ai.anthropic.api.tool.MockWeatherService; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; @@ -407,7 +406,7 @@ void thinkingTest() { .temperature(1.0) // temperature should be set to 1 when thinking is enabled .maxTokens(8192) .thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && < - // max_tokens + // max_tokens .build(); ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); @@ -439,7 +438,7 @@ void thinkingWithStreamingTest() { .temperature(1.0) // Temperature should be set to 1 when thinking is enabled .maxTokens(8192) .thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && < - // max_tokens + // max_tokens .build(); Flux responseFlux = this.streamingChatModel @@ -507,7 +506,7 @@ void chatWithPromptCacheViaOptions() { ChatResponse firstResponse = this.chatModel.call(new Prompt(List.of(new UserMessage(largeContent)), AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue()) - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) .maxTokens(100) .temperature(0.8) .build())); @@ -523,7 +522,7 @@ void chatWithPromptCacheViaOptions() { ChatResponse secondResponse = this.chatModel.call(new Prompt(List.of(new UserMessage(largeContent)), AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue()) - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) .maxTokens(100) .temperature(0.8) .build())); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java index 6cc4c689022..cb6b1a7bfa6 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java @@ -22,9 +22,10 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; -import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; +import org.springframework.ai.anthropic.AnthropicChatOptions.CacheControlConfiguration; import org.springframework.ai.anthropic.api.AnthropicCacheType; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; import static org.assertj.core.api.Assertions.assertThat; @@ -33,6 +34,7 @@ * * @author Alexandros Pappas * @author Soby Chacko + * @author Austin Dase */ class AnthropicChatOptionsTests { @@ -46,10 +48,14 @@ void testBuilderWithAllFields() { .topP(0.8) .topK(50) .metadata(new Metadata("userId_123")) + .cacheControlConfiguration(CacheControlConfiguration.DEFAULT) .build(); - assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata") - .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123")); + assertThat(options) + .extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata", + "cacheControlConfiguration") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), + CacheControlConfiguration.DEFAULT); } @Test @@ -63,6 +69,7 @@ void testCopy() { .topK(50) .metadata(new Metadata("userId_123")) .toolContext(Map.of("key1", "value1")) + .cacheControlConfiguration(CacheControlConfiguration.builder().minCacheBlockLength(100).build()) .build(); AnthropicChatOptions copied = original.copy(); @@ -71,6 +78,7 @@ void testCopy() { // Ensure deep copy assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + assertThat(copied.getCacheControlConfiguration()).isEqualTo(original.getCacheControlConfiguration()); } @Test @@ -83,6 +91,7 @@ void testSetters() { options.setTopP(0.8); options.setStopSequences(List.of("stop1", "stop2")); options.setMetadata(new Metadata("userId_123")); + options.setCacheControlConfiguration(CacheControlConfiguration.DEFAULT); assertThat(options.getModel()).isEqualTo("test-model"); assertThat(options.getMaxTokens()).isEqualTo(100); @@ -91,6 +100,7 @@ void testSetters() { assertThat(options.getTopP()).isEqualTo(0.8); assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123")); + assertThat(options.getCacheControlConfiguration()).isEqualTo(CacheControlConfiguration.DEFAULT); } @Test @@ -103,6 +113,7 @@ void testDefaultValues() { assertThat(options.getTopP()).isNull(); assertThat(options.getStopSequences()).isNull(); assertThat(options.getMetadata()).isNull(); + assertThat(options.getCacheControlConfiguration()).isNull(); } @Test @@ -136,6 +147,7 @@ void testCopyWithEmptyOptions() { assertThat(copiedOptions.getModel()).isNull(); assertThat(copiedOptions.getMaxTokens()).isNull(); assertThat(copiedOptions.getTemperature()).isNull(); + assertThat(copiedOptions.getCacheControlConfiguration()).isNull(); } @Test @@ -202,6 +214,8 @@ void testChainedBuilderMethods() { .stopSequences(List.of("stop")) .metadata(new Metadata("user_456")) .toolContext(Map.of("context", "value")) + .cacheControlConfiguration( + CacheControlConfiguration.builder().minCacheBlockLength(50).maxCacheBlocks(10).build()) .build(); // Verify all chained methods worked @@ -213,6 +227,9 @@ void testChainedBuilderMethods() { assertThat(options.getStopSequences()).containsExactly("stop"); assertThat(options.getMetadata()).isEqualTo(new Metadata("user_456")); assertThat(options.getToolContext()).containsEntry("context", "value"); + assertThat(options.getCacheControlConfiguration()).isNotNull(); + assertThat(options.getCacheControlConfiguration().getMinCacheBlockLength()).isEqualTo(50); + assertThat(options.getCacheControlConfiguration().getMaxCacheBlocks()).isEqualTo(10); } @Test @@ -227,6 +244,7 @@ void testSettersWithNullValues() { options.setStopSequences(null); options.setMetadata(null); options.setToolContext(null); + options.setCacheControlConfiguration(null); assertThat(options.getModel()).isNull(); assertThat(options.getMaxTokens()).isNull(); @@ -236,6 +254,7 @@ void testSettersWithNullValues() { assertThat(options.getStopSequences()).isNull(); assertThat(options.getMetadata()).isNull(); assertThat(options.getToolContext()).isNull(); + assertThat(options.getCacheControlConfiguration()).isNull(); } @Test @@ -302,6 +321,8 @@ void testCopyPreservesAllFields() { .topK(60) .metadata(new Metadata("comprehensive_test")) .toolContext(Map.of("key1", "value1", "key2", "value2")) + .cacheControlConfiguration( + CacheControlConfiguration.builder().minCacheBlockLength(200).maxCacheBlocks(5).build()) .build(); AnthropicChatOptions copied = original.copy(); @@ -315,6 +336,7 @@ void testCopyPreservesAllFields() { assertThat(copied.getTopK()).isEqualTo(original.getTopK()); assertThat(copied.getMetadata()).isEqualTo(original.getMetadata()); assertThat(copied.getToolContext()).isEqualTo(original.getToolContext()); + assertThat(copied.getCacheControlConfiguration()).isEqualTo(original.getCacheControlConfiguration()); // Ensure deep copy for collections assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); @@ -475,36 +497,37 @@ void testSetterOverwriteBehavior() { } @Test - void testCacheControlBuilder() { - CacheControl cacheControl = AnthropicCacheType.EPHEMERAL.cacheControl(); + void testCacheControlConfigurationBuilder() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); AnthropicChatOptions options = AnthropicChatOptions.builder() .model("test-model") - .cacheControl(cacheControl) + .cacheControlConfiguration(config) .build(); - assertThat(options.getCacheControl()).isEqualTo(cacheControl); - assertThat(options.getCacheControl().type()).isEqualTo("ephemeral"); + assertThat(options.getCacheControlConfiguration()).isEqualTo(config); + // Default max cache blocks is 4 per configuration defaults + assertThat(options.getCacheControlConfiguration().getMaxCacheBlocks()).isEqualTo(4); } @Test void testCacheControlDefaultValue() { AnthropicChatOptions options = new AnthropicChatOptions(); - assertThat(options.getCacheControl()).isNull(); + assertThat(options.getCacheControlConfiguration()).isNull(); } @Test - void testCacheControlEqualsAndHashCode() { - CacheControl cacheControl = AnthropicCacheType.EPHEMERAL.cacheControl(); + void testCacheControlConfigurationEqualsAndHashCode() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); AnthropicChatOptions options1 = AnthropicChatOptions.builder() .model("test-model") - .cacheControl(cacheControl) + .cacheControlConfiguration(config) .build(); AnthropicChatOptions options2 = AnthropicChatOptions.builder() .model("test-model") - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .cacheControlConfiguration(config) .build(); AnthropicChatOptions options3 = AnthropicChatOptions.builder().model("test-model").build(); @@ -517,31 +540,35 @@ void testCacheControlEqualsAndHashCode() { } @Test - void testCacheControlCopy() { - CacheControl originalCacheControl = AnthropicCacheType.EPHEMERAL.cacheControl(); + void testCacheControlConfigurationCopy() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); AnthropicChatOptions original = AnthropicChatOptions.builder() .model("test-model") - .cacheControl(originalCacheControl) + .cacheControlConfiguration(config) .build(); AnthropicChatOptions copied = original.copy(); assertThat(copied).isNotSameAs(original).isEqualTo(original); - assertThat(copied.getCacheControl()).isEqualTo(original.getCacheControl()); - assertThat(copied.getCacheControl()).isEqualTo(originalCacheControl); + assertThat(copied.getCacheControlConfiguration()).isEqualTo(original.getCacheControlConfiguration()); + // copy() preserves the same configuration instance + assertThat(copied.getCacheControlConfiguration()).isSameAs(config); } @Test - void testCacheControlWithNullValue() { - AnthropicChatOptions options = AnthropicChatOptions.builder().model("test-model").cacheControl(null).build(); + void testCacheControlConfigurationWithNullValue() { + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model("test-model") + .cacheControlConfiguration(null) + .build(); - assertThat(options.getCacheControl()).isNull(); + assertThat(options.getCacheControlConfiguration()).isNull(); } @Test - void testBuilderWithAllFieldsIncludingCacheControl() { - CacheControl cacheControl = AnthropicCacheType.EPHEMERAL.cacheControl(); + void testBuilderWithAllFieldsIncludingCacheControlConfiguration() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); AnthropicChatOptions options = AnthropicChatOptions.builder() .model("test-model") @@ -551,32 +578,121 @@ void testBuilderWithAllFieldsIncludingCacheControl() { .topP(0.8) .topK(50) .metadata(new Metadata("userId_123")) - .cacheControl(cacheControl) + .cacheControlConfiguration(config) .build(); assertThat(options) .extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata", - "cacheControl") + "cacheControlConfiguration") .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), - cacheControl); + config); } @Test - void testCacheControlMutationDoesNotAffectOriginal() { - CacheControl originalCacheControl = AnthropicCacheType.EPHEMERAL.cacheControl(); + void testCacheControlConfigurationMutationDoesNotAffectOriginal() { + CacheControlConfiguration config = CacheControlConfiguration.builder().build(); AnthropicChatOptions original = AnthropicChatOptions.builder() .model("original-model") - .cacheControl(originalCacheControl) + .cacheControlConfiguration(config) .build(); AnthropicChatOptions copy = original.copy(); - copy.setCacheControl(null); + copy.setCacheControlConfiguration(null); // Original should remain unchanged - assertThat(original.getCacheControl()).isEqualTo(originalCacheControl); - // Copy should have null cache control - assertThat(copy.getCacheControl()).isNull(); + assertThat(original.getCacheControlConfiguration()).isEqualTo(config); + // Copy should have null cache control configuration + assertThat(copy.getCacheControlConfiguration()).isNull(); + } + + @Test + void testCacheControlConfigurationDefaults() { + CacheControlConfiguration defaults = new CacheControlConfiguration(); + + assertThat(defaults.getMaxCacheBlocks()).isEqualTo(4); + assertThat(defaults.getMinCacheBlockLength()).isEqualTo(2000); + assertThat(defaults.getCachableMessageTypes()).containsExactlyInAnyOrder(MessageType.SYSTEM, MessageType.USER, + MessageType.ASSISTANT, MessageType.TOOL); + assertThat(defaults.getMessageTypeCacheTypes()) + .containsEntry(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H) + .containsEntry(MessageType.USER, AnthropicCacheType.EPHEMERAL) + .containsEntry(MessageType.ASSISTANT, AnthropicCacheType.EPHEMERAL) + .containsEntry(MessageType.TOOL, AnthropicCacheType.EPHEMERAL); + + // Static DEFAULT matches a fresh instance by value + assertThat(CacheControlConfiguration.DEFAULT).isEqualTo(defaults); + } + + @Test + void testCacheTypeLookupDefaultAndOverride() { + // Start from empty mapping then add specific override + CacheControlConfiguration config = CacheControlConfiguration.builder() + .messageTypeCacheTypes(null) // force builder to initialize map on demand + .addMessageTypeCacheType(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H) + .build(); + + // Unmapped types default to EPHEMERAL + assertThat(config.getCacheTypeForMessageType(MessageType.USER)).isEqualTo(AnthropicCacheType.EPHEMERAL); + assertThat(config.getCacheTypeForMessageType(MessageType.ASSISTANT)).isEqualTo(AnthropicCacheType.EPHEMERAL); + assertThat(config.getCacheTypeForMessageType(MessageType.TOOL)).isEqualTo(AnthropicCacheType.EPHEMERAL); + + // Mapped type returns configured value + assertThat(config.getCacheTypeForMessageType(MessageType.SYSTEM)).isEqualTo(AnthropicCacheType.EPHEMERAL_1H); + } + + @Test + void testMinBlockLengthLookupDefaultAndOverride() { + CacheControlConfiguration config = CacheControlConfiguration.builder() + .minCacheBlockLength(3000) + .minBlockLengthForMessageType(MessageType.SYSTEM, 1500) + .build(); + + // Override applies for SYSTEM + assertThat(config.getMinBlockLengthForMessageType(MessageType.SYSTEM)).isEqualTo(1500); + // Others use global default + assertThat(config.getMinBlockLengthForMessageType(MessageType.USER)).isEqualTo(3000); + assertThat(config.getMinBlockLengthForMessageType(MessageType.ASSISTANT)).isEqualTo(3000); + assertThat(config.getMinBlockLengthForMessageType(MessageType.TOOL)).isEqualTo(3000); + } + + @Test + void testBuilderAddersInitializeNullCollections() { + CacheControlConfiguration config = CacheControlConfiguration.builder() + .cachableMessageTypes(null) + .addCachableMessageType(MessageType.USER) + .messageTypeCacheTypes(null) + .addMessageTypeCacheType(MessageType.USER, AnthropicCacheType.EPHEMERAL) + .minBlockLengthForMessageType(MessageType.USER, 1234) + .build(); + + assertThat(config.getCachableMessageTypes()).contains(MessageType.USER); + assertThat(config.getMessageTypeCacheTypes()).containsEntry(MessageType.USER, AnthropicCacheType.EPHEMERAL); + assertThat(config.getMinBlockLengthForMessageType(MessageType.USER)).isEqualTo(1234); + } + + @Test + void testCacheControlConfigurationEqualityAcrossInstances() { + CacheControlConfiguration c1 = CacheControlConfiguration.builder() + .maxCacheBlocks(2) + .minCacheBlockLength(1111) + .cachableMessageTypes(new java.util.HashSet<>(java.util.List.of(MessageType.USER, MessageType.SYSTEM))) + .messageTypeCacheTypes(java.util.Map.of(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H)) + .minBlockLengthForMessageType(MessageType.SYSTEM, 999) + .build(); + + CacheControlConfiguration c2 = CacheControlConfiguration.builder() + .maxCacheBlocks(2) + .minCacheBlockLength(1111) + .cachableMessageTypes(new java.util.HashSet<>(java.util.List.of(MessageType.SYSTEM, MessageType.USER))) // different + // order + .messageTypeCacheTypes( + new java.util.HashMap<>(java.util.Map.of(MessageType.SYSTEM, AnthropicCacheType.EPHEMERAL_1H))) + .minBlockLengthForMessageType(MessageType.SYSTEM, 999) + .build(); + + assertThat(c1).isEqualTo(c2); + assertThat(c1.hashCode()).isEqualTo(c2.hashCode()); } } diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 256fe679427..668c1e5a0d7 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -479,6 +479,8 @@ Prompt buildRequestPrompt(Prompt prompt) { runtimeOptions.getGoogleSearchRetrieval(), this.defaultOptions.getGoogleSearchRetrieval())); requestOptions.setSafetySettings(ModelOptionsUtils.mergeOption(runtimeOptions.getSafetySettings(), this.defaultOptions.getSafetySettings())); + requestOptions + .setLabels(ModelOptionsUtils.mergeOption(runtimeOptions.getLabels(), this.defaultOptions.getLabels())); } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); @@ -488,6 +490,7 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setGoogleSearchRetrieval(this.defaultOptions.getGoogleSearchRetrieval()); requestOptions.setSafetySettings(this.defaultOptions.getSafetySettings()); + requestOptions.setLabels(this.defaultOptions.getLabels()); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); @@ -680,6 +683,9 @@ GeminiRequest createGeminiRequest(Prompt prompt) { configBuilder .thinkingConfig(ThinkingConfig.builder().thinkingBudget(requestOptions.getThinkingBudget()).build()); } + if (requestOptions.getLabels() != null && !requestOptions.getLabels().isEmpty()) { + configBuilder.labels(requestOptions.getLabels()); + } // Add safety settings if (!CollectionUtils.isEmpty(requestOptions.getSafetySettings())) { diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index 4d5eb076166..7e05e5fc921 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -144,6 +144,9 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Map labels = new HashMap<>(); // @formatter:on public static Builder builder() { @@ -170,6 +173,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); options.setToolContext(fromOptions.getToolContext()); options.setThinkingBudget(fromOptions.getThinkingBudget()); + options.setLabels(fromOptions.getLabels()); return options; } @@ -332,6 +336,15 @@ public void setSafetySettings(List safetySettings) { this.safetySettings = safetySettings; } + public Map getLabels() { + return this.labels; + } + + public void setLabels(Map labels) { + Assert.notNull(labels, "labels must not be null"); + this.labels = labels; + } + @Override public Map getToolContext() { return this.toolContext; @@ -363,7 +376,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) - && Objects.equals(this.toolContext, that.toolContext); + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels); } @Override @@ -371,7 +384,7 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, - this.safetySettings, this.internalToolExecutionEnabled, this.toolContext); + this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels); } @Override @@ -382,7 +395,8 @@ public String toString() { + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" - + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}'; + + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels + + '}'; } @Override @@ -510,6 +524,12 @@ public Builder thinkingBudget(Integer thinkingBudget) { return this; } + public Builder labels(Map labels) { + Assert.notNull(labels, "labels must not be null"); + this.options.labels = labels; + return this; + } + public GoogleGenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java index 4d8d45cbd11..3521213bfb5 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java @@ -16,6 +16,8 @@ package org.springframework.ai.google.genai; +import java.util.Map; + import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -104,6 +106,29 @@ public void testEqualsAndHashCodeWithThinkingBudget() { assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); } + @Test + public void testEqualsAndHashCodeWithLabels() { + GoogleGenAiChatOptions options1 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + GoogleGenAiChatOptions options2 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + GoogleGenAiChatOptions options3 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "other-org")) + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + assertThat(options1).isNotEqualTo(options3); + assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); + } + @Test public void testToStringWithThinkingBudget() { GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() @@ -116,4 +141,16 @@ public void testToStringWithThinkingBudget() { assertThat(toString).contains("test-model"); } + @Test + public void testToStringWithLabels() { + GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + String toString = options.toString(); + assertThat(toString).contains("labels={org=my-org}"); + assertThat(toString).contains("test-model"); + } + } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index f8d08b31e8a..9d1163168fc 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -208,10 +208,11 @@ Prompt caching is currently supported on Claude Opus 4, Claude Sonnet 4, Claude Spring AI supports Anthropic's cache types through the `AnthropicCacheType` enum: * `EPHEMERAL`: Temporary caching suitable for short-term reuse within a session +* `EPHEMERAL_1H`: Extended ephemeral caching with a 1-hour lifetime. Note that this cache type incurs higher costs compared to standard ephemeral caching. === Enabling Prompt Caching -To enable prompt caching, use the `cacheControl()` method in `AnthropicChatOptions`: +To enable prompt caching, use the `cacheControlConfiguration()` method in `AnthropicChatOptions.Builder`: ==== Basic Usage @@ -222,8 +223,7 @@ ChatResponse response = chatModel.call( new Prompt( List.of(new UserMessage("Large content to be cached...")), AnthropicChatOptions.builder() - .model("claude-3-5-sonnet-latest") - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .model("claude-3-5-sonnet-latest").cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) .build() ) ); @@ -238,7 +238,7 @@ String response = ChatClient.create(chatModel) .user("Analyze this large document: " + document) .options(AnthropicChatOptions.builder() .model("claude-3-5-sonnet-latest") - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) .build()) .call() .content(); @@ -259,7 +259,7 @@ ChatResponse firstResponse = chatModel.call( List.of(new UserMessage(largeContent)), AnthropicChatOptions.builder() .model("claude-3-haiku-20240307") - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) .maxTokens(100) .build() ) @@ -278,7 +278,7 @@ ChatResponse secondResponse = chatModel.call( List.of(new UserMessage(largeContent)), AnthropicChatOptions.builder() .model("claude-3-haiku-20240307") - .cacheControl(AnthropicCacheType.EPHEMERAL.cacheControl()) + .cacheControlConfiguration(AnthropicChatOptions.CacheControlConfiguration.DEFAULT) .maxTokens(100) .build() ) @@ -360,6 +360,12 @@ System.out.println("Cache creation tokens: " + usage.cacheCreationInputTokens()) System.out.println("Cache read tokens: " + usage.cacheReadInputTokens()); ---- +=== Additional Configuration Options + +You can further customize cache behavior using the `AnthropicChatOptions.CacheControlConfiguration`. This configuration gives you more fine-grained control over the way the cache control blocks are applied. For example, in order to optimize for caching the largest possible content blocks, you can configure which `MessageType` should attempt to use cache control and for each `MessageType` what type of cache control to use (e.g. `EPHEMERAL` or `EPHEMERAL_1H`). + +Because Anthropic only allows caching of content > 1024 tokens, `AnthropicChatOptions.CacheControlConfiguration` has a default minimum text content length of 2000 characters to avoid creating cache control blocks for small content that cannot be cached. This can be configured to a different value if needed. Note that the character count is an approximation and the actual token count may vary. + === Implementation Details Cache control is configured through `AnthropicChatOptions` rather than individual messages. @@ -534,7 +540,7 @@ Read more about xref:api/tools.adoc[Tool Calling]. == Multimodal -Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, pdf, images, data formats. +Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, pdf, images, data formats. === Images Currently, Anthropic Claude 3 supports the `base64` source type for `images`, and the `image/jpeg`, `image/png`, `image/gif`, and `image/webp` media types. @@ -712,4 +718,4 @@ Flux response = this.anthropicApi Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java[AnthropicApi.java]'s JavaDoc for further information. === Low-level API Examples -* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/chat/api/AnthropicApiIT.java[AnthropicApiIT.java] test provides some general examples how to use the lightweight library. +* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/chat/api/AnthropicApiIT.java[AnthropicApiIT.java] test provides some general examples how to use the lightweight library. \ No newline at end of file