Skip to content

Commit a1b1be6

Browse files
committed
Request-time filter expressions for RAG
When using the RetrievalAugmentationAdvisor with the VectorStoreDocumentRetriever, it’s now possible to provide a filter expression at request-time as an advisor context variable with key VectorStoreDocumentRetriever.FILTER_EXPRESSION. Fixes gh-1776 Signed-off-by: Thomas Vitale <[email protected]>
1 parent bed1db3 commit a1b1be6

File tree

7 files changed

+133
-16
lines changed

7 files changed

+133
-16
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public AdvisedRequest before(AdvisedRequest request) {
106106
Query originalQuery = Query.builder()
107107
.text(new PromptTemplate(request.userText(), request.userParams()).render())
108108
.history(request.messages())
109+
.context(context)
109110
.build();
110111

111112
// 1. Transform original user query based on a chain of query transformers.

spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import org.springframework.ai.vectorstore.SearchRequest;
2525
import org.springframework.ai.vectorstore.VectorStore;
2626
import org.springframework.ai.vectorstore.filter.Filter;
27+
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
2728
import org.springframework.lang.Nullable;
2829
import org.springframework.util.Assert;
30+
import org.springframework.util.StringUtils;
2931

3032
/**
3133
* Retrieves documents from a vector store that are semantically similar to the input
@@ -48,6 +50,8 @@
4850
*/
4951
public final class VectorStoreDocumentRetriever implements DocumentRetriever {
5052

53+
public static final String FILTER_EXPRESSION = "vector_store_filter_expression";
54+
5155
private final VectorStore vectorStore;
5256

5357
private final Double similarityThreshold;
@@ -75,15 +79,24 @@ public VectorStoreDocumentRetriever(VectorStore vectorStore, @Nullable Double si
7579
@Override
7680
public List<Document> retrieve(Query query) {
7781
Assert.notNull(query, "query cannot be null");
82+
var requestFilterExpression = computeRequestFilterExpression(query);
7883
var searchRequest = SearchRequest.builder()
7984
.query(query.text())
80-
.filterExpression(this.filterExpression.get())
85+
.filterExpression(requestFilterExpression)
8186
.similarityThreshold(this.similarityThreshold)
8287
.topK(this.topK)
8388
.build();
8489
return this.vectorStore.similaritySearch(searchRequest);
8590
}
8691

92+
private Filter.Expression computeRequestFilterExpression(Query query) {
93+
var contextFilterExpression = query.context().get(FILTER_EXPRESSION);
94+
if (contextFilterExpression != null && StringUtils.hasText(contextFilterExpression.toString())) {
95+
return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
96+
}
97+
return this.filterExpression.get();
98+
}
99+
87100
public static Builder builder() {
88101
return new Builder();
89102
}

spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,30 @@ void retrieveWithQueryObjectAndDefaultValues() {
210210
assertThat(result).hasSize(2).containsExactlyElementsOf(mockDocuments);
211211
}
212212

213+
@Test
214+
void retrieveWithQueryObjectAndRequestFilterExpression() {
215+
var mockVectorStore = mock(VectorStore.class);
216+
var documentRetriever = VectorStoreDocumentRetriever.builder().vectorStore(mockVectorStore).build();
217+
218+
var query = Query.builder()
219+
.text("test query")
220+
.context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "location == 'Rivendell'"))
221+
.build();
222+
documentRetriever.retrieve(query);
223+
224+
// Verify the mock interaction
225+
var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
226+
verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture());
227+
228+
// Verify the search request
229+
var searchRequest = searchRequestCaptor.getValue();
230+
assertThat(searchRequest.getQuery()).isEqualTo("test query");
231+
assertThat(searchRequest.getSimilarityThreshold()).isEqualTo(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL);
232+
assertThat(searchRequest.getTopK()).isEqualTo(SearchRequest.DEFAULT_TOP_K);
233+
assertThat(searchRequest.getFilterExpression())
234+
.isEqualTo(new FilterExpressionBuilder().eq("location", "Rivendell").build());
235+
}
236+
213237
static final class TenantContextHolder {
214238

215239
private static final ThreadLocal<String> tenantIdentifier = new ThreadLocal<>();

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,12 @@ This filter expression can be configured when creating the `QuestionAnswerAdviso
3939

4040
Here is how to create an instance of `QuestionAnswerAdvisor` where the threshold is `0.8` and to return the top `6` reulsts.
4141

42-
4342
[source,java]
4443
----
4544
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore,
4645
SearchRequest.builder().similarityThreshold(0.8d).topK(6).build());
4746
----
4847

49-
50-
5148
==== Dynamic Filter Expressions
5249

5350
Update the `SearchRequest` filter expression at runtime using the `FILTER_EXPRESSION` advisor context parameter:
@@ -118,6 +115,29 @@ String answer = chatClient.prompt()
118115
.content();
119116
----
120117

118+
The `VectorStoreDocumentRetriever` accepts a `FilterExpression` to filter the search results based on metadata.
119+
You can provide one when instantiating the `VectorStoreDocumentRetriever` or at runtime per request,
120+
using the `FILTER_EXPRESSION` advisor context parameter.
121+
122+
[source,java]
123+
----
124+
Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
125+
.documentRetriever(VectorStoreDocumentRetriever.builder()
126+
.similarityThreshold(0.50)
127+
.vectorStore(vectorStore)
128+
.build())
129+
.build();
130+
131+
String answer = chatClient.prompt()
132+
.advisors(retrievalAugmentationAdvisor)
133+
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "type == 'Spring'"))
134+
.user(question)
135+
.call()
136+
.content();
137+
----
138+
139+
See xref:api/retrieval-augmented-generation.adoc#_vectorstoredocumentretriever for more information.
140+
121141
===== Advanced RAG
122142

123143
[source,java]
@@ -298,6 +318,18 @@ DocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
298318
List<Document> documents = retriever.retrieve(new Query("What are the KPIs for the next semester?"));
299319
----
300320

321+
You can also provide a request-specific filter expression via the `Query` API, using the `FILTER_EXPRESSION` parameter.
322+
If both the request-specific and the retriever-specific filter expressions are provided, the request-specific filter expression takes precedence.
323+
324+
[source,java]
325+
----
326+
Query query = Query.builder()
327+
.text("Who is Anacletus?")
328+
.context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "location == 'Whispering Woods'"))
329+
.build();
330+
List<Document> retrievedDocuments = documentRetriever.retrieve(query);
331+
----
332+
301333
==== Document Join
302334

303335
A component for combining documents retrieved based on multiple queries and from multiple data sources into

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,29 @@ void ragBasic() {
108108
evaluateRelevancy(question, chatResponse);
109109
}
110110

111+
@Test
112+
void ragWithRequestFilter() {
113+
String question = "Where does the adventure of Anacletus and Birba take place?";
114+
115+
RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
116+
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
117+
.build();
118+
119+
ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
120+
.build()
121+
.prompt(question)
122+
.advisors(ragAdvisor)
123+
.advisors(a -> a.param(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "location == 'Italy'"))
124+
.call()
125+
.chatResponse();
126+
127+
assertThat(chatResponse).isNotNull();
128+
// No documents retrieved since the filter expression matches none of the
129+
// documents in the vector store.
130+
assertThat((String) chatResponse.getResult().getMetadata().get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT))
131+
.isNull();
132+
}
133+
111134
@Test
112135
void ragWithCompression() {
113136
MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class RewriteQueryTransformerIT {
4343

4444
@Test
4545
void whenTransformerWithDefaults() {
46-
Query query = new Query("I'm studying machine learning. What is an LLM?");
46+
Query query = new Query("What are the main tourist attractions in L.A.?");
4747
QueryTransformer queryTransformer = RewriteQueryTransformer.builder()
4848
.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
4949
.build();
@@ -52,7 +52,7 @@ void whenTransformerWithDefaults() {
5252

5353
assertThat(transformedQuery).isNotNull();
5454
System.out.println(transformedQuery);
55-
assertThat(transformedQuery.text()).containsIgnoringCase("model");
55+
assertThat(transformedQuery.text()).containsIgnoringCase("Angeles");
5656
}
5757

5858
}

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,21 @@
4646
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
4747
class VectorStoreDocumentRetrieverIT {
4848

49-
private static final Map<String, Document> documents = Map.of("1", new Document(
50-
"Anacletus was a majestic snowy owl with unusually bright golden eyes and distinctive black speckles across his wings.",
51-
Map.of("location", "Whispering Woods")), "2",
52-
new Document(
49+
// @formatter:off
50+
private static final Map<String, Document> documents = Map.of(
51+
"1", new Document(
52+
"Anacletus was a majestic snowy owl with unusually bright golden eyes and distinctive black speckles across his wings.",
53+
Map.of("location", "Whispering Woods")),
54+
"2", new Document(
5355
"Anacletus made his home in an ancient hollow oak tree deep within the Whispering Woods, where local villagers often heard his haunting calls at midnight.",
5456
Map.of("location", "Whispering Woods")),
55-
"3",
56-
new Document(
57+
"3", new Document(
5758
"Despite being a nocturnal hunter like other owls, Anacletus had developed a peculiar habit of collecting shiny objects, especially lost coins and jewelry that glinted in the moonlight.",
5859
Map.of()),
59-
"4",
60-
new Document(
60+
"4", new Document(
6161
"Birba was a plump Siamese cat with mismatched eyes - one blue and one green - who spent her days lounging on velvet cushions and judging everyone with a perpetual look of disdain.",
6262
Map.of("location", "Alfea")));
63+
// @formatter:on
6364

6465
@Autowired
6566
PgVectorStore pgVectorStore;
@@ -75,7 +76,7 @@ void tearDown() {
7576
}
7677

7778
@Test
78-
void withFilter() {
79+
void withBuildFilter() {
7980
DocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder()
8081
.vectorStore(this.pgVectorStore)
8182
.similarityThreshold(0.50)
@@ -95,7 +96,7 @@ void withFilter() {
9596
}
9697

9798
@Test
98-
void withNoFilter() {
99+
void withNoBuildFilter() {
99100
DocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder()
100101
.vectorStore(this.pgVectorStore)
101102
.similarityThreshold(0.50)
@@ -110,4 +111,27 @@ void withNoFilter() {
110111
assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("3").getId()));
111112
}
112113

114+
@Test
115+
void withRequestFilter() {
116+
DocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder()
117+
.vectorStore(this.pgVectorStore)
118+
.similarityThreshold(0.50)
119+
.topK(3)
120+
.build();
121+
122+
Query query = Query.builder()
123+
.text("Who is Anacletus?")
124+
.context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, "location == 'Whispering Woods'"))
125+
.build();
126+
List<Document> retrievedDocuments = documentRetriever.retrieve(query);
127+
128+
assertThat(retrievedDocuments).hasSize(2);
129+
assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("1").getId()));
130+
assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("2").getId()));
131+
132+
// No request filter expression applied, so full access to all documents.
133+
retrievedDocuments = documentRetriever.retrieve(new Query("Who is Birba?"));
134+
assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("4").getId()));
135+
}
136+
113137
}

0 commit comments

Comments
 (0)