Skip to content

Commit 8b1b897

Browse files
committed
Modular RAG: Orchestration and Post-Retrieval
Pre-Retrieval: * Consolidated naming and documentation Retrieval: * Consolidated naming and documentation * Introduced DocumentJoiner sub-module and CompositionDocumentJoiner operator Post-Retrieval: * Introduced main interfaces for sub-modules. Implementation waiting for missing features in Document APIs Orchestration: * Introduced QueryRouter sub-module and AllDocumentRetrieversQueryRouter operator Generation: * Consolidated naming and documentation Advisor: * Introduced BaseAdvisor to reduce boilerplate when implementing Advisors * Extended RetrievalAugmentationAdvisor to include the new sub-modules Relates to #gh-1603
1 parent d2e9e55 commit 8b1b897

File tree

43 files changed

+1171
-244
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1171
-244
lines changed

spring-ai-core/pom.xml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@
9696
<artifactId>micrometer-core</artifactId>
9797
</dependency>
9898

99+
<dependency>
100+
<groupId>io.micrometer</groupId>
101+
<artifactId>context-propagation</artifactId>
102+
</dependency>
103+
99104
<dependency>
100105
<groupId>io.micrometer</groupId>
101106
<artifactId>micrometer-tracing-bridge-otel</artifactId>
@@ -195,4 +200,4 @@
195200
</profiles>
196201

197202

198-
</project>
203+
</project>

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 127 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -16,71 +16,81 @@
1616

1717
package org.springframework.ai.chat.client.advisor;
1818

19-
import java.util.ArrayList;
2019
import java.util.Arrays;
2120
import java.util.HashMap;
2221
import java.util.List;
2322
import java.util.Map;
24-
import java.util.function.Predicate;
25-
26-
import reactor.core.publisher.Flux;
27-
import reactor.core.publisher.Mono;
28-
import reactor.core.scheduler.Schedulers;
23+
import java.util.concurrent.CompletableFuture;
24+
import java.util.stream.Collectors;
2925

3026
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3127
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
32-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
33-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
34-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
35-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
28+
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
3629
import org.springframework.ai.chat.model.ChatResponse;
3730
import org.springframework.ai.chat.prompt.PromptTemplate;
3831
import org.springframework.ai.document.Document;
3932
import org.springframework.ai.rag.Query;
40-
import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer;
41-
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
42-
import org.springframework.ai.rag.augmentation.QueryAugmentor;
33+
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
34+
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
35+
import org.springframework.ai.rag.orchestration.routing.AllRetrieversQueryRouter;
36+
import org.springframework.ai.rag.orchestration.routing.QueryRouter;
37+
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
38+
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
39+
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
40+
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
4341
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
42+
import org.springframework.core.task.TaskExecutor;
43+
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
4444
import org.springframework.lang.Nullable;
45+
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
4546
import org.springframework.util.Assert;
46-
import org.springframework.util.StringUtils;
47+
import reactor.core.scheduler.Scheduler;
4748

4849
/**
4950
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
5051
* building blocks defined in the {@link org.springframework.ai.rag} package and following
5152
* the Modular RAG Architecture.
52-
* <p>
53-
* It's the successor of the {@link QuestionAnswerAdvisor}.
5453
*
5554
* @author Christian Tzolov
5655
* @author Thomas Vitale
5756
* @since 1.0.0
5857
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
5958
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
6059
*/
61-
public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
60+
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6261

6362
public static final String DOCUMENT_CONTEXT = "rag_document_context";
6463

6564
private final List<QueryTransformer> queryTransformers;
6665

67-
private final DocumentRetriever documentRetriever;
66+
@Nullable
67+
private final QueryExpander queryExpander;
68+
69+
private final QueryRouter queryRouter;
70+
71+
private final DocumentJoiner documentJoiner;
72+
73+
private final QueryAugmenter queryAugmenter;
6874

69-
private final QueryAugmentor queryAugmentor;
75+
private final TaskExecutor taskExecutor;
7076

71-
private final boolean protectFromBlocking;
77+
private final Scheduler scheduler;
7278

7379
private final int order;
7480

75-
public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, DocumentRetriever documentRetriever,
76-
@Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
77-
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
81+
public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
82+
@Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentJoiner documentJoiner,
83+
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
84+
@Nullable Integer order) {
85+
Assert.notNull(queryRouter, "queryRouter cannot be null");
7886
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
79-
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
80-
this.queryTransformers = queryTransformers;
81-
this.documentRetriever = documentRetriever;
82-
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
83-
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true;
87+
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
88+
this.queryExpander = queryExpander;
89+
this.queryRouter = queryRouter;
90+
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
91+
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
92+
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
93+
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
8494
this.order = order != null ? order : 0;
8595
}
8696

@@ -89,41 +99,7 @@ public static Builder builder() {
8999
}
90100

91101
@Override
92-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
93-
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
94-
Assert.notNull(chain, "chain cannot be null");
95-
96-
AdvisedRequest processedAdvisedRequest = before(advisedRequest);
97-
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
98-
return after(advisedResponse);
99-
}
100-
101-
@Override
102-
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
103-
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
104-
Assert.notNull(chain, "chain cannot be null");
105-
106-
// This can be executed by both blocking and non-blocking Threads
107-
// E.g. a command line or Tomcat blocking Thread implementation
108-
// or by a WebFlux dispatch in a non-blocking manner.
109-
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
110-
// @formatter:off
111-
Mono.just(advisedRequest)
112-
.publishOn(Schedulers.boundedElastic())
113-
.map(this::before)
114-
.flatMapMany(chain::nextAroundStream)
115-
: chain.nextAroundStream(before(advisedRequest));
116-
// @formatter:on
117-
118-
return advisedResponses.map(ar -> {
119-
if (onFinishReason().test(ar)) {
120-
ar = after(ar);
121-
}
122-
return ar;
123-
});
124-
}
125-
126-
private AdvisedRequest before(AdvisedRequest request) {
102+
public AdvisedRequest before(AdvisedRequest request) {
127103
Map<String, Object> context = new HashMap<>(request.adviseContext());
128104

129105
// 0. Create a query from the user text and parameters.
@@ -135,17 +111,47 @@ private AdvisedRequest before(AdvisedRequest request) {
135111
transformedQuery = queryTransformer.apply(transformedQuery);
136112
}
137113

138-
// 2. Retrieve similar documents for the original query.
139-
List<Document> documents = this.documentRetriever.retrieve(transformedQuery);
114+
// 2. Expand query into one or multiple queries.
115+
List<Query> expandedQueries = queryExpander != null ? queryExpander.expand(transformedQuery)
116+
: List.of(transformedQuery);
117+
118+
// 3. Get similar documents for each query.
119+
Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
120+
.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), taskExecutor))
121+
.toList()
122+
.stream()
123+
.map(CompletableFuture::join)
124+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
125+
126+
// 4. Combine documents retrieved based on multiple queries and from multiple data
127+
// sources.
128+
List<Document> documents = documentJoiner.join(documentsForQuery);
140129
context.put(DOCUMENT_CONTEXT, documents);
141130

142-
// 3. Augment user query with the document contextual data.
143-
Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents);
131+
// 5. Augment user query with the document contextual data.
132+
Query augmentedQuery = queryAugmenter.augment(originalQuery, documents);
144133

134+
// 6. Update advised request with augmented prompt.
145135
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
146136
}
147137

148-
private AdvisedResponse after(AdvisedResponse advisedResponse) {
138+
/**
139+
* Processes a single query by routing it to document retrievers and collecting
140+
* documents.
141+
*/
142+
private Map.Entry<Query, List<List<Document>>> getDocumentsForQuery(Query query) {
143+
List<DocumentRetriever> retrievers = queryRouter.route(query);
144+
List<List<Document>> documents = retrievers.stream()
145+
.map(retriever -> CompletableFuture.supplyAsync(() -> retriever.retrieve(query), taskExecutor))
146+
.toList()
147+
.stream()
148+
.map(CompletableFuture::join)
149+
.toList();
150+
return Map.entry(query, documents);
151+
}
152+
153+
@Override
154+
public AdvisedResponse after(AdvisedResponse advisedResponse) {
149155
ChatResponse.Builder chatResponseBuilder;
150156
if (advisedResponse.response() == null) {
151157
chatResponseBuilder = ChatResponse.builder();
@@ -157,66 +163,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) {
157163
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
158164
}
159165

160-
private Predicate<AdvisedResponse> onFinishReason() {
161-
return advisedResponse -> {
162-
ChatResponse chatResponse = advisedResponse.response();
163-
return chatResponse != null && chatResponse.getResults() != null
164-
&& chatResponse.getResults()
165-
.stream()
166-
.anyMatch(result -> result != null && result.getMetadata() != null
167-
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
168-
};
169-
}
170-
171166
@Override
172-
public String getName() {
173-
return this.getClass().getSimpleName();
167+
public Scheduler getScheduler() {
168+
return scheduler;
174169
}
175170

176171
@Override
177172
public int getOrder() {
178173
return this.order;
179174
}
180175

176+
private static TaskExecutor buildDefaultTaskExecutor() {
177+
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
178+
taskExecutor.setThreadNamePrefix("ai-advisor-");
179+
taskExecutor.setCorePoolSize(4);
180+
taskExecutor.setMaxPoolSize(16);
181+
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
182+
taskExecutor.initialize();
183+
return taskExecutor;
184+
}
185+
181186
public static final class Builder {
182187

183-
private final List<QueryTransformer> queryTransformers = new ArrayList<>();
188+
private List<QueryTransformer> queryTransformers;
189+
190+
private QueryExpander queryExpander;
191+
192+
private QueryRouter queryRouter;
193+
194+
private DocumentJoiner documentJoiner;
184195

185-
private DocumentRetriever documentRetriever;
196+
private QueryAugmenter queryAugmenter;
186197

187-
private QueryAugmentor queryAugmentor;
198+
private TaskExecutor taskExecutor;
188199

189-
private Boolean protectFromBlocking;
200+
private Scheduler scheduler;
190201

191202
private Integer order;
192203

193204
private Builder() {
194205
}
195206

196207
public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
197-
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
198-
this.queryTransformers.addAll(queryTransformers);
208+
this.queryTransformers = queryTransformers;
199209
return this;
200210
}
201211

202212
public Builder queryTransformers(QueryTransformer... queryTransformers) {
203-
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
204-
this.queryTransformers.addAll(Arrays.asList(queryTransformers));
213+
this.queryTransformers = Arrays.asList(queryTransformers);
214+
return this;
215+
}
216+
217+
public Builder queryExpander(QueryExpander queryExpander) {
218+
this.queryExpander = queryExpander;
219+
return this;
220+
}
221+
222+
public Builder queryRouter(QueryRouter queryRouter) {
223+
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
224+
this.queryRouter = queryRouter;
205225
return this;
206226
}
207227

208228
public Builder documentRetriever(DocumentRetriever documentRetriever) {
209-
this.documentRetriever = documentRetriever;
229+
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
230+
this.queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build();
231+
return this;
232+
}
233+
234+
public Builder documentJoiner(DocumentJoiner documentJoiner) {
235+
this.documentJoiner = documentJoiner;
236+
return this;
237+
}
238+
239+
public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
240+
this.queryAugmenter = queryAugmenter;
210241
return this;
211242
}
212243

213-
public Builder queryAugmentor(QueryAugmentor queryAugmentor) {
214-
this.queryAugmentor = queryAugmentor;
244+
public Builder taskExecutor(TaskExecutor taskExecutor) {
245+
this.taskExecutor = taskExecutor;
215246
return this;
216247
}
217248

218-
public Builder protectFromBlocking(Boolean protectFromBlocking) {
219-
this.protectFromBlocking = protectFromBlocking;
249+
public Builder scheduler(Scheduler scheduler) {
250+
this.scheduler = scheduler;
220251
return this;
221252
}
222253

@@ -226,8 +257,8 @@ public Builder order(Integer order) {
226257
}
227258

228259
public RetrievalAugmentationAdvisor build() {
229-
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor,
230-
this.protectFromBlocking, this.order);
260+
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.queryRouter,
261+
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
231262
}
232263

233264
}

0 commit comments

Comments
 (0)