Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.type.TypeReference;
Expand Down Expand Up @@ -91,6 +92,7 @@
* @author Alexandros Pappas
* @author Jonghoon Park
* @author Soby Chacko
* @author Austin Dase
* @since 1.0.0
*/
public class AnthropicChatModel implements ChatModel {
Expand Down Expand Up @@ -481,34 +483,100 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead
return mergedHttpHeaders;
}

private static ContentBlock cacheAwareContentBlock(String text, AtomicInteger usedCacheBlocks,
AnthropicChatOptions.CacheControlConfiguration cfg, MessageType type) {
return cacheAwareContentBlock(new ContentBlock(text), usedCacheBlocks, cfg, type);
}

private static ContentBlock cacheAwareContentBlock(ContentBlock contentBlock, AtomicInteger usedCacheBlocks,
AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration, MessageType messageType) {
if (cacheControlConfiguration == null) {
return contentBlock;
}

// Only proceed if this message is eligible for caching AND we can reserve a cache
// slot
if (isCacheEligible(contentBlock, cacheControlConfiguration, messageType)
&& tryReserveCacheBlock(usedCacheBlocks, cacheControlConfiguration.getMaxCacheBlocks())) {
return ContentBlock.from(contentBlock)
.cacheControl(cacheControlConfiguration.getCacheTypeForMessageType(messageType).cacheControl())
.build();
}

if (logger.isDebugEnabled()) {
final Integer minCacheBlockLength = cacheControlConfiguration.getMinBlockLengthForMessageType(messageType);
logger.debug(
"Skipping cache for messageType={}, used={}/{}; textLength={}, contentLength={}, minLength={}, cachableTypes={}",
messageType, usedCacheBlocks.get(), cacheControlConfiguration.getMaxCacheBlocks(),
safeLength(contentBlock.text()), safeLength(contentBlock.content()), minCacheBlockLength,
cacheControlConfiguration.getCachableMessageTypes());
}

return contentBlock;
}

private static int safeLength(String s) {
return (s == null) ? 0 : s.length();
}

private static boolean isCacheEligible(ContentBlock block,
AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration, MessageType messageType) {
if (!cacheControlConfiguration.getCachableMessageTypes().contains(messageType)) {
return false;
}

final int minCacheBlockLength = cacheControlConfiguration.getMinBlockLengthForMessageType(messageType);

return isNullOrGreaterThanLength(block.text(), minCacheBlockLength)
&& isNullOrGreaterThanLength(block.content(), minCacheBlockLength);
}

private static boolean isNullOrGreaterThanLength(String s, int min) {
return s == null || s.length() >= min;
}

/**
* Attempts to increment the counter only if we're still under the max. Returns true
* if we successfully reserved a slot.
*/
private static boolean tryReserveCacheBlock(AtomicInteger used, int max) {
int prev = used.getAndUpdate(v -> (v < max) ? (v + 1) : v);
return prev < max;
}

ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

// Get cache control from options
AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions();
AnthropicApi.ChatCompletionRequest.CacheControl cacheControl = (requestOptions != null)
? requestOptions.getCacheControl() : null;
AnthropicChatOptions.CacheControlConfiguration cacheControlConfiguration = (requestOptions != null)
? requestOptions.getCacheControlConfiguration() : null;

AtomicInteger usedCacheBlocks = new AtomicInteger();

List<ContentBlock> systemPrompt = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(m -> cacheAwareContentBlock(m.getText(), usedCacheBlocks, cacheControlConfiguration,
MessageType.SYSTEM))
.toList();

List<AnthropicMessage> userMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>();

// Apply cache control if enabled for user messages
if (cacheControl != null) {
contents.add(new ContentBlock(message.getText(), cacheControl));
}
else {
contents.add(new ContentBlock(message.getText()));
}
contents.add(cacheAwareContentBlock(message.getText(), usedCacheBlocks, cacheControlConfiguration,
MessageType.USER));
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
Type contentBlockType = getContentBlockTypeByMedia(media);
var source = getSourceByMedia(media);
return new ContentBlock(contentBlockType, source);
}).toList();
})
.map(contentBlock -> cacheAwareContentBlock(contentBlock, usedCacheBlocks,
cacheControlConfiguration, MessageType.USER))
.toList();
contents.addAll(mediaContent);
}
}
Expand All @@ -518,12 +586,15 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getText())) {
contentBlocks.add(new ContentBlock(message.getText()));
contentBlocks.add(cacheAwareContentBlock(message.getText(), usedCacheBlocks,
cacheControlConfiguration, MessageType.ASSISTANT));
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
ContentBlock contentBlock = new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
ModelOptionsUtils.jsonToMap(toolCall.arguments()));
contentBlocks.add(cacheAwareContentBlock(contentBlock, usedCacheBlocks,
cacheControlConfiguration, MessageType.ASSISTANT));
}
}
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
Expand All @@ -533,6 +604,8 @@ else if (message.getMessageType() == MessageType.TOOL) {
.stream()
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
toolResponse.responseData()))
.map(contentBlock -> cacheAwareContentBlock(contentBlock, usedCacheBlocks,
cacheControlConfiguration, MessageType.TOOL))
.toList();
return new AnthropicMessage(toolResponses, Role.USER);
}
Expand All @@ -542,14 +615,14 @@ else if (message.getMessageType() == MessageType.TOOL) {
})
.toList();

String systemPrompt = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(m -> m.getText())
.collect(Collectors.joining(System.lineSeparator()));

ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);
ChatCompletionRequest request = ChatCompletionRequest.builder()
.model(this.defaultOptions.getModel())
.messages(userMessages)
.system(systemPrompt)
.maxTokens(this.defaultOptions.getMaxTokens())
.temperature(this.defaultOptions.getTemperature())
.stream(stream)
.build();

request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);

Expand Down
Loading