2020import java .util .HashMap ;
2121import java .util .List ;
2222import java .util .Map ;
23- import java .util .stream .Collectors ;
2423
25- import reactor .core .publisher .Flux ;
26-
27- import org .springframework .ai .chat .client .advisor .AbstractChatMemoryAdvisor ;
2824import org .springframework .ai .chat .client .ChatClientRequest ;
2925import org .springframework .ai .chat .client .ChatClientResponse ;
30- import org .springframework .ai .chat .client .advisor .api . CallAdvisorChain ;
31- import org .springframework .ai .chat .client .advisor .api .StreamAdvisorChain ;
26+ import org .springframework .ai .chat .client .advisor .AbstractChatMemoryAdvisor ;
27+ import org .springframework .ai .chat .client .advisor .api .AdvisorChain ;
3228import org .springframework .ai .chat .messages .AssistantMessage ;
3329import org .springframework .ai .chat .messages .Message ;
3430import org .springframework .ai .chat .messages .MessageType ;
35- import org .springframework .ai .chat .messages .SystemMessage ;
3631import org .springframework .ai .chat .messages .UserMessage ;
37- import org .springframework .ai .chat .model .MessageAggregator ;
3832import org .springframework .ai .chat .prompt .PromptTemplate ;
3933import org .springframework .ai .document .Document ;
40- import org .springframework .ai .vectorstore .SearchRequest ;
4134import org .springframework .ai .vectorstore .VectorStore ;
4235
4336/**
4841 * @author Christian Tzolov
4942 * @author Thomas Vitale
5043 * @author Oganes Bozoyan
44+ * @author Mark Pollack
5145 * @since 1.0.0
5246 */
5347public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor <VectorStore > {
5448
49+ public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size" ;
50+
5551 private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId" ;
5652
5753 private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType" ;
5854
55+ /**
56+ * The default chat memory retrieve size to use when no retrieve size is provided.
57+ */
58+ public static final int DEFAULT_CHAT_MEMORY_RESPONSE_SIZE = 100 ;
59+
5960 private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate ("""
6061 {instructions}
6162
@@ -69,71 +70,64 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6970
7071 private final PromptTemplate systemPromptTemplate ;
7172
72- private VectorStoreChatMemoryAdvisor (VectorStore vectorStore , String defaultConversationId ,
73- int chatHistoryWindowSize , boolean protectFromBlocking , PromptTemplate systemPromptTemplate , int order ) {
74- super (vectorStore , defaultConversationId , chatHistoryWindowSize , protectFromBlocking , order );
73+ protected final int defaultChatMemoryRetrieveSize ;
74+
75+ public VectorStoreChatMemoryAdvisor (VectorStore chatMemory , String defaultConversationId ,
76+ int defaultChatMemoryRetrieveSize , boolean protectFromBlocking , PromptTemplate systemPromptTemplate ,
77+ int order ) {
78+ super (chatMemory , defaultConversationId , protectFromBlocking , order );
7579 this .systemPromptTemplate = systemPromptTemplate ;
80+ this .defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize ;
7681 }
7782
7883 public static Builder builder (VectorStore chatMemory ) {
7984 return new Builder (chatMemory );
8085 }
8186
8287 @ Override
83- public ChatClientResponse adviseCall (ChatClientRequest chatClientRequest , CallAdvisorChain callAdvisorChain ) {
84- chatClientRequest = this .before (chatClientRequest );
85-
86- ChatClientResponse chatClientResponse = callAdvisorChain .nextCall (chatClientRequest );
87-
88- this .after (chatClientResponse );
89-
90- return chatClientResponse ;
91- }
92-
93- @ Override
94- public Flux <ChatClientResponse > adviseStream (ChatClientRequest chatClientRequest ,
95- StreamAdvisorChain streamAdvisorChain ) {
96- Flux <ChatClientResponse > chatClientResponses = this .doNextWithProtectFromBlockingBefore (chatClientRequest ,
97- streamAdvisorChain , this ::before );
98-
99- return new MessageAggregator ().aggregateChatClientResponse (chatClientResponses , this ::after );
100- }
101-
102- private ChatClientRequest before (ChatClientRequest chatClientRequest ) {
103- String conversationId = this .doGetConversationId (chatClientRequest .context ());
104- int chatMemoryRetrieveSize = this .doGetChatMemoryRetrieveSize (chatClientRequest .context ());
105-
106- // 1. Retrieve the chat memory for the current conversation.
107- var searchRequest = SearchRequest .builder ()
108- .query (chatClientRequest .prompt ().getUserMessage ().getText ())
109- .topK (chatMemoryRetrieveSize )
110- .filterExpression (DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'" )
88+ public ChatClientRequest before (ChatClientRequest request , AdvisorChain advisorChain ) {
89+ String conversationId = doGetConversationId (request .context ());
90+ String query = request .prompt ().getUserMessage () != null ? request .prompt ().getUserMessage ().getText () : "" ;
91+ int topK = doGetChatMemoryRetrieveSize (request .context ());
92+ String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'" ;
93+ var searchRequest = org .springframework .ai .vectorstore .SearchRequest .builder ()
94+ .query (query )
95+ .topK (topK )
96+ .filterExpression (filter )
11197 .build ();
98+ java .util .List <org .springframework .ai .document .Document > documents = this .getChatMemoryStore ()
99+ .similaritySearch (searchRequest );
112100
113- List <Document > documents = this .getChatMemoryStore ().similaritySearch (searchRequest );
114-
115- // 2. Processed memory messages as a string.
116101 String longTermMemory = documents == null ? ""
117- : documents .stream ().map (Document ::getText ).collect (Collectors .joining (System .lineSeparator ()));
102+ : documents .stream ()
103+ .map (org .springframework .ai .document .Document ::getText )
104+ .collect (java .util .stream .Collectors .joining (System .lineSeparator ()));
118105
119- // 2. Augment the system message.
120- SystemMessage systemMessage = chatClientRequest .prompt ().getSystemMessage ();
106+ org .springframework .ai .chat .messages .SystemMessage systemMessage = request .prompt ().getSystemMessage ();
121107 String augmentedSystemText = this .systemPromptTemplate
122- .render (Map .of ("instructions" , systemMessage .getText (), "long_term_memory" , longTermMemory ));
108+ .render (java . util . Map .of ("instructions" , systemMessage .getText (), "long_term_memory" , longTermMemory ));
123109
124- // 3. Create a new request with the augmented system message.
125- ChatClientRequest processedChatClientRequest = chatClientRequest .mutate ()
126- .prompt (chatClientRequest .prompt ().augmentSystemMessage (augmentedSystemText ))
110+ ChatClientRequest processedChatClientRequest = request .mutate ()
111+ .prompt (request .prompt ().augmentSystemMessage (augmentedSystemText ))
127112 .build ();
128113
129- // 4. Add the new user message to the conversation memory.
130- UserMessage userMessage = processedChatClientRequest .prompt ().getUserMessage ();
131- this .getChatMemoryStore ().write (toDocuments (List .of (userMessage ), conversationId ));
114+ org .springframework .ai .chat .messages .UserMessage userMessage = processedChatClientRequest .prompt ()
115+ .getUserMessage ();
116+ if (userMessage != null ) {
117+ this .getChatMemoryStore ().write (toDocuments (java .util .List .of (userMessage ), conversationId ));
118+ }
132119
133120 return processedChatClientRequest ;
134121 }
135122
136- private void after (ChatClientResponse chatClientResponse ) {
123+ protected int doGetChatMemoryRetrieveSize (Map <String , Object > context ) {
124+ return context .containsKey (CHAT_MEMORY_RETRIEVE_SIZE_KEY )
125+ ? Integer .parseInt (context .get (CHAT_MEMORY_RETRIEVE_SIZE_KEY ).toString ())
126+ : this .defaultChatMemoryRetrieveSize ;
127+ }
128+
129+ @ Override
130+ public ChatClientResponse after (ChatClientResponse chatClientResponse , AdvisorChain advisorChain ) {
137131 List <Message > assistantMessages = new ArrayList <>();
138132 if (chatClientResponse .chatResponse () != null ) {
139133 assistantMessages = chatClientResponse .chatResponse ()
@@ -144,6 +138,7 @@ private void after(ChatClientResponse chatClientResponse) {
144138 }
145139 this .getChatMemoryStore ()
146140 .write (toDocuments (assistantMessages , this .doGetConversationId (chatClientResponse .context ())));
141+ return chatClientResponse ;
147142 }
148143
149144 private List <Document > toDocuments (List <Message > messages , String conversationId ) {
@@ -173,22 +168,56 @@ else if (message instanceof AssistantMessage assistantMessage) {
173168 return docs ;
174169 }
175170
176- public static class Builder extends AbstractChatMemoryAdvisor .AbstractBuilder <VectorStore > {
171+ /**
172+ * Builder for VectorStoreChatMemoryAdvisor.
173+ */
174+ public static class Builder extends AbstractChatMemoryAdvisor .AbstractBuilder <VectorStore , Builder > {
177175
178176 private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE ;
179177
180- protected Builder (VectorStore chatMemory ) {
181- super (chatMemory );
178+ private Integer chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE ;
179+
180+ /**
181+ * Creates a new builder instance.
182+ * @param vectorStore the vector store to use
183+ */
184+ protected Builder (VectorStore vectorStore ) {
185+ super (vectorStore );
182186 }
183187
184- public Builder systemTextAdvise ( String systemTextAdvise ) {
185- this . systemPromptTemplate = new PromptTemplate ( systemTextAdvise );
188+ @ Override
189+ protected Builder self () {
186190 return this ;
187191 }
188192
193+ /**
194+ * Set the system prompt template.
195+ * @param systemPromptTemplate the system prompt template
196+ * @return this builder
197+ */
189198 public Builder systemPromptTemplate (PromptTemplate systemPromptTemplate ) {
190199 this .systemPromptTemplate = systemPromptTemplate ;
191- return this ;
200+ return self ();
201+ }
202+
203+ /**
204+ * Set the system prompt template using a text template.
205+ * @param systemTextAdvise the system prompt text template
206+ * @return this builder
207+ */
208+ public Builder systemTextAdvise (String systemTextAdvise ) {
209+ this .systemPromptTemplate = new PromptTemplate (systemTextAdvise );
210+ return self ();
211+ }
212+
213+ /**
214+ * Set the chat memory retrieve size.
215+ * @param chatMemoryRetrieveSize the chat memory retrieve size
216+ * @return this builder
217+ */
218+ public Builder chatMemoryRetrieveSize (int chatMemoryRetrieveSize ) {
219+ this .chatMemoryRetrieveSize = chatMemoryRetrieveSize ;
220+ return self ();
192221 }
193222
194223 @ Override
0 commit comments