Skip to content

feat(anthropic): add support for prompt caching with fix #4100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@

package org.springframework.ai.anthropic;

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -30,6 +25,7 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
Expand All @@ -42,10 +38,7 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.anthropic.api.AnthropicCacheType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
Expand Down Expand Up @@ -482,12 +475,35 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead

ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<Message> userMessagesList = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() == MessageType.USER)
.toList();
Message lastUserMessage = userMessagesList.isEmpty() ? null : userMessagesList.get(userMessagesList.size() - 1);

List<Message> assistantMessageList = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() == MessageType.ASSISTANT)
.toList();
Message lastAssistantMessage = assistantMessageList.isEmpty() ? null
: assistantMessageList.get(assistantMessageList.size() - 1);

List<AnthropicMessage> userMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
AbstractMessage abstractMessage = (AbstractMessage) message;
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
List<ContentBlock> contents;
boolean isLastItem = message.equals(lastUserMessage);
if (isLastItem && abstractMessage.getCache() != null) {
AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache());
contents = new ArrayList<>(
List.of(new ContentBlock(message.getText(), cacheType.cacheControl())));
}
else {
contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
}
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
Expand All @@ -503,8 +519,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
boolean isLastItem = message.equals(lastAssistantMessage);
if (StringUtils.hasText(message.getText())) {
contentBlocks.add(new ContentBlock(message.getText()));
if (isLastItem && abstractMessage.getCache() != null) {
AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache());
contentBlocks.add(new ContentBlock(message.getText(), cacheType.cacheControl()));
}
else {
contentBlocks.add(new ContentBlock(message.getText()));
}
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
Expand Down Expand Up @@ -543,6 +566,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
// Add the tool definitions to the request's tools parameter.
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
if (!CollectionUtils.isEmpty(toolDefinitions)) {
var tool = getFunctionTools(toolDefinitions);
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
import java.util.function.Consumer;
import java.util.function.Predicate;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
Expand All @@ -46,14 +55,6 @@
import org.springframework.web.reactive.function.client.WebClient;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
* The Anthropic API client.
Expand Down Expand Up @@ -94,6 +95,8 @@ public static Builder builder() {

private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta";

public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31";

private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;

private final String completionsPath;
Expand Down Expand Up @@ -538,25 +541,30 @@ public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Stri
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null);
}

public static ChatCompletionRequestBuilder builder() {
return new ChatCompletionRequestBuilder();
}

public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) {
return new ChatCompletionRequestBuilder(request);
}

/**
* Metadata about the request.
*
* @param userId An external identifier for the user who is associated with the
* request. This should be a uuid, hash value, or other opaque identifier.
* Anthropic may use this id to help detect abuse. Do not include any identifying
* information such as name, email address, or phone number.
*/
@JsonInclude(Include.NON_NULL)
public record Metadata(@JsonProperty("user_id") String userId) {
}

/**
* @param type is the cache type supported by anthropic. <a href=
* "https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations">Doc</a>
*/
@JsonInclude(Include.NON_NULL)
public record CacheControl(String type) {
}

public static ChatCompletionRequestBuilder builder() {
return new ChatCompletionRequestBuilder();
}

public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) {
return new ChatCompletionRequestBuilder(request);
}

/**
Expand Down Expand Up @@ -760,7 +768,10 @@ public record ContentBlock(
@JsonProperty("tool_use_id") String toolUseId,
@JsonProperty("content") String content,

// Thinking only
// cache object
@JsonProperty("cache_control") CacheControl cacheControl,

// Thinking only
@JsonProperty("signature") String signature,
@JsonProperty("thinking") String thinking,

Expand All @@ -784,23 +795,27 @@ public ContentBlock(String mediaType, String data) {
* @param source The source of the content.
*/
public ContentBlock(Type type, Source source) {
this(type, source, null, null, null, null, null, null, null, null, null, null);
this(type, source, null, null, null, null, null, null, null, null, null, null, null);
}

/**
* Create content block
* @param source The source of the content.
*/
public ContentBlock(Source source) {
this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null);
this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null, null);
}

/**
* Create content block
* @param text The text of the content.
*/
public ContentBlock(String text) {
this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null);
this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, null);
}

public ContentBlock(String text, CacheControl cache) {
this(Type.TEXT, null, text, null, null, null, null, null, null, cache, null, null, null);
}

// Tool result
Expand All @@ -811,7 +826,7 @@ public ContentBlock(String text) {
* @param content The content of the tool result.
*/
public ContentBlock(Type type, String toolUseId, String content) {
this(type, null, null, null, null, null, null, toolUseId, content, null, null, null);
this(type, null, null, null, null, null, null, toolUseId, content, null, null, null, null);
}

/**
Expand All @@ -822,7 +837,7 @@ public ContentBlock(Type type, String toolUseId, String content) {
* @param index The index of the content block.
*/
public ContentBlock(Type type, Source source, String text, Integer index) {
this(type, source, text, index, null, null, null, null, null, null, null, null);
this(type, source, text, index, null, null, null, null, null, null, null, null, null);
}

// Tool use input JSON delta streaming
Expand All @@ -834,7 +849,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) {
* @param input The input of the tool use.
*/
public ContentBlock(Type type, String id, String name, Map<String, Object> input) {
this(type, null, null, null, id, name, input, null, null, null, null, null);
this(type, null, null, null, id, name, input, null, null, null, null, null, null);
}

/**
Expand Down Expand Up @@ -1028,7 +1043,9 @@ public record ChatCompletionResponse(
public record Usage(
// @formatter:off
@JsonProperty("input_tokens") Integer inputTokens,
@JsonProperty("output_tokens") Integer outputTokens) {
@JsonProperty("output_tokens") Integer outputTokens,
@JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens,
@JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) {
// @formatter:off
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.springframework.ai.anthropic.api;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;

import java.util.function.Supplier;

public enum AnthropicCacheType {

EPHEMERAL(() -> new CacheControl("ephemeral"));

private Supplier<CacheControl> value;

AnthropicCacheType(Supplier<CacheControl> value) {
this.value = value;
}

public CacheControl cacheControl() {
return this.value.get();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_START)) {
}
else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) {
ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null,
null, null, null, null, null, thinkingBlock.thinking(), null);
null, null, null, null, null, null, thinkingBlock.thinking(), null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else {
Expand All @@ -176,12 +176,12 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) {
}
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) {
ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(),
null, null, null, null, null, null, thinking.thinking(), null);
null, null, null, null, null, null, null, thinking.thinking(), null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) {
ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(),
null, null, null, null, null, sig.signature(), null, null);
null, null, null, null, null, null, sig.signature(), null, null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else {
Expand All @@ -204,8 +204,10 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
}

if (messageDeltaEvent.usage() != null) {
Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
messageDeltaEvent.usage().outputTokens());
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
messageDeltaEvent.usage().outputTokens(),
contentBlockReference.get().usage.cacheCreationInputTokens(),
contentBlockReference.get().usage.cacheReadInputTokens());
contentBlockReference.get().withUsage(totalUsage);
}
}
Expand Down
Loading