diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java
new file mode 100644
index 00000000000..82188b88fa1
--- /dev/null
+++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.chat.client.advisor;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
+
+import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
+import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
+import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
+import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
+import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
+import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.PromptTemplate;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.model.Content;
+import org.springframework.ai.rag.Query;
+import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+/**
+ * This advisor implements common Retrieval Augmented Generation (RAG) flows using the
+ * building blocks defined in the {@link org.springframework.ai.rag} package and following
+ * the Modular RAG Architecture.
+ *
+ * It's the successor of the {@link QuestionAnswerAdvisor}.
+ *
+ * @author Christian Tzolov
+ * @author Thomas Vitale
+ * @since 1.0.0
+ * @see arXiv:2407.21059
+ * @see arXiv:2312.10997
+ */
+public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
+
+ public static final String DOCUMENT_CONTEXT = "rag_document_context";
+
+ public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
+ {query}
+
+ Context information is below. Use this information to answer the user query.
+
+ ---------------------
+ {context}
+ ---------------------
+
+ Given the context and provided history information and not prior knowledge,
+ reply to the user query. If the answer is not in the context, inform
+ the user that you can't answer the query.
+ """);
+
+ private final DocumentRetriever documentRetriever;
+
+ private final PromptTemplate promptTemplate;
+
+ private final boolean protectFromBlocking;
+
+ private final int order;
+
+ public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate,
+ @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
+ Assert.notNull(documentRetriever, "documentRetriever cannot be null");
+ this.documentRetriever = documentRetriever;
+ this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
+ this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
+ this.order = order != null ? order : 0;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ @Override
+ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
+ Assert.notNull(advisedRequest, "advisedRequest cannot be null");
+ Assert.notNull(chain, "chain cannot be null");
+
+ AdvisedRequest processedAdvisedRequest = before(advisedRequest);
+ AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
+ return after(advisedResponse);
+ }
+
+ @Override
+ public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
+ Assert.notNull(advisedRequest, "advisedRequest cannot be null");
+ Assert.notNull(chain, "chain cannot be null");
+
+ // This can be executed by both blocking and non-blocking Threads
+ // E.g. a command line or Tomcat blocking Thread implementation
+ // or by a WebFlux dispatch in a non-blocking manner.
+ Flux advisedResponses = (this.protectFromBlocking) ?
+ // @formatter:off
+ Mono.just(advisedRequest)
+ .publishOn(Schedulers.boundedElastic())
+ .map(this::before)
+ .flatMapMany(chain::nextAroundStream)
+ : chain.nextAroundStream(before(advisedRequest));
+ // @formatter:on
+
+ return advisedResponses.map(ar -> {
+ if (onFinishReason().test(ar)) {
+ ar = after(ar);
+ }
+ return ar;
+ });
+ }
+
+ private AdvisedRequest before(AdvisedRequest request) {
+ Map context = new HashMap<>(request.adviseContext());
+
+ // 0. Create a query from the user text and parameters.
+ Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
+
+ // 1. Retrieve similar documents for the original query.
+ List documents = this.documentRetriever.retrieve(query);
+ context.put(DOCUMENT_CONTEXT, documents);
+
+ // 2. Combine retrieved documents.
+ String documentContext = documents.stream()
+ .map(Content::getContent)
+ .collect(Collectors.joining(System.lineSeparator()));
+
+ // 3. Define augmentation prompt parameters.
+ Map promptParameters = Map.of("query", query.text(), "context", documentContext);
+
+ // 4. Augment user prompt with the context data.
+ UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters);
+
+ return AdvisedRequest.from(request)
+ .withUserText(augmentedUserMessage.getContent())
+ .withAdviseContext(context)
+ .build();
+ }
+
+ private AdvisedResponse after(AdvisedResponse advisedResponse) {
+ ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
+ chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
+ return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
+ }
+
+ private Predicate onFinishReason() {
+ return advisedResponse -> advisedResponse.response()
+ .getResults()
+ .stream()
+ .anyMatch(result -> result != null && result.getMetadata() != null
+ && StringUtils.hasText(result.getMetadata().getFinishReason()));
+ }
+
+ @Override
+ public String getName() {
+ return this.getClass().getSimpleName();
+ }
+
+ @Override
+ public int getOrder() {
+ return this.order;
+ }
+
+ public static final class Builder {
+
+ private DocumentRetriever documentRetriever;
+
+ private PromptTemplate promptTemplate;
+
+ private Boolean protectFromBlocking;
+
+ private Integer order;
+
+ private Builder() {
+ }
+
+ public Builder documentRetriever(DocumentRetriever documentRetriever) {
+ this.documentRetriever = documentRetriever;
+ return this;
+ }
+
+ public Builder promptTemplate(PromptTemplate promptTemplate) {
+ this.promptTemplate = promptTemplate;
+ return this;
+ }
+
+ public Builder protectFromBlocking(Boolean protectFromBlocking) {
+ this.protectFromBlocking = protectFromBlocking;
+ return this;
+ }
+
+ public Builder order(Integer order) {
+ this.order = order;
+ return this;
+ }
+
+ public RetrievalAugmentationAdvisor build() {
+ return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate,
+ this.protectFromBlocking, this.order);
+ }
+
+ }
+
+}
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java
index afca774760f..08e1aa276e7 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java
@@ -34,6 +34,8 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
@@ -41,8 +43,6 @@
* The data of the chat client request that can be modified before the execution of the
* ChatClient's call method
*
- * @author Christian Tzolov
- * @since 1.0.0
* @param chatModel the chat model used
* @param userText the text provided by the user
* @param systemText the text provided by the system
@@ -57,13 +57,53 @@
* @param advisorParams the map of advisor parameters
* @param adviseContext the map of advise context
* @param toolContext the tool context
+ * @author Christian Tzolov
+ * @author Thomas Vitale
+ * @since 1.0.0
*/
-public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions,
- List media, List functionNames, List functionCallbacks, List messages,
- Map userParams, Map systemParams, List advisors,
- Map advisorParams, Map adviseContext, Map toolContext) {
+public record AdvisedRequest(
+// @formatter:off
+ ChatModel chatModel,
+ String userText,
+ @Nullable
+ String systemText,
+ @Nullable
+ ChatOptions chatOptions,
+ List media,
+ List functionNames,
+ List functionCallbacks,
+ List messages,
+ Map userParams,
+ Map systemParams,
+ List advisors,
+ Map advisorParams,
+ Map adviseContext,
+ Map toolContext
+// @formatter:on
+) {
+
+ public AdvisedRequest {
+ Assert.notNull(chatModel, "chatModel cannot be null");
+ Assert.hasText(userText, "userText cannot be null or empty");
+ Assert.notNull(media, "media cannot be null");
+ Assert.notNull(functionNames, "functionNames cannot be null");
+ Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
+ Assert.notNull(messages, "messages cannot be null");
+ Assert.notNull(userParams, "userParams cannot be null");
+ Assert.notNull(systemParams, "systemParams cannot be null");
+ Assert.notNull(advisors, "advisors cannot be null");
+ Assert.notNull(advisorParams, "advisorParams cannot be null");
+ Assert.notNull(adviseContext, "adviseContext cannot be null");
+ Assert.notNull(toolContext, "toolContext cannot be null");
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
public static Builder from(AdvisedRequest from) {
+ Assert.notNull(from, "AdvisedRequest cannot be null");
+
Builder builder = new Builder();
builder.chatModel = from.chatModel;
builder.userText = from.userText;
@@ -79,23 +119,18 @@ public static Builder from(AdvisedRequest from) {
builder.advisorParams = from.advisorParams;
builder.adviseContext = from.adviseContext;
builder.toolContext = from.toolContext;
-
return builder;
}
- public static Builder builder() {
- return new Builder();
- }
-
public AdvisedRequest updateContext(Function