Skip to content

Commit d732617

Browse files
committed
added cache suport to assistant message
1 parent f17ee20 commit d732617

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -475,19 +475,26 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead
475475

476476
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
477477

478-
Optional<Message> lastMessage = prompt.getInstructions()
478+
List<Message> userMessagesList = prompt.getInstructions()
479479
.stream()
480-
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
481-
.findFirst();
480+
.filter(message -> message.getMessageType() == MessageType.USER)
481+
.toList();
482+
Message lastUserMessage = userMessagesList.isEmpty() ? null : userMessagesList.get(userMessagesList.size() - 1);
483+
484+
List<Message> assistantMessageList = prompt.getInstructions()
485+
.stream()
486+
.filter(message -> message.getMessageType() == MessageType.ASSISTANT)
487+
.toList();
488+
Message lastAssistantMessage = assistantMessageList.isEmpty() ? null : assistantMessageList.get(assistantMessageList.size() - 1);
482489

483490
List<AnthropicMessage> userMessages = prompt.getInstructions()
484491
.stream()
485492
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
486493
.map(message -> {
494+
AbstractMessage abstractMessage = (AbstractMessage) message;
487495
if (message.getMessageType() == MessageType.USER) {
488-
AbstractMessage abstractMessage = (AbstractMessage) message;
489496
List<ContentBlock> contents;
490-
boolean isLastItem = lastMessage.filter(message::equals).isPresent();
497+
boolean isLastItem = message.equals(lastUserMessage);
491498
if (isLastItem && abstractMessage.getCache() != null) {
492499
AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache());
493500
contents = new ArrayList<>(
@@ -511,8 +518,14 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
511518
else if (message.getMessageType() == MessageType.ASSISTANT) {
512519
AssistantMessage assistantMessage = (AssistantMessage) message;
513520
List<ContentBlock> contentBlocks = new ArrayList<>();
521+
boolean isLastItem = message.equals(lastAssistantMessage);
514522
if (StringUtils.hasText(message.getText())) {
515-
contentBlocks.add(new ContentBlock(message.getText()));
523+
if (isLastItem && abstractMessage.getCache() != null) {
524+
AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache());
525+
contentBlocks.add(new ContentBlock(message.getText(), cacheType.cacheControl()));
526+
} else {
527+
contentBlocks.add(new ContentBlock(message.getText()));
528+
}
516529
}
517530
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
518531
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
@@ -551,6 +564,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
551564
// Add the tool definitions to the request's tools parameter.
552565
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
553566
if (!CollectionUtils.isEmpty(toolDefinitions)) {
567+
var tool = getFunctionTools(toolDefinitions);
554568
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
555569
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
556570
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.Random;
2122
import java.util.stream.Collectors;
2223

2324
import org.junit.jupiter.api.Test;
@@ -73,6 +74,39 @@ public class AnthropicApiIT {
7374
}
7475
""")));
7576

77+
@Test
78+
void chatWithPromptCacheInAssistantMessage() {
79+
String assistantMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which "
80+
+ "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" "
81+
+ "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English "
82+
+ "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". "
83+
+ "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\"";
84+
85+
AnthropicMessage chatCompletionMessage = new AnthropicMessage(
86+
List.of(new ContentBlock(assistantMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())),
87+
Role.ASSISTANT);
88+
89+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
90+
.model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU)
91+
.messages(List.of(chatCompletionMessage))
92+
.maxTokens(1500)
93+
.temperature(0.8)
94+
.stream(false)
95+
.build();
96+
97+
AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest)
98+
.getBody()
99+
.usage();
100+
101+
assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0);
102+
assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0);
103+
104+
AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage();
105+
106+
assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0);
107+
assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0);
108+
}
109+
76110
@Test
77111
void chatWithPromptCache() {
78112
String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which "

0 commit comments

Comments
 (0)