| 
 | 1 | +/*  | 
 | 2 | + * Copyright 2023-2025 the original author or authors.  | 
 | 3 | + *  | 
 | 4 | + * Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 5 | + * you may not use this file except in compliance with the License.  | 
 | 6 | + * You may obtain a copy of the License at  | 
 | 7 | + *  | 
 | 8 | + *      https://www.apache.org/licenses/LICENSE-2.0  | 
 | 9 | + *  | 
 | 10 | + * Unless required by applicable law or agreed to in writing, software  | 
 | 11 | + * distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 13 | + * See the License for the specific language governing permissions and  | 
 | 14 | + * limitations under the License.  | 
 | 15 | + */  | 
 | 16 | + | 
 | 17 | +package org.springframework.ai.chat.memory;  | 
 | 18 | + | 
 | 19 | +import org.springframework.ai.chat.messages.Message;  | 
 | 20 | +import org.springframework.ai.chat.messages.SystemMessage;  | 
 | 21 | +import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;  | 
 | 22 | +import org.springframework.ai.tokenizer.TokenCountEstimator;  | 
 | 23 | +import org.springframework.util.Assert;  | 
 | 24 | + | 
 | 25 | +import java.util.ArrayList;  | 
 | 26 | +import java.util.HashSet;  | 
 | 27 | +import java.util.List;  | 
 | 28 | +import java.util.Set;  | 
 | 29 | + | 
 | 30 | +/**  | 
 | 31 | + * A chat memory implementation that maintains a message window of a specified size,  | 
 | 32 | + * ensuring that the total number of tokens does not exceed the specified limit. Messages  | 
 | 33 | + * are treated as indivisible units; when eviction is necessary due to exceeding the token  | 
 | 34 | + * limit, the oldest complete message is removed.  | 
 | 35 | + * <p>  | 
 | 36 | + * Messages of type {@link SystemMessage} are treated specially: if a new  | 
 | 37 | + * {@link SystemMessage} is added, all previous {@link SystemMessage} instances are  | 
 | 38 | + * removed from the memory.  | 
 | 39 | + *  | 
 | 40 | + * @author Sun Yuhan  | 
 | 41 | + * @since 1.1.0  | 
 | 42 | + */  | 
 | 43 | +public final class TokenWindowChatMemory implements ChatMemory {  | 
 | 44 | + | 
 | 45 | +	private static final long DEFAULT_MAX_TOKENS = 128000L;  | 
 | 46 | + | 
 | 47 | +	private final ChatMemoryRepository chatMemoryRepository;  | 
 | 48 | + | 
 | 49 | +	private final TokenCountEstimator tokenCountEstimator;  | 
 | 50 | + | 
 | 51 | +	private final long maxTokens;  | 
 | 52 | + | 
 | 53 | +	public TokenWindowChatMemory(ChatMemoryRepository chatMemoryRepository, TokenCountEstimator tokenCountEstimator,  | 
 | 54 | +			Long maxTokens) {  | 
 | 55 | +		Assert.notNull(chatMemoryRepository, "chatMemoryRepository cannot be null");  | 
 | 56 | +		Assert.notNull(tokenCountEstimator, "tokenCountEstimator cannot be null");  | 
 | 57 | +		Assert.isTrue(maxTokens > 0, "maxTokens must be greater than 0");  | 
 | 58 | +		this.chatMemoryRepository = chatMemoryRepository;  | 
 | 59 | +		this.tokenCountEstimator = tokenCountEstimator;  | 
 | 60 | +		this.maxTokens = maxTokens;  | 
 | 61 | +	}  | 
 | 62 | + | 
 | 63 | +	@Override  | 
 | 64 | +	public void add(String conversationId, List<Message> messages) {  | 
 | 65 | +		Assert.hasText(conversationId, "conversationId cannot be null or empty");  | 
 | 66 | +		Assert.notNull(messages, "messages cannot be null");  | 
 | 67 | +		Assert.noNullElements(messages, "messages cannot contain null elements");  | 
 | 68 | + | 
 | 69 | +		List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);  | 
 | 70 | +		List<Message> processedMessages = process(memoryMessages, messages);  | 
 | 71 | +		this.chatMemoryRepository.saveAll(conversationId, processedMessages);  | 
 | 72 | +	}  | 
 | 73 | + | 
 | 74 | +	@Override  | 
 | 75 | +	public List<Message> get(String conversationId) {  | 
 | 76 | +		Assert.hasText(conversationId, "conversationId cannot be null or empty");  | 
 | 77 | +		return this.chatMemoryRepository.findByConversationId(conversationId);  | 
 | 78 | +	}  | 
 | 79 | + | 
 | 80 | +	@Override  | 
 | 81 | +	public void clear(String conversationId) {  | 
 | 82 | +		Assert.hasText(conversationId, "conversationId cannot be null or empty");  | 
 | 83 | +		this.chatMemoryRepository.deleteByConversationId(conversationId);  | 
 | 84 | +	}  | 
 | 85 | + | 
 | 86 | +	private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {  | 
 | 87 | +		List<Message> processedMessages = new ArrayList<>();  | 
 | 88 | + | 
 | 89 | +		Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);  | 
 | 90 | +		boolean hasNewSystemMessage = newMessages.stream()  | 
 | 91 | +			.filter(SystemMessage.class::isInstance)  | 
 | 92 | +			.anyMatch(message -> !memoryMessagesSet.contains(message));  | 
 | 93 | + | 
 | 94 | +		memoryMessages.stream()  | 
 | 95 | +			.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))  | 
 | 96 | +			.forEach(processedMessages::add);  | 
 | 97 | + | 
 | 98 | +		processedMessages.addAll(newMessages);  | 
 | 99 | + | 
 | 100 | +		int tokens = processedMessages.stream()  | 
 | 101 | +			.mapToInt(processedMessage -> tokenCountEstimator.estimate(processedMessage.getText()))  | 
 | 102 | +			.sum();  | 
 | 103 | + | 
 | 104 | +		if (tokens <= this.maxTokens) {  | 
 | 105 | +			return processedMessages;  | 
 | 106 | +		}  | 
 | 107 | + | 
 | 108 | +		int removeMessageIndex = 0;  | 
 | 109 | +		while (tokens > this.maxTokens && !processedMessages.isEmpty()  | 
 | 110 | +				&& removeMessageIndex < processedMessages.size()) {  | 
 | 111 | +			if (processedMessages.get(removeMessageIndex) instanceof SystemMessage) {  | 
 | 112 | +				if (processedMessages.size() == 1) {  | 
 | 113 | +					break;  | 
 | 114 | +				}  | 
 | 115 | +				removeMessageIndex += 1;  | 
 | 116 | +				continue;  | 
 | 117 | +			}  | 
 | 118 | +			Message removedMessage = processedMessages.remove(removeMessageIndex);  | 
 | 119 | +			tokens -= tokenCountEstimator.estimate(removedMessage.getText());  | 
 | 120 | +		}  | 
 | 121 | + | 
 | 122 | +		return processedMessages;  | 
 | 123 | +	}  | 
 | 124 | + | 
 | 125 | +	public static Builder builder() {  | 
 | 126 | +		return new Builder();  | 
 | 127 | +	}  | 
 | 128 | + | 
 | 129 | +	public static final class Builder {  | 
 | 130 | + | 
 | 131 | +		private ChatMemoryRepository chatMemoryRepository;  | 
 | 132 | + | 
 | 133 | +		private TokenCountEstimator tokenCountEstimator;  | 
 | 134 | + | 
 | 135 | +		private long maxTokens = DEFAULT_MAX_TOKENS;  | 
 | 136 | + | 
 | 137 | +		private Builder() {  | 
 | 138 | +		}  | 
 | 139 | + | 
 | 140 | +		public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) {  | 
 | 141 | +			this.chatMemoryRepository = chatMemoryRepository;  | 
 | 142 | +			return this;  | 
 | 143 | +		}  | 
 | 144 | + | 
 | 145 | +		public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) {  | 
 | 146 | +			this.tokenCountEstimator = tokenCountEstimator;  | 
 | 147 | +			return this;  | 
 | 148 | +		}  | 
 | 149 | + | 
 | 150 | +		public Builder maxTokens(long maxTokens) {  | 
 | 151 | +			this.maxTokens = maxTokens;  | 
 | 152 | +			return this;  | 
 | 153 | +		}  | 
 | 154 | + | 
 | 155 | +		public TokenWindowChatMemory build() {  | 
 | 156 | +			if (this.chatMemoryRepository == null) {  | 
 | 157 | +				this.chatMemoryRepository = new InMemoryChatMemoryRepository();  | 
 | 158 | +			}  | 
 | 159 | +			if (this.tokenCountEstimator == null) {  | 
 | 160 | +				this.tokenCountEstimator = new JTokkitTokenCountEstimator();  | 
 | 161 | +			}  | 
 | 162 | +			return new TokenWindowChatMemory(this.chatMemoryRepository, this.tokenCountEstimator, this.maxTokens);  | 
 | 163 | +		}  | 
 | 164 | + | 
 | 165 | +	}  | 
 | 166 | + | 
 | 167 | +}  | 
0 commit comments