Skip to content

Commit 839ef8c

Browse files
committed
Merge branch 'main' into FilterExpressionDsl
2 parents 02cc1f5 + d5d907b commit 839ef8c

File tree

119 files changed

+3126
-1486
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+3126
-1486
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 137 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,23 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
2423

25-
import reactor.core.publisher.Flux;
24+
import org.springframework.util.Assert;
25+
import reactor.core.scheduler.Scheduler;
2626

27-
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
2827
import org.springframework.ai.chat.client.ChatClientRequest;
2928
import org.springframework.ai.chat.client.ChatClientResponse;
30-
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
31-
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
29+
import org.springframework.ai.chat.client.advisor.api.Advisor;
30+
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
31+
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
32+
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
33+
import org.springframework.ai.chat.memory.ChatMemory;
3234
import org.springframework.ai.chat.messages.AssistantMessage;
3335
import org.springframework.ai.chat.messages.Message;
3436
import org.springframework.ai.chat.messages.MessageType;
35-
import org.springframework.ai.chat.messages.SystemMessage;
3637
import org.springframework.ai.chat.messages.UserMessage;
37-
import org.springframework.ai.chat.model.MessageAggregator;
3838
import org.springframework.ai.chat.prompt.PromptTemplate;
3939
import org.springframework.ai.document.Document;
40-
import org.springframework.ai.vectorstore.SearchRequest;
4140
import org.springframework.ai.vectorstore.VectorStore;
4241

4342
/**
@@ -48,14 +47,19 @@
4847
* @author Christian Tzolov
4948
* @author Thomas Vitale
5049
* @author Oganes Bozoyan
50+
* @author Mark Pollack
5151
* @since 1.0.0
5252
*/
53-
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {
53+
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
54+
55+
public static final String TOP_K = "chat_memory_vector_store_top_k";
5456

5557
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
5658

5759
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
5860

61+
private static final int DEFAULT_TOP_K = 20;
62+
5963
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6064
{instructions}
6165
@@ -69,71 +73,87 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6973

7074
private final PromptTemplate systemPromptTemplate;
7175

72-
private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
73-
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
74-
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
76+
private final int defaultTopK;
77+
78+
private final String defaultConversationId;
79+
80+
private final int order;
81+
82+
private final Scheduler scheduler;
83+
84+
private final VectorStore vectorStore;
85+
86+
private VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultTopK,
87+
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
88+
Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null");
89+
Assert.isTrue(defaultTopK > 0, "topK must be greater than 0");
90+
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
91+
Assert.notNull(scheduler, "scheduler cannot be null");
92+
Assert.notNull(vectorStore, "vectorStore cannot be null");
7593
this.systemPromptTemplate = systemPromptTemplate;
94+
this.defaultTopK = defaultTopK;
95+
this.defaultConversationId = defaultConversationId;
96+
this.order = order;
97+
this.scheduler = scheduler;
98+
this.vectorStore = vectorStore;
7699
}
77100

78101
public static Builder builder(VectorStore chatMemory) {
79102
return new Builder(chatMemory);
80103
}
81104

82105
@Override
83-
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
84-
chatClientRequest = this.before(chatClientRequest);
85-
86-
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
87-
88-
this.after(chatClientResponse);
89-
90-
return chatClientResponse;
106+
public int getOrder() {
107+
return order;
91108
}
92109

93110
@Override
94-
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
95-
StreamAdvisorChain streamAdvisorChain) {
96-
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
97-
streamAdvisorChain, this::before);
98-
99-
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
111+
public Scheduler getScheduler() {
112+
return this.scheduler;
100113
}
101114

102-
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
103-
String conversationId = this.doGetConversationId(chatClientRequest.context());
104-
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
105-
106-
// 1. Retrieve the chat memory for the current conversation.
107-
var searchRequest = SearchRequest.builder()
108-
.query(chatClientRequest.prompt().getUserMessage().getText())
109-
.topK(chatMemoryRetrieveSize)
110-
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
115+
@Override
116+
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
117+
String conversationId = getConversationId(request.context(), this.defaultConversationId);
118+
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
119+
int topK = getChatMemoryTopK(request.context());
120+
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
121+
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
122+
.query(query)
123+
.topK(topK)
124+
.filterExpression(filter)
111125
.build();
126+
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
127+
.similaritySearch(searchRequest);
112128

113-
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
114-
115-
// 2. Processed memory messages as a string.
116129
String longTermMemory = documents == null ? ""
117-
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
130+
: documents.stream()
131+
.map(org.springframework.ai.document.Document::getText)
132+
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
118133

119-
// 2. Augment the system message.
120-
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
134+
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
121135
String augmentedSystemText = this.systemPromptTemplate
122-
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
136+
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
123137

124-
// 3. Create a new request with the augmented system message.
125-
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
126-
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
138+
ChatClientRequest processedChatClientRequest = request.mutate()
139+
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
127140
.build();
128141

129-
// 4. Add the new user message to the conversation memory.
130-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
131-
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
142+
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
143+
.getUserMessage();
144+
if (userMessage != null) {
145+
this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));
146+
}
132147

133148
return processedChatClientRequest;
134149
}
135150

136-
private void after(ChatClientResponse chatClientResponse) {
151+
private int getChatMemoryTopK(Map<String, Object> context) {
152+
return context.containsKey(TOP_K) ? Integer.parseInt(context.get(TOP_K).toString()) : this.defaultTopK;
153+
}
154+
155+
@Override
156+
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
137157
List<Message> assistantMessages = new ArrayList<>();
138158
if (chatClientResponse.chatResponse() != null) {
139159
assistantMessages = chatClientResponse.chatResponse()
@@ -142,8 +162,9 @@ private void after(ChatClientResponse chatClientResponse) {
142162
.map(g -> (Message) g.getOutput())
143163
.toList();
144164
}
145-
this.getChatMemoryStore()
146-
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
165+
this.vectorStore.write(toDocuments(assistantMessages,
166+
this.getConversationId(chatClientResponse.context(), this.defaultConversationId)));
167+
return chatClientResponse;
147168
}
148169

149170
private List<Document> toDocuments(List<Message> messages, String conversationId) {
@@ -173,28 +194,83 @@ else if (message instanceof AssistantMessage assistantMessage) {
173194
return docs;
174195
}
175196

176-
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {
197+
/**
198+
* Builder for VectorStoreChatMemoryAdvisor.
199+
*/
200+
public static class Builder {
177201

178202
private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
179203

180-
protected Builder(VectorStore chatMemory) {
181-
super(chatMemory);
182-
}
204+
private Integer defaultTopK = DEFAULT_TOP_K;
183205

184-
public Builder systemTextAdvise(String systemTextAdvise) {
185-
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
186-
return this;
206+
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
207+
208+
private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER;
209+
210+
private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
211+
212+
private VectorStore vectorStore;
213+
214+
/**
215+
* Creates a new builder instance.
216+
* @param vectorStore the vector store to use
217+
*/
218+
protected Builder(VectorStore vectorStore) {
219+
this.vectorStore = vectorStore;
187220
}
188221

222+
/**
223+
* Set the system prompt template.
224+
* @param systemPromptTemplate the system prompt template
225+
* @return this builder
226+
*/
189227
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
190228
this.systemPromptTemplate = systemPromptTemplate;
191229
return this;
192230
}
193231

194-
@Override
232+
/**
233+
* Set the chat memory retrieve size.
234+
* @param defaultTopK the chat memory retrieve size
235+
* @return this builder
236+
*/
237+
public Builder defaultTopK(int defaultTopK) {
238+
this.defaultTopK = defaultTopK;
239+
return this;
240+
}
241+
242+
/**
243+
* Set the conversation id.
244+
* @param conversationId the conversation id
245+
* @return the builder
246+
*/
247+
public Builder conversationId(String conversationId) {
248+
this.conversationId = conversationId;
249+
return this;
250+
}
251+
252+
public Builder scheduler(Scheduler scheduler) {
253+
this.scheduler = scheduler;
254+
return this;
255+
}
256+
257+
/**
258+
* Set the order.
259+
* @param order the order
260+
* @return the builder
261+
*/
262+
public Builder order(int order) {
263+
this.order = order;
264+
return this;
265+
}
266+
267+
/**
268+
* Build the advisor.
269+
* @return the advisor
270+
*/
195271
public VectorStoreChatMemoryAdvisor build() {
196-
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
197-
this.protectFromBlocking, this.systemPromptTemplate, this.order);
272+
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, this.conversationId,
273+
this.order, this.scheduler, this.vectorStore);
198274
}
199275

200276
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package org.springframework.ai.chat.client.advisor.vectorstore;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.mockito.Mockito;
5+
import org.springframework.ai.vectorstore.VectorStore;
6+
7+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
8+
9+
/**
10+
* Unit tests for {@link VectorStoreChatMemoryAdvisor}.
11+
*
12+
* @author Thomas Vitale
13+
*/
14+
class VectorStoreChatMemoryAdvisorTests {
15+
16+
@Test
17+
void whenVectorStoreIsNullThenThrow() {
18+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(null).build())
19+
.isInstanceOf(IllegalArgumentException.class)
20+
.hasMessageContaining("vectorStore cannot be null");
21+
}
22+
23+
@Test
24+
void whenDefaultConversationIdIsNullThenThrow() {
25+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
26+
27+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).conversationId(null).build())
28+
.isInstanceOf(IllegalArgumentException.class)
29+
.hasMessageContaining("defaultConversationId cannot be null or empty");
30+
}
31+
32+
@Test
33+
void whenDefaultConversationIdIsEmptyThenThrow() {
34+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
35+
36+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).conversationId(null).build())
37+
.isInstanceOf(IllegalArgumentException.class)
38+
.hasMessageContaining("defaultConversationId cannot be null or empty");
39+
}
40+
41+
@Test
42+
void whenSchedulerIsNullThenThrow() {
43+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
44+
45+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).scheduler(null).build())
46+
.isInstanceOf(IllegalArgumentException.class)
47+
.hasMessageContaining("scheduler cannot be null");
48+
}
49+
50+
@Test
51+
void whenSystemPromptTemplateIsNullThenThrow() {
52+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
53+
54+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).systemPromptTemplate(null).build())
55+
.isInstanceOf(IllegalArgumentException.class)
56+
.hasMessageContaining("systemPromptTemplate cannot be null");
57+
}
58+
59+
@Test
60+
void whenDefaultTopKIsZeroThenThrow() {
61+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
62+
63+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultTopK(0).build())
64+
.isInstanceOf(IllegalArgumentException.class)
65+
.hasMessageContaining("topK must be greater than 0");
66+
}
67+
68+
@Test
69+
void whenDefaultTopKIsNegativeThenThrow() {
70+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
71+
72+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultTopK(-1).build())
73+
.isInstanceOf(IllegalArgumentException.class)
74+
.hasMessageContaining("topK must be greater than 0");
75+
}
76+
77+
}

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
<dependency>
2727
<groupId>org.springframework.ai</groupId>
28-
<artifactId>spring-ai-model-chat-memory-cassandra</artifactId>
28+
<artifactId>spring-ai-model-chat-memory-repository-cassandra</artifactId>
2929
<version>${project.parent.version}</version>
3030
</dependency>
3131

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import com.datastax.oss.driver.api.core.CqlSession;
2020

21-
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepositoryConfig;
22-
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository;
21+
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
22+
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepository;
2323
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
2424
import org.springframework.boot.autoconfigure.AutoConfiguration;
2525
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
@@ -49,8 +49,7 @@ public CassandraChatMemoryRepository cassandraChatMemoryRepository(
4949

5050
builder = builder.withKeyspaceName(properties.getKeyspace())
5151
.withTableName(properties.getTable())
52-
.withAssistantColumnName(properties.getAssistantColumn())
53-
.withUserColumnName(properties.getUserColumn());
52+
.withMessagesColumnName(properties.getMessagesColumn());
5453

5554
if (!properties.isInitializeSchema()) {
5655
builder = builder.disallowSchemaChanges();

0 commit comments

Comments
 (0)