Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,25 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;

/**
Expand All @@ -48,14 +49,24 @@
* @author Christian Tzolov
* @author Thomas Vitale
* @author Oganes Bozoyan
* @author Mark Pollack
* @since 1.0.0
*/
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {

private static final Logger logger = LoggerFactory.getLogger(VectorStoreChatMemoryAdvisor.class);

public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider renaming this to something more explicit about what it's used for

Suggested change
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
public static final String TOP_K = "chat_memory_vector_store_top_k";

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, also maybe move it into the VectorStore interface so it is easier to discover. Let's consider tomorrow.


private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";

private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";

/**
* The default chat memory retrieve size to use when no retrieve size is provided.
*/
public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably be renamed to something like DEFAULT_TOP_K ?

Also, isn't 100 a bit too much? Should it be something lower? Like 20? (the value we use for messages)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha, did it already!


private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
{instructions}

Expand All @@ -69,71 +80,93 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect

private final PromptTemplate systemPromptTemplate;

private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
protected final int defaultChatMemoryRetrieveSize;

private final String defaultConversationId;

private final int order;

private final Scheduler scheduler;

private VectorStore vectorStore;

public VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultChatMemoryRetrieveSize,
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
this.systemPromptTemplate = systemPromptTemplate;
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
this.defaultConversationId = defaultConversationId;
this.order = order;
this.scheduler = scheduler;
this.vectorStore = vectorStore;
}

public static Builder builder(VectorStore chatMemory) {
return new Builder(chatMemory);
}

@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
chatClientRequest = this.before(chatClientRequest);

ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);

this.after(chatClientResponse);

return chatClientResponse;
public int getOrder() {
return order;
}

@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
streamAdvisorChain, this::before);

return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
public Scheduler getScheduler() {
return this.scheduler;
}

private ChatClientRequest before(ChatClientRequest chatClientRequest) {
String conversationId = this.doGetConversationId(chatClientRequest.context());
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());

// 1. Retrieve the chat memory for the current conversation.
var searchRequest = SearchRequest.builder()
.query(chatClientRequest.prompt().getUserMessage().getText())
.topK(chatMemoryRetrieveSize)
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
@Override
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
String conversationId = doGetConversationId(request.context());
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
int topK = doGetChatMemoryRetrieveSize(request.context());
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
.query(query)
.topK(topK)
.filterExpression(filter)
.build();
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
.similaritySearch(searchRequest);

List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);

// 2. Processed memory messages as a string.
String longTermMemory = documents == null ? ""
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
: documents.stream()
.map(org.springframework.ai.document.Document::getText)
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));

// 2. Augment the system message.
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
String augmentedSystemText = this.systemPromptTemplate
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));

// 3. Create a new request with the augmented system message.
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
ChatClientRequest processedChatClientRequest = request.mutate()
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
.build();

// 4. Add the new user message to the conversation memory.
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
.getUserMessage();
if (userMessage != null) {
this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));
}

return processedChatClientRequest;
}

private void after(ChatClientResponse chatClientResponse) {
protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could maybe be inlined where it's used?

or else, it should be private and maybe renamed to something like "getChatMemoryTopK()" or something like that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed. Seems easier to parse in the code where it is being used vs. inlined.

return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
: this.defaultChatMemoryRetrieveSize;
}

protected String doGetConversationId(Map<String, Object> context) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be removed and replaced by calling getConversationId() from BaseChatMemoryAdvisor.

if (context == null || !context.containsKey(ChatMemory.CHAT_MEMORY_CONVERSATION_ID_KEY)) {
logger.warn("No conversation ID found in context; using defaultConversationId '{}'.",
this.defaultConversationId);
}
return context != null && context.containsKey(ChatMemory.CHAT_MEMORY_CONVERSATION_ID_KEY)
? context.get(ChatMemory.CHAT_MEMORY_CONVERSATION_ID_KEY).toString() : this.defaultConversationId;
}

@Override
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
List<Message> assistantMessages = new ArrayList<>();
if (chatClientResponse.chatResponse() != null) {
assistantMessages = chatClientResponse.chatResponse()
Expand All @@ -142,8 +175,8 @@ private void after(ChatClientResponse chatClientResponse) {
.map(g -> (Message) g.getOutput())
.toList();
}
this.getChatMemoryStore()
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
this.vectorStore.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
return chatClientResponse;
}

private List<Document> toDocuments(List<Message> messages, String conversationId) {
Expand Down Expand Up @@ -173,28 +206,103 @@ else if (message instanceof AssistantMessage assistantMessage) {
return docs;
}

public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {
/**
* Builder for VectorStoreChatMemoryAdvisor.
*/
public static class Builder {

private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;

protected Builder(VectorStore chatMemory) {
super(chatMemory);
private Integer chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE;

private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;

private Scheduler scheduler;

private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;

private VectorStore vectorStore;

/**
* Creates a new builder instance.
* @param vectorStore the vector store to use
*/
protected Builder(VectorStore vectorStore) {
this.vectorStore = vectorStore;
}

/**
* Set the system prompt template.
* @param systemPromptTemplate the system prompt template
* @return this builder
*/
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
this.systemPromptTemplate = systemPromptTemplate;
return this;
}

/**
* Set the system text advice.
* @param systemTextAdvise the system text advice
* @return this builder
*/
public Builder systemTextAdvise(String systemTextAdvise) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the amount of changes, I would remove this since it's been superseded by systemPromptTemplate(). This was kept for backward compatibility, but things will break not matter what.

this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
return this;
}

public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
this.systemPromptTemplate = systemPromptTemplate;
/**
* Set the chat memory retrieve size.
* @param chatMemoryRetrieveSize the chat memory retrieve size
* @return this builder
*/
public Builder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename to topK() or something like that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I dhaalso renamed the field to be public static final int DEFAULT_TOP_K = 100;

this.chatMemoryRetrieveSize = chatMemoryRetrieveSize;
return this;
}

/**
* Set the conversation id.
* @param conversationId the conversation id
* @return the builder
*/
public Builder conversationId(String conversationId) {
this.conversationId = conversationId;
return this;
}

/**
* Set whether to protect from blocking.
* @param protectFromBlocking whether to protect from blocking
* @return the builder
*/
public Builder protectFromBlocking(boolean protectFromBlocking) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove this in favour of scheduler()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as before.

this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
return this;
}

public Builder scheduler(Scheduler scheduler) {
this.scheduler = scheduler;
return this;
}

/**
* Set the order.
* @param order the order
* @return the builder
*/
public Builder order(int order) {
this.order = order;
return this;
}

@Override
/**
* Build the advisor.
* @return the advisor
*/
public VectorStoreChatMemoryAdvisor build() {
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
this.protectFromBlocking, this.systemPromptTemplate, this.order);
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.chatMemoryRetrieveSize,
this.conversationId, this.order, this.scheduler, this.vectorStore);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void messageChatMemoryAdvisor_withPromptMessages_throwsException() {
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();
MessageChatMemoryAdvisor advisor = new MessageChatMemoryAdvisor(chatMemory);
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();

ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();

Expand Down
Loading