Skip to content

Commit 0b4f09c

Browse files
committed
feat: Added an implementation of ChatMemory based on a token-size sliding window limit: TokenWindowChatMemory.
This implementation limits the chat memory based on the total number of tokens. When the total token count exceeds the limit, the oldest messages are evicted. Messages are indivisible, meaning that if the token limit is exceeded, the entire oldest message is removed. Signed-off-by: Sun Yuhan <[email protected]>
1 parent ea995df commit 0b4f09c

File tree

2 files changed

+475
-0
lines changed

2 files changed

+475
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory;
18+
19+
import org.springframework.ai.chat.messages.Message;
20+
import org.springframework.ai.chat.messages.SystemMessage;
21+
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
22+
import org.springframework.ai.tokenizer.TokenCountEstimator;
23+
import org.springframework.util.Assert;
24+
25+
import java.util.ArrayList;
26+
import java.util.HashSet;
27+
import java.util.List;
28+
import java.util.Set;
29+
30+
/**
31+
* A chat memory implementation that maintains a message window of a specified size,
32+
* ensuring that the total number of tokens does not exceed the specified limit. Messages
33+
* are treated as indivisible units; when eviction is necessary due to exceeding the token
34+
* limit, the oldest complete message is removed.
35+
* <p>
36+
* Messages of type {@link SystemMessage} are treated specially: if a new
37+
* {@link SystemMessage} is added, all previous {@link SystemMessage} instances are
38+
* removed from the memory.
39+
*
40+
* @author Sun Yuhan
41+
* @since 1.1.0
42+
*/
43+
public final class TokenWindowChatMemory implements ChatMemory {
44+
45+
private static final long DEFAULT_MAX_TOKENS = 128000L;
46+
47+
private final ChatMemoryRepository chatMemoryRepository;
48+
49+
private final TokenCountEstimator tokenCountEstimator;
50+
51+
private final long maxTokens;
52+
53+
public TokenWindowChatMemory(ChatMemoryRepository chatMemoryRepository, TokenCountEstimator tokenCountEstimator,
54+
Long maxTokens) {
55+
Assert.notNull(chatMemoryRepository, "chatMemoryRepository cannot be null");
56+
Assert.notNull(tokenCountEstimator, "tokenCountEstimator cannot be null");
57+
Assert.isTrue(maxTokens > 0, "maxTokens must be greater than 0");
58+
this.chatMemoryRepository = chatMemoryRepository;
59+
this.tokenCountEstimator = tokenCountEstimator;
60+
this.maxTokens = maxTokens;
61+
}
62+
63+
@Override
64+
public void add(String conversationId, List<Message> messages) {
65+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
66+
Assert.notNull(messages, "messages cannot be null");
67+
Assert.noNullElements(messages, "messages cannot contain null elements");
68+
69+
List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
70+
List<Message> processedMessages = process(memoryMessages, messages);
71+
this.chatMemoryRepository.saveAll(conversationId, processedMessages);
72+
}
73+
74+
@Override
75+
public List<Message> get(String conversationId) {
76+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
77+
return this.chatMemoryRepository.findByConversationId(conversationId);
78+
}
79+
80+
@Override
81+
public void clear(String conversationId) {
82+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
83+
this.chatMemoryRepository.deleteByConversationId(conversationId);
84+
}
85+
86+
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
87+
List<Message> processedMessages = new ArrayList<>();
88+
89+
Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
90+
boolean hasNewSystemMessage = newMessages.stream()
91+
.filter(SystemMessage.class::isInstance)
92+
.anyMatch(message -> !memoryMessagesSet.contains(message));
93+
94+
memoryMessages.stream()
95+
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
96+
.forEach(processedMessages::add);
97+
98+
processedMessages.addAll(newMessages);
99+
100+
int tokens = processedMessages.stream()
101+
.mapToInt(processedMessage -> tokenCountEstimator.estimate(processedMessage.getText()))
102+
.sum();
103+
104+
if (tokens <= this.maxTokens) {
105+
return processedMessages;
106+
}
107+
108+
int removeMessageIndex = 0;
109+
while (tokens > this.maxTokens && !processedMessages.isEmpty()
110+
&& removeMessageIndex < processedMessages.size()) {
111+
if (processedMessages.get(removeMessageIndex) instanceof SystemMessage) {
112+
if (processedMessages.size() == 1) {
113+
break;
114+
}
115+
removeMessageIndex += 1;
116+
continue;
117+
}
118+
Message removedMessage = processedMessages.remove(removeMessageIndex);
119+
tokens -= tokenCountEstimator.estimate(removedMessage.getText());
120+
}
121+
122+
return processedMessages;
123+
}
124+
125+
public static Builder builder() {
126+
return new Builder();
127+
}
128+
129+
public static final class Builder {
130+
131+
private ChatMemoryRepository chatMemoryRepository;
132+
133+
private TokenCountEstimator tokenCountEstimator;
134+
135+
private long maxTokens = DEFAULT_MAX_TOKENS;
136+
137+
private Builder() {
138+
}
139+
140+
public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) {
141+
this.chatMemoryRepository = chatMemoryRepository;
142+
return this;
143+
}
144+
145+
public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) {
146+
this.tokenCountEstimator = tokenCountEstimator;
147+
return this;
148+
}
149+
150+
public Builder maxTokens(long maxTokens) {
151+
this.maxTokens = maxTokens;
152+
return this;
153+
}
154+
155+
public TokenWindowChatMemory build() {
156+
if (this.chatMemoryRepository == null) {
157+
this.chatMemoryRepository = new InMemoryChatMemoryRepository();
158+
}
159+
if (this.tokenCountEstimator == null) {
160+
this.tokenCountEstimator = new JTokkitTokenCountEstimator();
161+
}
162+
return new TokenWindowChatMemory(this.chatMemoryRepository, this.tokenCountEstimator, this.maxTokens);
163+
}
164+
165+
}
166+
167+
}

0 commit comments

Comments
 (0)