Skip to content

Commit bbaa9b4

Browse files
committed
Support DocumentPostProcessors in RAG Advisor
The DocumentPostProcessor is one of the modular RAG components introduce in M8. You can now use this API from within the RetrievalAugmentationAdvisor to post-process the retrieved documents before passing them to the model. For example, you can use such an interface to perform re-ranking of the retrieved documents based on their relevance to the query, remove irrelevant or redundant documents, or compress the content of each document to reduce noise and redundancy. Signed-off-by: Thomas Vitale <[email protected]>
1 parent 09a6a6e commit bbaa9b4

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ PromptTemplate customPromptTemplate = PromptTemplate.builder()
119119

120120
NOTE: The `QuestionAnswerAdvisor.Builder.userTextAdvise()` method is deprecated in favor of using `.promptTemplate()` for more flexible customization.
121121

122-
=== RetrievalAugmentationAdvisor (Incubating)
122+
=== RetrievalAugmentationAdvisor
123123

124124
Spring AI includes a xref:api/retrieval-augmented-generation.adoc#modules[library of RAG modules] that you can use to build your own RAG flows.
125125
The `RetrievalAugmentationAdvisor` is an `Advisor` providing an out-of-the-box implementation for the most common RAG flows,
@@ -211,6 +211,8 @@ String answer = chatClient.prompt()
211211
.content();
212212
----
213213

214+
You can also use the `DocumentPostProcessor` API to post-process the retrieved documents before passing them to the model. For example, you can use such an interface to perform re-ranking of the retrieved documents based on their relevance to the query, remove irrelevant or redundant documents, or compress the content of each document to reduce noise and redundancy.
215+
214216
[[modules]]
215217
== Modules
216218

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,32 @@ void ragWithMultiQuery() {
261261
evaluateRelevancy(question, chatResponse);
262262
}
263263

264+
@Test
265+
void ragWithDocumentPostProcessor() {
266+
String question = "Where does the adventure of Anacletus and Birba take place?";
267+
268+
RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
269+
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
270+
.documentPostProcessors((query, documents) -> List
271+
.of(Document.builder().text("The adventure of Anacletus and Birba takes place in Molise").build()))
272+
.build();
273+
274+
ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
275+
.build()
276+
.prompt(question)
277+
.advisors(ragAdvisor)
278+
.call()
279+
.chatResponse();
280+
281+
assertThat(chatResponse).isNotNull();
282+
283+
String response = chatResponse.getResult().getOutput().getText();
284+
System.out.println(response);
285+
assertThat(response).containsIgnoringCase("Molise");
286+
287+
evaluateRelevancy(question, chatResponse);
288+
}
289+
264290
private void evaluateRelevancy(String question, ChatResponse chatResponse) {
265291
EvaluationRequest evaluationRequest = new EvaluationRequest(question,
266292
chatResponse.getMetadata().get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT),

spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.rag.Query;
3535
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
3636
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
37+
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
3738
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
3839
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
3940
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
@@ -70,6 +71,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
7071

7172
private final DocumentJoiner documentJoiner;
7273

74+
private final List<DocumentPostProcessor> documentPostProcessors;
75+
7376
private final QueryAugmenter queryAugmenter;
7477

7578
private final TaskExecutor taskExecutor;
@@ -80,14 +83,16 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
8083

8184
private RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
8285
@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
83-
@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
84-
@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
86+
@Nullable DocumentJoiner documentJoiner, @Nullable List<DocumentPostProcessor> documentPostProcessors,
87+
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
88+
@Nullable Integer order) {
8589
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
8690
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
8791
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
8892
this.queryExpander = queryExpander;
8993
this.documentRetriever = documentRetriever;
9094
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
95+
this.documentPostProcessors = documentPostProcessors != null ? documentPostProcessors : List.of();
9196
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
9297
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
9398
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
@@ -130,6 +135,11 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable A
130135
// 4. Combine documents retrieved based on multiple queries and from multiple data
131136
// sources.
132137
List<Document> documents = this.documentJoiner.join(documentsForQuery);
138+
139+
// 5. Post-process the documents.
140+
for (var documentPostProcessor : this.documentPostProcessors) {
141+
documents = documentPostProcessor.process(originalQuery, documents);
142+
}
133143
context.put(DOCUMENT_CONTEXT, documents);
134144

135145
// 5. Augment user query with the document contextual data.
@@ -197,6 +207,8 @@ public static final class Builder {
197207

198208
private DocumentJoiner documentJoiner;
199209

210+
private List<DocumentPostProcessor> documentPostProcessors;
211+
200212
private QueryAugmenter queryAugmenter;
201213

202214
private TaskExecutor taskExecutor;
@@ -209,11 +221,14 @@ private Builder() {
209221
}
210222

211223
public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
224+
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
212225
this.queryTransformers = queryTransformers;
213226
return this;
214227
}
215228

216229
public Builder queryTransformers(QueryTransformer... queryTransformers) {
230+
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
231+
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
217232
this.queryTransformers = Arrays.asList(queryTransformers);
218233
return this;
219234
}
@@ -233,6 +248,19 @@ public Builder documentJoiner(DocumentJoiner documentJoiner) {
233248
return this;
234249
}
235250

251+
public Builder documentPostProcessors(List<DocumentPostProcessor> documentPostProcessors) {
252+
Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
253+
this.documentPostProcessors = documentPostProcessors;
254+
return this;
255+
}
256+
257+
public Builder documentPostProcessors(DocumentPostProcessor... documentPostProcessors) {
258+
Assert.notNull(documentPostProcessors, "documentPostProcessors cannot be null");
259+
Assert.noNullElements(documentPostProcessors, "documentPostProcessors cannot contain null elements");
260+
this.documentPostProcessors = Arrays.asList(documentPostProcessors);
261+
return this;
262+
}
263+
236264
public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
237265
this.queryAugmenter = queryAugmenter;
238266
return this;
@@ -255,7 +283,8 @@ public Builder order(Integer order) {
255283

256284
public RetrievalAugmentationAdvisor build() {
257285
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever,
258-
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
286+
this.documentJoiner, this.documentPostProcessors, this.queryAugmenter, this.taskExecutor,
287+
this.scheduler, this.order);
259288
}
260289

261290
}

0 commit comments

Comments
 (0)