Skip to content

Commit c2e65ff

Browse files
committed
support setting similarity threshold in VectorStoreChatMemoryAdvisor
Signed-off-by: Peter Keeler <[email protected]>
1 parent 3919204 commit c2e65ff

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
*/
5757
public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
5858

59+
public static final String SIMILARITY_THRESHOLD = "chat_memory_vector_store_similarity_threshold";
60+
5961
public static final String TOP_K = "chat_memory_vector_store_top_k";
6062

6163
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
@@ -64,6 +66,8 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor
6466

6567
private static final int DEFAULT_TOP_K = 20;
6668

69+
private static final double DEFAULT_SIMILARITY_THRESHOLD = 0;
70+
6771
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6872
{instructions}
6973
@@ -79,6 +83,8 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor
7983

8084
private final int defaultTopK;
8185

86+
private final double defaultSimilarityThreshold;
87+
8288
private final String defaultConversationId;
8389

8490
private final int order;
@@ -88,14 +94,17 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor
8894
private final VectorStore vectorStore;
8995

9096
private VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultTopK,
91-
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
97+
double defaultSimilarityThreshold, String defaultConversationId, int order, Scheduler scheduler,
98+
VectorStore vectorStore) {
9299
Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null");
93100
Assert.isTrue(defaultTopK > 0, "topK must be greater than 0");
101+
Assert.isTrue(defaultSimilarityThreshold >= 0, "similarityThreshold must be equal to or greater than 0");
94102
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
95103
Assert.notNull(scheduler, "scheduler cannot be null");
96104
Assert.notNull(vectorStore, "vectorStore cannot be null");
97105
this.systemPromptTemplate = systemPromptTemplate;
98106
this.defaultTopK = defaultTopK;
107+
this.defaultSimilarityThreshold = defaultSimilarityThreshold;
99108
this.defaultConversationId = defaultConversationId;
100109
this.order = order;
101110
this.scheduler = scheduler;
@@ -121,10 +130,12 @@ public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorC
121130
String conversationId = getConversationId(request.context(), this.defaultConversationId);
122131
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
123132
int topK = getChatMemoryTopK(request.context());
133+
double similarityThreshold = getChatMemorySimilarityThreshold(request.context());
124134
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
125135
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
126136
.query(query)
127137
.topK(topK)
138+
.similarityThreshold(similarityThreshold)
128139
.filterExpression(filter)
129140
.build();
130141
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
@@ -156,6 +167,11 @@ private int getChatMemoryTopK(Map<String, Object> context) {
156167
return context.containsKey(TOP_K) ? Integer.parseInt(context.get(TOP_K).toString()) : this.defaultTopK;
157168
}
158169

170+
private double getChatMemorySimilarityThreshold(Map<String, Object> context) {
171+
return context.containsKey(SIMILARITY_THRESHOLD)
172+
? Double.parseDouble(context.get(SIMILARITY_THRESHOLD).toString()) : this.defaultSimilarityThreshold;
173+
}
174+
159175
@Override
160176
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
161177
List<Message> assistantMessages = new ArrayList<>();
@@ -221,6 +237,8 @@ public static class Builder {
221237

222238
private Integer defaultTopK = DEFAULT_TOP_K;
223239

240+
private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD;
241+
224242
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
225243

226244
private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER;
@@ -257,6 +275,17 @@ public Builder defaultTopK(int defaultTopK) {
257275
return this;
258276
}
259277

278+
/**
279+
* Set the similarity threshold for retrieving relevant documents.
280+
* @param defaultSimilarityThreshold the required similarity for documents to
281+
* retrieve
282+
* @return this builder
283+
*/
284+
public Builder defaultSimilarityThreshold(Double defaultSimilarityThreshold) {
285+
this.defaultSimilarityThreshold = defaultSimilarityThreshold;
286+
return this;
287+
}
288+
260289
/**
261290
* Set the conversation id.
262291
* @param conversationId the conversation id
@@ -287,8 +316,8 @@ public Builder order(int order) {
287316
* @return the advisor
288317
*/
289318
public VectorStoreChatMemoryAdvisor build() {
290-
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, this.conversationId,
291-
this.order, this.scheduler, this.vectorStore);
319+
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK,
320+
this.defaultSimilarityThreshold, this.conversationId, this.order, this.scheduler, this.vectorStore);
292321
}
293322

294323
}

advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisorTests.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,14 @@ void whenDefaultTopKIsNegativeThenThrow() {
9191
.hasMessageContaining("topK must be greater than 0");
9292
}
9393

94+
@Test
95+
void whenDefaultSimilarityThresholdIsLessThanZeroThenThrow() {
96+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
97+
98+
assertThatThrownBy(
99+
() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultSimilarityThreshold(-0.1).build())
100+
.isInstanceOf(IllegalArgumentException.class)
101+
.hasMessageContaining("similarityThreshold must be equal to or greater than 0");
102+
}
103+
94104
}

0 commit comments

Comments
 (0)