1616
1717package org .springframework .ai .chat .client .advisor ;
1818
19- import java .util .ArrayList ;
2019import java .util .Arrays ;
2120import java .util .HashMap ;
2221import java .util .List ;
2322import java .util .Map ;
24- import java .util .function .Predicate ;
25-
26- import reactor .core .publisher .Flux ;
27- import reactor .core .publisher .Mono ;
28- import reactor .core .scheduler .Schedulers ;
23+ import java .util .concurrent .CompletableFuture ;
24+ import java .util .stream .Collectors ;
2925
3026import org .springframework .ai .chat .client .advisor .api .AdvisedRequest ;
3127import org .springframework .ai .chat .client .advisor .api .AdvisedResponse ;
32- import org .springframework .ai .chat .client .advisor .api .CallAroundAdvisor ;
33- import org .springframework .ai .chat .client .advisor .api .CallAroundAdvisorChain ;
34- import org .springframework .ai .chat .client .advisor .api .StreamAroundAdvisor ;
35- import org .springframework .ai .chat .client .advisor .api .StreamAroundAdvisorChain ;
28+ import org .springframework .ai .chat .client .advisor .api .BaseAdvisor ;
3629import org .springframework .ai .chat .model .ChatResponse ;
3730import org .springframework .ai .chat .prompt .PromptTemplate ;
3831import org .springframework .ai .document .Document ;
3932import org .springframework .ai .rag .Query ;
40- import org .springframework .ai .rag .analysis .query .transformation .QueryTransformer ;
41- import org .springframework .ai .rag .augmentation .ContextualQueryAugmentor ;
42- import org .springframework .ai .rag .augmentation .QueryAugmentor ;
33+ import org .springframework .ai .rag .generation .augmentation .ContextualQueryAugmenter ;
34+ import org .springframework .ai .rag .generation .augmentation .QueryAugmenter ;
35+ import org .springframework .ai .rag .orchestration .routing .AllRetrieversQueryRouter ;
36+ import org .springframework .ai .rag .orchestration .routing .QueryRouter ;
37+ import org .springframework .ai .rag .preretrieval .query .expansion .QueryExpander ;
38+ import org .springframework .ai .rag .preretrieval .query .transformation .QueryTransformer ;
39+ import org .springframework .ai .rag .retrieval .join .ConcatenationDocumentJoiner ;
40+ import org .springframework .ai .rag .retrieval .join .DocumentJoiner ;
4341import org .springframework .ai .rag .retrieval .search .DocumentRetriever ;
42+ import org .springframework .core .task .TaskExecutor ;
43+ import org .springframework .core .task .support .ContextPropagatingTaskDecorator ;
4444import org .springframework .lang .Nullable ;
45+ import org .springframework .scheduling .concurrent .ThreadPoolTaskExecutor ;
4546import org .springframework .util .Assert ;
46- import org . springframework . util . StringUtils ;
47+ import reactor . core . scheduler . Scheduler ;
4748
4849/**
4950 * Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
5051 * building blocks defined in the {@link org.springframework.ai.rag} package and following
5152 * the Modular RAG Architecture.
52- * <p>
53- * It's the successor of the {@link QuestionAnswerAdvisor}.
5453 *
5554 * @author Christian Tzolov
5655 * @author Thomas Vitale
5756 * @since 1.0.0
5857 * @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
5958 * @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
6059 */
61- public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor , StreamAroundAdvisor {
60+ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6261
6362 public static final String DOCUMENT_CONTEXT = "rag_document_context" ;
6463
6564 private final List <QueryTransformer > queryTransformers ;
6665
67- private final DocumentRetriever documentRetriever ;
66+ @ Nullable
67+ private final QueryExpander queryExpander ;
68+
69+ private final QueryRouter queryRouter ;
70+
71+ private final DocumentJoiner documentJoiner ;
72+
73+ private final QueryAugmenter queryAugmenter ;
6874
69- private final QueryAugmentor queryAugmentor ;
75+ private final TaskExecutor taskExecutor ;
7076
71- private final boolean protectFromBlocking ;
77+ private final Scheduler scheduler ;
7278
7379 private final int order ;
7480
75- public RetrievalAugmentationAdvisor (List <QueryTransformer > queryTransformers , DocumentRetriever documentRetriever ,
76- @ Nullable QueryAugmentor queryAugmentor , @ Nullable Boolean protectFromBlocking , @ Nullable Integer order ) {
77- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
81+ public RetrievalAugmentationAdvisor (@ Nullable List <QueryTransformer > queryTransformers ,
82+ @ Nullable QueryExpander queryExpander , QueryRouter queryRouter , @ Nullable DocumentJoiner documentJoiner ,
83+ @ Nullable QueryAugmenter queryAugmenter , @ Nullable TaskExecutor taskExecutor , @ Nullable Scheduler scheduler ,
84+ @ Nullable Integer order ) {
85+ Assert .notNull (queryRouter , "queryRouter cannot be null" );
7886 Assert .noNullElements (queryTransformers , "queryTransformers cannot contain null elements" );
79- Assert .notNull (documentRetriever , "documentRetriever cannot be null" );
80- this .queryTransformers = queryTransformers ;
81- this .documentRetriever = documentRetriever ;
82- this .queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor .builder ().build ();
83- this .protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true ;
87+ this .queryTransformers = queryTransformers != null ? queryTransformers : List .of ();
88+ this .queryExpander = queryExpander ;
89+ this .queryRouter = queryRouter ;
90+ this .documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner ();
91+ this .queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter .builder ().build ();
92+ this .taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor ();
93+ this .scheduler = scheduler != null ? scheduler : BaseAdvisor .DEFAULT_SCHEDULER ;
8494 this .order = order != null ? order : 0 ;
8595 }
8696
@@ -89,41 +99,7 @@ public static Builder builder() {
8999 }
90100
91101 @ Override
92- public AdvisedResponse aroundCall (AdvisedRequest advisedRequest , CallAroundAdvisorChain chain ) {
93- Assert .notNull (advisedRequest , "advisedRequest cannot be null" );
94- Assert .notNull (chain , "chain cannot be null" );
95-
96- AdvisedRequest processedAdvisedRequest = before (advisedRequest );
97- AdvisedResponse advisedResponse = chain .nextAroundCall (processedAdvisedRequest );
98- return after (advisedResponse );
99- }
100-
101- @ Override
102- public Flux <AdvisedResponse > aroundStream (AdvisedRequest advisedRequest , StreamAroundAdvisorChain chain ) {
103- Assert .notNull (advisedRequest , "advisedRequest cannot be null" );
104- Assert .notNull (chain , "chain cannot be null" );
105-
106- // This can be executed by both blocking and non-blocking Threads
107- // E.g. a command line or Tomcat blocking Thread implementation
108- // or by a WebFlux dispatch in a non-blocking manner.
109- Flux <AdvisedResponse > advisedResponses = (this .protectFromBlocking ) ?
110- // @formatter:off
111- Mono .just (advisedRequest )
112- .publishOn (Schedulers .boundedElastic ())
113- .map (this ::before )
114- .flatMapMany (chain ::nextAroundStream )
115- : chain .nextAroundStream (before (advisedRequest ));
116- // @formatter:on
117-
118- return advisedResponses .map (ar -> {
119- if (onFinishReason ().test (ar )) {
120- ar = after (ar );
121- }
122- return ar ;
123- });
124- }
125-
126- private AdvisedRequest before (AdvisedRequest request ) {
102+ public AdvisedRequest before (AdvisedRequest request ) {
127103 Map <String , Object > context = new HashMap <>(request .adviseContext ());
128104
129105 // 0. Create a query from the user text and parameters.
@@ -135,17 +111,47 @@ private AdvisedRequest before(AdvisedRequest request) {
135111 transformedQuery = queryTransformer .apply (transformedQuery );
136112 }
137113
138- // 2. Retrieve similar documents for the original query.
139- List <Document > documents = this .documentRetriever .retrieve (transformedQuery );
114+ // 2. Expand query into one or multiple queries.
115+ List <Query > expandedQueries = queryExpander != null ? queryExpander .expand (transformedQuery )
116+ : List .of (transformedQuery );
117+
118+ // 3. Get similar documents for each query.
119+ Map <Query , List <List <Document >>> documentsForQuery = expandedQueries .stream ()
120+ .map (query -> CompletableFuture .supplyAsync (() -> getDocumentsForQuery (query ), taskExecutor ))
121+ .toList ()
122+ .stream ()
123+ .map (CompletableFuture ::join )
124+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
125+
126+ // 4. Combine documents retrieved based on multiple queries and from multiple data
127+ // sources.
128+ List <Document > documents = documentJoiner .join (documentsForQuery );
140129 context .put (DOCUMENT_CONTEXT , documents );
141130
142- // 3 . Augment user query with the document contextual data.
143- Query augmentedQuery = this . queryAugmentor . augment (transformedQuery , documents );
131+ // 5 . Augment user query with the document contextual data.
132+ Query augmentedQuery = queryAugmenter . augment (originalQuery , documents );
144133
134+ // 6. Update advised request with augmented prompt.
145135 return AdvisedRequest .from (request ).withUserText (augmentedQuery .text ()).withAdviseContext (context ).build ();
146136 }
147137
148- private AdvisedResponse after (AdvisedResponse advisedResponse ) {
138+ /**
139+ * Processes a single query by routing it to document retrievers and collecting
140+ * documents.
141+ */
142+ private Map .Entry <Query , List <List <Document >>> getDocumentsForQuery (Query query ) {
143+ List <DocumentRetriever > retrievers = queryRouter .route (query );
144+ List <List <Document >> documents = retrievers .stream ()
145+ .map (retriever -> CompletableFuture .supplyAsync (() -> retriever .retrieve (query ), taskExecutor ))
146+ .toList ()
147+ .stream ()
148+ .map (CompletableFuture ::join )
149+ .toList ();
150+ return Map .entry (query , documents );
151+ }
152+
153+ @ Override
154+ public AdvisedResponse after (AdvisedResponse advisedResponse ) {
149155 ChatResponse .Builder chatResponseBuilder ;
150156 if (advisedResponse .response () == null ) {
151157 chatResponseBuilder = ChatResponse .builder ();
@@ -157,66 +163,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) {
157163 return new AdvisedResponse (chatResponseBuilder .build (), advisedResponse .adviseContext ());
158164 }
159165
160- private Predicate <AdvisedResponse > onFinishReason () {
161- return advisedResponse -> {
162- ChatResponse chatResponse = advisedResponse .response ();
163- return chatResponse != null && chatResponse .getResults () != null
164- && chatResponse .getResults ()
165- .stream ()
166- .anyMatch (result -> result != null && result .getMetadata () != null
167- && StringUtils .hasText (result .getMetadata ().getFinishReason ()));
168- };
169- }
170-
171166 @ Override
172- public String getName () {
173- return this . getClass (). getSimpleName () ;
167+ public Scheduler getScheduler () {
168+ return scheduler ;
174169 }
175170
176171 @ Override
177172 public int getOrder () {
178173 return this .order ;
179174 }
180175
176+ private static TaskExecutor buildDefaultTaskExecutor () {
177+ ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor ();
178+ taskExecutor .setThreadNamePrefix ("ai-advisor-" );
179+ taskExecutor .setCorePoolSize (4 );
180+ taskExecutor .setMaxPoolSize (16 );
181+ taskExecutor .setTaskDecorator (new ContextPropagatingTaskDecorator ());
182+ taskExecutor .initialize ();
183+ return taskExecutor ;
184+ }
185+
181186 public static final class Builder {
182187
183- private final List <QueryTransformer > queryTransformers = new ArrayList <>();
188+ private List <QueryTransformer > queryTransformers ;
189+
190+ private QueryExpander queryExpander ;
191+
192+ private QueryRouter queryRouter ;
193+
194+ private DocumentJoiner documentJoiner ;
184195
185- private DocumentRetriever documentRetriever ;
196+ private QueryAugmenter queryAugmenter ;
186197
187- private QueryAugmentor queryAugmentor ;
198+ private TaskExecutor taskExecutor ;
188199
189- private Boolean protectFromBlocking ;
200+ private Scheduler scheduler ;
190201
191202 private Integer order ;
192203
193204 private Builder () {
194205 }
195206
196207 public Builder queryTransformers (List <QueryTransformer > queryTransformers ) {
197- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
198- this .queryTransformers .addAll (queryTransformers );
208+ this .queryTransformers = queryTransformers ;
199209 return this ;
200210 }
201211
202212 public Builder queryTransformers (QueryTransformer ... queryTransformers ) {
203- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
204- this .queryTransformers .addAll (Arrays .asList (queryTransformers ));
213+ this .queryTransformers = Arrays .asList (queryTransformers );
214+ return this ;
215+ }
216+
217+ public Builder queryExpander (QueryExpander queryExpander ) {
218+ this .queryExpander = queryExpander ;
219+ return this ;
220+ }
221+
222+ public Builder queryRouter (QueryRouter queryRouter ) {
223+ Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
224+ this .queryRouter = queryRouter ;
205225 return this ;
206226 }
207227
208228 public Builder documentRetriever (DocumentRetriever documentRetriever ) {
209- this .documentRetriever = documentRetriever ;
229+ Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
230+ this .queryRouter = AllRetrieversQueryRouter .builder ().documentRetrievers (documentRetriever ).build ();
231+ return this ;
232+ }
233+
234+ public Builder documentJoiner (DocumentJoiner documentJoiner ) {
235+ this .documentJoiner = documentJoiner ;
236+ return this ;
237+ }
238+
239+ public Builder queryAugmenter (QueryAugmenter queryAugmenter ) {
240+ this .queryAugmenter = queryAugmenter ;
210241 return this ;
211242 }
212243
213- public Builder queryAugmentor ( QueryAugmentor queryAugmentor ) {
214- this .queryAugmentor = queryAugmentor ;
244+ public Builder taskExecutor ( TaskExecutor taskExecutor ) {
245+ this .taskExecutor = taskExecutor ;
215246 return this ;
216247 }
217248
218- public Builder protectFromBlocking ( Boolean protectFromBlocking ) {
219- this .protectFromBlocking = protectFromBlocking ;
249+ public Builder scheduler ( Scheduler scheduler ) {
250+ this .scheduler = scheduler ;
220251 return this ;
221252 }
222253
@@ -226,8 +257,8 @@ public Builder order(Integer order) {
226257 }
227258
228259 public RetrievalAugmentationAdvisor build () {
229- return new RetrievalAugmentationAdvisor (this .queryTransformers , this .documentRetriever , this .queryAugmentor ,
230- this .protectFromBlocking , this .order );
260+ return new RetrievalAugmentationAdvisor (this .queryTransformers , this .queryExpander , this .queryRouter ,
261+ this .documentJoiner , this . queryAugmenter , this . taskExecutor , this . scheduler , this .order );
231262 }
232263
233264 }
0 commit comments