3434import org .springframework .ai .rag .Query ;
3535import org .springframework .ai .rag .generation .augmentation .ContextualQueryAugmenter ;
3636import org .springframework .ai .rag .generation .augmentation .QueryAugmenter ;
37- import org .springframework .ai .rag .orchestration .routing .AllRetrieversQueryRouter ;
38- import org .springframework .ai .rag .orchestration .routing .QueryRouter ;
3937import org .springframework .ai .rag .preretrieval .query .expansion .QueryExpander ;
4038import org .springframework .ai .rag .preretrieval .query .transformation .QueryTransformer ;
4139import org .springframework .ai .rag .retrieval .join .ConcatenationDocumentJoiner ;
5755 * @since 1.0.0
5856 * @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
5957 * @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
58+ * @see <a href="https://export.arxiv.org/abs/2410.20878">arXiv:2410.20878</a>
6059 */
6160public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6261
@@ -67,7 +66,7 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6766 @ Nullable
6867 private final QueryExpander queryExpander ;
6968
70- private final QueryRouter queryRouter ;
69+ private final DocumentRetriever documentRetriever ;
7170
7271 private final DocumentJoiner documentJoiner ;
7372
@@ -80,14 +79,14 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
8079 private final int order ;
8180
8281 public RetrievalAugmentationAdvisor (@ Nullable List <QueryTransformer > queryTransformers ,
83- @ Nullable QueryExpander queryExpander , QueryRouter queryRouter , @ Nullable DocumentJoiner documentJoiner ,
84- @ Nullable QueryAugmenter queryAugmenter , @ Nullable TaskExecutor taskExecutor , @ Nullable Scheduler scheduler ,
85- @ Nullable Integer order ) {
86- Assert .notNull (queryRouter , "queryRouter cannot be null" );
82+ @ Nullable QueryExpander queryExpander , DocumentRetriever documentRetriever ,
83+ @ Nullable DocumentJoiner documentJoiner , @ Nullable QueryAugmenter queryAugmenter ,
84+ @ Nullable TaskExecutor taskExecutor , @ Nullable Scheduler scheduler , @ Nullable Integer order ) {
85+ Assert .notNull (documentRetriever , "documentRetriever cannot be null" );
8786 Assert .noNullElements (queryTransformers , "queryTransformers cannot contain null elements" );
8887 this .queryTransformers = queryTransformers != null ? queryTransformers : List .of ();
8988 this .queryExpander = queryExpander ;
90- this .queryRouter = queryRouter ;
89+ this .documentRetriever = documentRetriever ;
9190 this .documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner ();
9291 this .queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter .builder ().build ();
9392 this .taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor ();
@@ -122,7 +121,7 @@ public AdvisedRequest before(AdvisedRequest request) {
122121 .toList ()
123122 .stream ()
124123 .map (CompletableFuture ::join )
125- .collect (Collectors .toMap (Map .Entry ::getKey , Map . Entry :: getValue ));
124+ .collect (Collectors .toMap (Map .Entry ::getKey , entry -> List . of ( entry . getValue ()) ));
126125
127126 // 4. Combine documents retrieved based on multiple queries and from multiple data
128127 // sources.
@@ -140,14 +139,8 @@ public AdvisedRequest before(AdvisedRequest request) {
140139 * Processes a single query by routing it to document retrievers and collecting
141140 * documents.
142141 */
143- private Map .Entry <Query , List <List <Document >>> getDocumentsForQuery (Query query ) {
144- List <DocumentRetriever > retrievers = this .queryRouter .route (query );
145- List <List <Document >> documents = retrievers .stream ()
146- .map (retriever -> CompletableFuture .supplyAsync (() -> retriever .retrieve (query ), this .taskExecutor ))
147- .toList ()
148- .stream ()
149- .map (CompletableFuture ::join )
150- .toList ();
142+ private Map .Entry <Query , List <Document >> getDocumentsForQuery (Query query ) {
143+ List <Document > documents = documentRetriever .retrieve (query );
151144 return Map .entry (query , documents );
152145 }
153146
@@ -160,7 +153,7 @@ public AdvisedResponse after(AdvisedResponse advisedResponse) {
160153 else {
161154 chatResponseBuilder = ChatResponse .builder ().from (advisedResponse .response ());
162155 }
163- chatResponseBuilder .withMetadata (DOCUMENT_CONTEXT , advisedResponse .adviseContext ().get (DOCUMENT_CONTEXT ));
156+ chatResponseBuilder .metadata (DOCUMENT_CONTEXT , advisedResponse .adviseContext ().get (DOCUMENT_CONTEXT ));
164157 return new AdvisedResponse (chatResponseBuilder .build (), advisedResponse .adviseContext ());
165158 }
166159
@@ -190,7 +183,7 @@ public static final class Builder {
190183
191184 private QueryExpander queryExpander ;
192185
193- private QueryRouter queryRouter ;
186+ private DocumentRetriever documentRetriever ;
194187
195188 private DocumentJoiner documentJoiner ;
196189
@@ -220,15 +213,8 @@ public Builder queryExpander(QueryExpander queryExpander) {
220213 return this ;
221214 }
222215
223- public Builder queryRouter (QueryRouter queryRouter ) {
224- Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
225- this .queryRouter = queryRouter ;
226- return this ;
227- }
228-
229216 public Builder documentRetriever (DocumentRetriever documentRetriever ) {
230- Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
231- this .queryRouter = AllRetrieversQueryRouter .builder ().documentRetrievers (documentRetriever ).build ();
217+ this .documentRetriever = documentRetriever ;
232218 return this ;
233219 }
234220
@@ -258,7 +244,7 @@ public Builder order(Integer order) {
258244 }
259245
260246 public RetrievalAugmentationAdvisor build () {
261- return new RetrievalAugmentationAdvisor (this .queryTransformers , this .queryExpander , this .queryRouter ,
247+ return new RetrievalAugmentationAdvisor (this .queryTransformers , this .queryExpander , this .documentRetriever ,
262248 this .documentJoiner , this .queryAugmenter , this .taskExecutor , this .scheduler , this .order );
263249 }
264250
0 commit comments