Skip to content

Commit 002beef

Browse files
committed
support setting a similarity threshold in VectorStoreChatMemoryAdvisor
1 parent 3919204 commit 002beef

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor
6464

6565
private static final int DEFAULT_TOP_K = 20;
6666

67+
private static final double DEFAULT_SIMILARITY_THRESHOLD = 0;
68+
6769
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6870
{instructions}
6971
@@ -79,6 +81,8 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor
7981

8082
private final int defaultTopK;
8183

84+
private final double defaultSimilarityThreshold;
85+
8286
private final String defaultConversationId;
8387

8488
private final int order;
@@ -88,14 +92,17 @@ public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor
8892
private final VectorStore vectorStore;
8993

9094
private VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultTopK,
91-
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
95+
double defaultSimilarityThreshold, String defaultConversationId, int order, Scheduler scheduler,
96+
VectorStore vectorStore) {
9297
Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null");
9398
Assert.isTrue(defaultTopK > 0, "topK must be greater than 0");
99+
Assert.isTrue(defaultSimilarityThreshold >= 0, "similarityThreshold must be equal to or greater than 0");
94100
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
95101
Assert.notNull(scheduler, "scheduler cannot be null");
96102
Assert.notNull(vectorStore, "vectorStore cannot be null");
97103
this.systemPromptTemplate = systemPromptTemplate;
98104
this.defaultTopK = defaultTopK;
105+
this.defaultSimilarityThreshold = defaultSimilarityThreshold;
99106
this.defaultConversationId = defaultConversationId;
100107
this.order = order;
101108
this.scheduler = scheduler;
@@ -221,6 +228,8 @@ public static class Builder {
221228

222229
private Integer defaultTopK = DEFAULT_TOP_K;
223230

231+
private Double defaultSimilarityThreshold = DEFAULT_SIMILARITY_THRESHOLD;
232+
224233
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
225234

226235
private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER;
@@ -257,6 +266,17 @@ public Builder defaultTopK(int defaultTopK) {
257266
return this;
258267
}
259268

269+
/**
270+
* Set the similarity threshold for retrieving relevant documents.
271+
* @param defaultSimilarityThreshold the required similarity for documents to
272+
* retrieve
273+
* @return this builder
274+
*/
275+
public Builder defaultSimilarityThreshold(Double defaultSimilarityThreshold) {
276+
this.defaultSimilarityThreshold = defaultSimilarityThreshold;
277+
return this;
278+
}
279+
260280
/**
261281
* Set the conversation id.
262282
* @param conversationId the conversation id
@@ -287,8 +307,8 @@ public Builder order(int order) {
287307
* @return the advisor
288308
*/
289309
public VectorStoreChatMemoryAdvisor build() {
290-
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, this.conversationId,
291-
this.order, this.scheduler, this.vectorStore);
310+
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK,
311+
this.defaultSimilarityThreshold, this.conversationId, this.order, this.scheduler, this.vectorStore);
292312
}
293313

294314
}

advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisorTests.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,14 @@ void whenDefaultTopKIsNegativeThenThrow() {
9191
.hasMessageContaining("topK must be greater than 0");
9292
}
9393

94+
@Test
95+
void whenDefaultSimilarityThresholdIsLessThanZeroThenThrow() {
96+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
97+
98+
assertThatThrownBy(
99+
() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultSimilarityThreshold(-0.1).build())
100+
.isInstanceOf(IllegalArgumentException.class)
101+
.hasMessageContaining("similarityThreshold must be equal to or greater than 0");
102+
}
103+
94104
}

0 commit comments

Comments
 (0)