diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java index a76fc55c258..cc0cf17e649 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java @@ -45,6 +45,11 @@ * List documents = retriever.retrieve(new Query("example query")); * } * + *

+ * The {@link #FILTER_EXPRESSION} context key can be used to provide a filter expression + * for a specific query. This key accepts either a string representation of a filter + * expression or a {@link Filter.Expression} object directly. + * * @author Thomas Vitale * @since 1.0.0 */ @@ -89,10 +94,27 @@ public List retrieve(Query query) { return this.vectorStore.similaritySearch(searchRequest); } + /** + * Computes the filter expression to use for the current request. + *

+ * The filter expression can be provided in the query context using the + * {@link #FILTER_EXPRESSION} key. This key accepts either a string representation of + * a filter expression or a {@link Filter.Expression} object directly. + *

+ * If no filter expression is provided in the context, the default filter expression + * configured for this retriever is used. + * @param query the query containing potential context with filter expression + * @return the filter expression to use for the request + */ private Filter.Expression computeRequestFilterExpression(Query query) { var contextFilterExpression = query.context().get(FILTER_EXPRESSION); - if (contextFilterExpression != null && StringUtils.hasText(contextFilterExpression.toString())) { - return new FilterExpressionTextParser().parse(contextFilterExpression.toString()); + if (contextFilterExpression != null) { + if (contextFilterExpression instanceof Filter.Expression) { + return (Filter.Expression) contextFilterExpression; + } + else if (StringUtils.hasText(contextFilterExpression.toString())) { + return new FilterExpressionTextParser().parse(contextFilterExpression.toString()); + } } return this.filterExpression.get(); } diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java index bbcc7f65602..fbd52950774 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java @@ -234,6 +234,32 @@ void retrieveWithQueryObjectAndRequestFilterExpression() { .isEqualTo(new FilterExpressionBuilder().eq("location", "Rivendell").build()); } + @Test + void retrieveWithQueryObjectAndFilterExpressionObject() { + var mockVectorStore = mock(VectorStore.class); + var documentRetriever = VectorStoreDocumentRetriever.builder().vectorStore(mockVectorStore).build(); + + // Create a Filter.Expression object directly + var filterExpression = new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Rivendell")); + + var query = Query.builder() + .text("test query") + .context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filterExpression)) + .build(); + documentRetriever.retrieve(query); + + // Verify the mock interaction + var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); + verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture()); + + // Verify the search request + var searchRequest = searchRequestCaptor.getValue(); + assertThat(searchRequest.getQuery()).isEqualTo("test query"); + assertThat(searchRequest.getSimilarityThreshold()).isEqualTo(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); + assertThat(searchRequest.getTopK()).isEqualTo(SearchRequest.DEFAULT_TOP_K); + assertThat(searchRequest.getFilterExpression()).isEqualTo(filterExpression); + } + static final class TenantContextHolder { private static final ThreadLocal tenantIdentifier = new ThreadLocal<>();