1616
1717package org .springframework .ai .chat .client .advisor ;
1818
19- import java .util .ArrayList ;
2019import java .util .Arrays ;
2120import java .util .HashMap ;
2221import java .util .List ;
2322import java .util .Map ;
24- import java .util .function .Predicate ;
23+ import java .util .concurrent .CompletableFuture ;
24+ import java .util .stream .Collectors ;
2525
26- import reactor .core .publisher .Flux ;
27- import reactor .core .publisher .Mono ;
28- import reactor .core .scheduler .Schedulers ;
26+ import reactor .core .scheduler .Scheduler ;
2927
3028import org .springframework .ai .chat .client .advisor .api .AdvisedRequest ;
3129import org .springframework .ai .chat .client .advisor .api .AdvisedResponse ;
32- import org .springframework .ai .chat .client .advisor .api .CallAroundAdvisor ;
33- import org .springframework .ai .chat .client .advisor .api .CallAroundAdvisorChain ;
34- import org .springframework .ai .chat .client .advisor .api .StreamAroundAdvisor ;
35- import org .springframework .ai .chat .client .advisor .api .StreamAroundAdvisorChain ;
30+ import org .springframework .ai .chat .client .advisor .api .BaseAdvisor ;
3631import org .springframework .ai .chat .model .ChatResponse ;
3732import org .springframework .ai .chat .prompt .PromptTemplate ;
3833import org .springframework .ai .document .Document ;
3934import org .springframework .ai .rag .Query ;
40- import org .springframework .ai .rag .analysis .query .transformation .QueryTransformer ;
41- import org .springframework .ai .rag .augmentation .ContextualQueryAugmentor ;
42- import org .springframework .ai .rag .augmentation .QueryAugmentor ;
35+ import org .springframework .ai .rag .generation .augmentation .ContextualQueryAugmenter ;
36+ import org .springframework .ai .rag .generation .augmentation .QueryAugmenter ;
37+ import org .springframework .ai .rag .orchestration .routing .AllRetrieversQueryRouter ;
38+ import org .springframework .ai .rag .orchestration .routing .QueryRouter ;
39+ import org .springframework .ai .rag .preretrieval .query .expansion .QueryExpander ;
40+ import org .springframework .ai .rag .preretrieval .query .transformation .QueryTransformer ;
41+ import org .springframework .ai .rag .retrieval .join .ConcatenationDocumentJoiner ;
42+ import org .springframework .ai .rag .retrieval .join .DocumentJoiner ;
4343import org .springframework .ai .rag .retrieval .search .DocumentRetriever ;
44+ import org .springframework .core .task .TaskExecutor ;
45+ import org .springframework .core .task .support .ContextPropagatingTaskDecorator ;
4446import org .springframework .lang .Nullable ;
47+ import org .springframework .scheduling .concurrent .ThreadPoolTaskExecutor ;
4548import org .springframework .util .Assert ;
46- import org .springframework .util .StringUtils ;
4749
4850/**
4951 * Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
5052 * building blocks defined in the {@link org.springframework.ai.rag} package and following
5153 * the Modular RAG Architecture.
52- * <p>
53- * It's the successor of the {@link QuestionAnswerAdvisor}.
5454 *
5555 * @author Christian Tzolov
5656 * @author Thomas Vitale
5757 * @since 1.0.0
5858 * @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
5959 * @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
6060 */
61- public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor , StreamAroundAdvisor {
61+ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6262
6363 public static final String DOCUMENT_CONTEXT = "rag_document_context" ;
6464
6565 private final List <QueryTransformer > queryTransformers ;
6666
67- private final DocumentRetriever documentRetriever ;
67+ @ Nullable
68+ private final QueryExpander queryExpander ;
6869
69- private final QueryAugmentor queryAugmentor ;
70+ private final QueryRouter queryRouter ;
7071
71- private final boolean protectFromBlocking ;
72+ private final DocumentJoiner documentJoiner ;
73+
74+ private final QueryAugmenter queryAugmenter ;
75+
76+ private final TaskExecutor taskExecutor ;
77+
78+ private final Scheduler scheduler ;
7279
7380 private final int order ;
7481
75- public RetrievalAugmentationAdvisor (List <QueryTransformer > queryTransformers , DocumentRetriever documentRetriever ,
76- @ Nullable QueryAugmentor queryAugmentor , @ Nullable Boolean protectFromBlocking , @ Nullable Integer order ) {
77- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
82+ public RetrievalAugmentationAdvisor (@ Nullable List <QueryTransformer > queryTransformers ,
83+ @ Nullable QueryExpander queryExpander , QueryRouter queryRouter , @ Nullable DocumentJoiner documentJoiner ,
84+ @ Nullable QueryAugmenter queryAugmenter , @ Nullable TaskExecutor taskExecutor , @ Nullable Scheduler scheduler ,
85+ @ Nullable Integer order ) {
86+ Assert .notNull (queryRouter , "queryRouter cannot be null" );
7887 Assert .noNullElements (queryTransformers , "queryTransformers cannot contain null elements" );
79- Assert .notNull (documentRetriever , "documentRetriever cannot be null" );
80- this .queryTransformers = queryTransformers ;
81- this .documentRetriever = documentRetriever ;
82- this .queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor .builder ().build ();
83- this .protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true ;
88+ this .queryTransformers = queryTransformers != null ? queryTransformers : List .of ();
89+ this .queryExpander = queryExpander ;
90+ this .queryRouter = queryRouter ;
91+ this .documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner ();
92+ this .queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter .builder ().build ();
93+ this .taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor ();
94+ this .scheduler = scheduler != null ? scheduler : BaseAdvisor .DEFAULT_SCHEDULER ;
8495 this .order = order != null ? order : 0 ;
8596 }
8697
@@ -89,41 +100,7 @@ public static Builder builder() {
89100 }
90101
91102 @ Override
92- public AdvisedResponse aroundCall (AdvisedRequest advisedRequest , CallAroundAdvisorChain chain ) {
93- Assert .notNull (advisedRequest , "advisedRequest cannot be null" );
94- Assert .notNull (chain , "chain cannot be null" );
95-
96- AdvisedRequest processedAdvisedRequest = before (advisedRequest );
97- AdvisedResponse advisedResponse = chain .nextAroundCall (processedAdvisedRequest );
98- return after (advisedResponse );
99- }
100-
101- @ Override
102- public Flux <AdvisedResponse > aroundStream (AdvisedRequest advisedRequest , StreamAroundAdvisorChain chain ) {
103- Assert .notNull (advisedRequest , "advisedRequest cannot be null" );
104- Assert .notNull (chain , "chain cannot be null" );
105-
106- // This can be executed by both blocking and non-blocking Threads
107- // E.g. a command line or Tomcat blocking Thread implementation
108- // or by a WebFlux dispatch in a non-blocking manner.
109- Flux <AdvisedResponse > advisedResponses = (this .protectFromBlocking ) ?
110- // @formatter:off
111- Mono .just (advisedRequest )
112- .publishOn (Schedulers .boundedElastic ())
113- .map (this ::before )
114- .flatMapMany (chain ::nextAroundStream )
115- : chain .nextAroundStream (before (advisedRequest ));
116- // @formatter:on
117-
118- return advisedResponses .map (ar -> {
119- if (onFinishReason ().test (ar )) {
120- ar = after (ar );
121- }
122- return ar ;
123- });
124- }
125-
126- private AdvisedRequest before (AdvisedRequest request ) {
103+ public AdvisedRequest before (AdvisedRequest request ) {
127104 Map <String , Object > context = new HashMap <>(request .adviseContext ());
128105
129106 // 0. Create a query from the user text and parameters.
@@ -135,17 +112,47 @@ private AdvisedRequest before(AdvisedRequest request) {
135112 transformedQuery = queryTransformer .apply (transformedQuery );
136113 }
137114
138- // 2. Retrieve similar documents for the original query.
139- List <Document > documents = this .documentRetriever .retrieve (transformedQuery );
115+ // 2. Expand query into one or multiple queries.
116+ List <Query > expandedQueries = this .queryExpander != null ? this .queryExpander .expand (transformedQuery )
117+ : List .of (transformedQuery );
118+
119+ // 3. Get similar documents for each query.
120+ Map <Query , List <List <Document >>> documentsForQuery = expandedQueries .stream ()
121+ .map (query -> CompletableFuture .supplyAsync (() -> getDocumentsForQuery (query ), this .taskExecutor ))
122+ .toList ()
123+ .stream ()
124+ .map (CompletableFuture ::join )
125+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
126+
127+ // 4. Combine documents retrieved based on multiple queries and from multiple data
128+ // sources.
129+ List <Document > documents = this .documentJoiner .join (documentsForQuery );
140130 context .put (DOCUMENT_CONTEXT , documents );
141131
142- // 3 . Augment user query with the document contextual data.
143- Query augmentedQuery = this .queryAugmentor .augment (transformedQuery , documents );
132+ // 5 . Augment user query with the document contextual data.
133+ Query augmentedQuery = this .queryAugmenter .augment (originalQuery , documents );
144134
135+ // 6. Update advised request with augmented prompt.
145136 return AdvisedRequest .from (request ).withUserText (augmentedQuery .text ()).withAdviseContext (context ).build ();
146137 }
147138
148- private AdvisedResponse after (AdvisedResponse advisedResponse ) {
139+ /**
140+ * Processes a single query by routing it to document retrievers and collecting
141+ * documents.
142+ */
143+ private Map .Entry <Query , List <List <Document >>> getDocumentsForQuery (Query query ) {
144+ List <DocumentRetriever > retrievers = this .queryRouter .route (query );
145+ List <List <Document >> documents = retrievers .stream ()
146+ .map (retriever -> CompletableFuture .supplyAsync (() -> retriever .retrieve (query ), this .taskExecutor ))
147+ .toList ()
148+ .stream ()
149+ .map (CompletableFuture ::join )
150+ .toList ();
151+ return Map .entry (query , documents );
152+ }
153+
154+ @ Override
155+ public AdvisedResponse after (AdvisedResponse advisedResponse ) {
149156 ChatResponse .Builder chatResponseBuilder ;
150157 if (advisedResponse .response () == null ) {
151158 chatResponseBuilder = ChatResponse .builder ();
@@ -157,66 +164,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) {
157164 return new AdvisedResponse (chatResponseBuilder .build (), advisedResponse .adviseContext ());
158165 }
159166
160- private Predicate <AdvisedResponse > onFinishReason () {
161- return advisedResponse -> {
162- ChatResponse chatResponse = advisedResponse .response ();
163- return chatResponse != null && chatResponse .getResults () != null
164- && chatResponse .getResults ()
165- .stream ()
166- .anyMatch (result -> result != null && result .getMetadata () != null
167- && StringUtils .hasText (result .getMetadata ().getFinishReason ()));
168- };
169- }
170-
171167 @ Override
172- public String getName () {
173- return this .getClass (). getSimpleName () ;
168+ public Scheduler getScheduler () {
169+ return this .scheduler ;
174170 }
175171
176172 @ Override
177173 public int getOrder () {
178174 return this .order ;
179175 }
180176
177+ private static TaskExecutor buildDefaultTaskExecutor () {
178+ ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor ();
179+ taskExecutor .setThreadNamePrefix ("ai-advisor-" );
180+ taskExecutor .setCorePoolSize (4 );
181+ taskExecutor .setMaxPoolSize (16 );
182+ taskExecutor .setTaskDecorator (new ContextPropagatingTaskDecorator ());
183+ taskExecutor .initialize ();
184+ return taskExecutor ;
185+ }
186+
181187 public static final class Builder {
182188
183- private final List <QueryTransformer > queryTransformers = new ArrayList <>() ;
189+ private List <QueryTransformer > queryTransformers ;
184190
185- private DocumentRetriever documentRetriever ;
191+ private QueryExpander queryExpander ;
186192
187- private QueryAugmentor queryAugmentor ;
193+ private QueryRouter queryRouter ;
188194
189- private Boolean protectFromBlocking ;
195+ private DocumentJoiner documentJoiner ;
196+
197+ private QueryAugmenter queryAugmenter ;
198+
199+ private TaskExecutor taskExecutor ;
200+
201+ private Scheduler scheduler ;
190202
191203 private Integer order ;
192204
193205 private Builder () {
194206 }
195207
196208 public Builder queryTransformers (List <QueryTransformer > queryTransformers ) {
197- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
198- this .queryTransformers .addAll (queryTransformers );
209+ this .queryTransformers = queryTransformers ;
199210 return this ;
200211 }
201212
202213 public Builder queryTransformers (QueryTransformer ... queryTransformers ) {
203- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
204- this .queryTransformers .addAll (Arrays .asList (queryTransformers ));
214+ this .queryTransformers = Arrays .asList (queryTransformers );
215+ return this ;
216+ }
217+
218+ public Builder queryExpander (QueryExpander queryExpander ) {
219+ this .queryExpander = queryExpander ;
220+ return this ;
221+ }
222+
223+ public Builder queryRouter (QueryRouter queryRouter ) {
224+ Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
225+ this .queryRouter = queryRouter ;
205226 return this ;
206227 }
207228
208229 public Builder documentRetriever (DocumentRetriever documentRetriever ) {
209- this .documentRetriever = documentRetriever ;
230+ Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
231+ this .queryRouter = AllRetrieversQueryRouter .builder ().documentRetrievers (documentRetriever ).build ();
232+ return this ;
233+ }
234+
235+ public Builder documentJoiner (DocumentJoiner documentJoiner ) {
236+ this .documentJoiner = documentJoiner ;
237+ return this ;
238+ }
239+
240+ public Builder queryAugmenter (QueryAugmenter queryAugmenter ) {
241+ this .queryAugmenter = queryAugmenter ;
210242 return this ;
211243 }
212244
213- public Builder queryAugmentor ( QueryAugmentor queryAugmentor ) {
214- this .queryAugmentor = queryAugmentor ;
245+ public Builder taskExecutor ( TaskExecutor taskExecutor ) {
246+ this .taskExecutor = taskExecutor ;
215247 return this ;
216248 }
217249
218- public Builder protectFromBlocking ( Boolean protectFromBlocking ) {
219- this .protectFromBlocking = protectFromBlocking ;
250+ public Builder scheduler ( Scheduler scheduler ) {
251+ this .scheduler = scheduler ;
220252 return this ;
221253 }
222254
@@ -226,8 +258,8 @@ public Builder order(Integer order) {
226258 }
227259
228260 public RetrievalAugmentationAdvisor build () {
229- return new RetrievalAugmentationAdvisor (this .queryTransformers , this .documentRetriever , this .queryAugmentor ,
230- this .protectFromBlocking , this .order );
261+ return new RetrievalAugmentationAdvisor (this .queryTransformers , this .queryExpander , this .queryRouter ,
262+ this .documentJoiner , this . queryAugmenter , this . taskExecutor , this . scheduler , this .order );
231263 }
232264
233265 }
0 commit comments