diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index fae7c8d785f..a9f78f985c9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -34,8 +34,6 @@ import org.springframework.ai.rag.Query; import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; import org.springframework.ai.rag.generation.augmentation.QueryAugmenter; -import org.springframework.ai.rag.orchestration.routing.AllRetrieversQueryRouter; -import org.springframework.ai.rag.orchestration.routing.QueryRouter; import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander; import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner; @@ -57,6 +55,7 @@ * @since 1.0.0 * @see arXiv:2407.21059 * @see arXiv:2312.10997 + * @see arXiv:2410.20878 */ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { @@ -67,7 +66,7 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { @Nullable private final QueryExpander queryExpander; - private final QueryRouter queryRouter; + private final DocumentRetriever documentRetriever; private final DocumentJoiner documentJoiner; @@ -80,14 +79,14 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { private final int order; public RetrievalAugmentationAdvisor(@Nullable List queryTransformers, - @Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentJoiner documentJoiner, - @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, - @Nullable Integer order) { - Assert.notNull(queryRouter, "queryRouter cannot be null"); + @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, + @Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter, + @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) { + Assert.notNull(documentRetriever, "documentRetriever cannot be null"); Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); this.queryTransformers = queryTransformers != null ? queryTransformers : List.of(); this.queryExpander = queryExpander; - this.queryRouter = queryRouter; + this.documentRetriever = documentRetriever; this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner(); this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build(); this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor(); @@ -122,7 +121,7 @@ public AdvisedRequest before(AdvisedRequest request) { .toList() .stream() .map(CompletableFuture::join) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + .collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue()))); // 4. Combine documents retrieved based on multiple queries and from multiple data // sources. @@ -140,14 +139,8 @@ public AdvisedRequest before(AdvisedRequest request) { * Processes a single query by routing it to document retrievers and collecting * documents. */ - private Map.Entry>> getDocumentsForQuery(Query query) { - List retrievers = this.queryRouter.route(query); - List> documents = retrievers.stream() - .map(retriever -> CompletableFuture.supplyAsync(() -> retriever.retrieve(query), this.taskExecutor)) - .toList() - .stream() - .map(CompletableFuture::join) - .toList(); + private Map.Entry> getDocumentsForQuery(Query query) { + List documents = documentRetriever.retrieve(query); return Map.entry(query, documents); } @@ -160,7 +153,7 @@ public AdvisedResponse after(AdvisedResponse advisedResponse) { else { chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); } - chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT)); + chatResponseBuilder.metadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT)); return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); } @@ -190,7 +183,7 @@ public static final class Builder { private QueryExpander queryExpander; - private QueryRouter queryRouter; + private DocumentRetriever documentRetriever; private DocumentJoiner documentJoiner; @@ -220,15 +213,8 @@ public Builder queryExpander(QueryExpander queryExpander) { return this; } - public Builder queryRouter(QueryRouter queryRouter) { - Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter"); - this.queryRouter = queryRouter; - return this; - } - public Builder documentRetriever(DocumentRetriever documentRetriever) { - Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter"); - this.queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build(); + this.documentRetriever = documentRetriever; return this; } @@ -258,7 +244,7 @@ public Builder order(Integer order) { } public RetrievalAugmentationAdvisor build() { - return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.queryRouter, + return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java index 6e3d6100925..d6b5e2447ba 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java @@ -16,17 +16,78 @@ package org.springframework.ai.rag; +import org.springframework.ai.chat.messages.Message; import org.springframework.util.Assert; +import java.util.List; +import java.util.Map; + /** * Represents a query in the context of a Retrieval Augmented Generation (RAG) flow. * * @param text the text of the query + * @param history the messages in the conversation history + * @param context the context of the query * @author Thomas Vitale * @since 1.0.0 */ -public record Query(String text) { +public record Query(String text, List history, Map context) { + public Query { Assert.hasText(text, "text cannot be null or empty"); + Assert.notNull(history, "history cannot be null"); + Assert.noNullElements(history, "history elements cannot be null"); + Assert.notNull(context, "context cannot be null"); + Assert.noNullElements(context.keySet(), "context keys cannot be null"); + } + + public Query(String text) { + this(text, List.of(), Map.of()); + } + + public Builder mutate() { + return new Builder().text(this.text).history(this.history).context(this.context); } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String text; + + private List history = List.of(); + + private Map context = Map.of(); + + private Builder() { + } + + public Builder text(String text) { + this.text = text; + return this; + } + + public Builder history(List history) { + this.history = history; + return this; + } + + public Builder history(Message... history) { + this.history = List.of(history); + return this; + } + + public Builder context(Map context) { + this.context = context; + return this; + } + + public Query build() { + return new Query(text, history, context); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/package-info.java deleted file mode 100644 index 7ef0db0979e..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/package-info.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * RAG Module: Orchestration. - *

- * This package includes components for controlling the execution flow in a Retrieval - * Augmented Generation system. - */ -@NonNullApi -@NonNullFields -package org.springframework.ai.rag.orchestration; - -import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/AllRetrieversQueryRouter.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/AllRetrieversQueryRouter.java deleted file mode 100644 index bf18848480c..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/AllRetrieversQueryRouter.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.rag.orchestration.routing; - -import java.util.Arrays; -import java.util.List; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.ai.rag.Query; -import org.springframework.ai.rag.retrieval.search.DocumentRetriever; -import org.springframework.util.Assert; - -/** - * Routes a query to all the defined document retrievers. - * - * @author Thomas Vitale - * @since 1.0.0 - */ -public class AllRetrieversQueryRouter implements QueryRouter { - - private static final Logger logger = LoggerFactory.getLogger(AllRetrieversQueryRouter.class); - - private final List documentRetrievers; - - public AllRetrieversQueryRouter(List documentRetrievers) { - Assert.notEmpty(documentRetrievers, "documentRetrievers cannot be null or empty"); - Assert.noNullElements(documentRetrievers, "documentRetrievers cannot contain null elements"); - this.documentRetrievers = documentRetrievers; - } - - @Override - public List route(Query query) { - Assert.notNull(query, "query cannot be null"); - logger.debug("Routing query to all document retrievers"); - return this.documentRetrievers; - } - - public static Builder builder() { - return new Builder(); - } - - public final static class Builder { - - private List documentRetrievers; - - private Builder() { - } - - public Builder documentRetrievers(DocumentRetriever... documentRetrievers) { - this.documentRetrievers = Arrays.asList(documentRetrievers); - return this; - } - - public Builder documentRetrievers(List documentRetrievers) { - this.documentRetrievers = documentRetrievers; - return this; - } - - public AllRetrieversQueryRouter build() { - return new AllRetrieversQueryRouter(this.documentRetrievers); - } - - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/QueryRouter.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/QueryRouter.java deleted file mode 100644 index e8b6e34b058..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/QueryRouter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.rag.orchestration.routing; - -import java.util.List; -import java.util.function.Function; - -import org.springframework.ai.rag.Query; -import org.springframework.ai.rag.retrieval.join.DocumentJoiner; -import org.springframework.ai.rag.retrieval.search.DocumentRetriever; - -/** - * A component for routing a query to one or more document retrievers. It provides a - * decision-making mechanism to support various scenarios and making the Retrieval - * Augmented Generation flow more flexible and extensible. It can be used to implement - * routing strategies using metadata, large language models, tools (the foundation of - * Agentic RAG), and other techniques. - *

- * When retrieving documents from multiple sources, you'll need to join the results before - * concluding the retrieval stage. For this purpose, you can use the - * {@link DocumentJoiner}. - * - * @author Thomas Vitale - * @since 1.0.0 - */ -public interface QueryRouter extends Function> { - - /** - * Routes a query to one or more document retrievers. - * @param query the query to route - * @return a list of document retrievers - */ - List route(Query query); - - default List apply(Query query) { - return route(query); - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/package-info.java deleted file mode 100644 index 59a8a597f40..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/orchestration/routing/package-info.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * RAG Sub-Module: Query Router. - */ -@NonNullApi -@NonNullFields -package org.springframework.ai.rag.orchestration.routing; - -import org.springframework.lang.NonNullApi; -import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java index b7061763599..5026b47e710 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java @@ -23,6 +23,7 @@ * * @see arXiv:2407.21059 * @see arXiv:2312.10997 + * @see arXiv:2410.20878 */ @NonNullApi @NonNullFields diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpander.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpander.java index 08f9f874b52..aa64f46cf0d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpander.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpander.java @@ -120,7 +120,10 @@ public List expand(Query query) { return List.of(query); } - var queries = queryVariants.stream().filter(StringUtils::hasText).map(Query::new).collect(Collectors.toList()); + var queries = queryVariants.stream() + .filter(StringUtils::hasText) + .map(queryText -> query.mutate().text(queryText).build()) + .collect(Collectors.toList()); if (this.includeOriginal) { logger.debug("Including the original query in the result"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/QueryExpander.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/QueryExpander.java index bb1d5d44ef5..379daa657c1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/QueryExpander.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/expansion/QueryExpander.java @@ -24,7 +24,7 @@ /** * A component for expanding the input query into a list of queries, addressing challenges * such as poorly formed queries by providing alternative query formulations, or by - * breaking down complex problems into simpler sub-queries, + * breaking down complex problems into simpler sub-queries. * * @author Thomas Vitale * @since 1.0.0 diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java new file mode 100644 index 00000000000..71d868ed6f5 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformer.java @@ -0,0 +1,140 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.preretrieval.query.transformation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.rag.Query; +import org.springframework.ai.util.PromptAssert; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Uses a large language model to compress a conversation history and a follow-up query + * into a standalone query that captures the essence of the conversation. + *

+ * This transformer is useful when the conversation history is long and the follow-up + * query is related to the conversation context. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class CompressionQueryTransformer implements QueryTransformer { + + private static final Logger logger = LoggerFactory.getLogger(CompressionQueryTransformer.class); + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + Given the following conversation history and a follow-up query, your task is to synthesize + a concise, standalone query that incorporates the context from the history. + Ensure the standalone query is clear, specific, and maintains the user's intent. + + Conversation history: + {history} + + Follow-up query: + {query} + + Standalone query: + """); + + private final ChatClient chatClient; + + private final PromptTemplate promptTemplate; + + public CompressionQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); + + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "history", "query"); + } + + @Override + public Query transform(Query query) { + Assert.notNull(query, "query cannot be null"); + + logger.debug("Compressing conversation history and follow-up query into a standalone query"); + + var compressedQueryText = this.chatClient.prompt() + .user(user -> user.text(this.promptTemplate.getTemplate()) + .param("history", formatConversationHistory(query.history())) + .param("query", query.text())) + .options(ChatOptions.builder().temperature(0.0).build()) + .call() + .content(); + + if (!StringUtils.hasText(compressedQueryText)) { + logger.warn("Query compression result is null/empty. Returning the input query unchanged."); + return query; + } + + return query.mutate().text(compressedQueryText).build(); + } + + private String formatConversationHistory(List history) { + if (history.isEmpty()) { + return ""; + } + + return history.stream() + .filter(message -> message.getMessageType().equals(MessageType.USER) + || message.getMessageType().equals(MessageType.ASSISTANT)) + .map(message -> "%s: %s".formatted(message.getMessageType(), message.getText())) + .collect(Collectors.joining("\n")); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private ChatClient.Builder chatClientBuilder; + + @Nullable + private PromptTemplate promptTemplate; + + private Builder() { + } + + public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) { + this.chatClientBuilder = chatClientBuilder; + return this; + } + + public Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public CompressionQueryTransformer build() { + return new CompressionQueryTransformer(chatClientBuilder, promptTemplate); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java new file mode 100644 index 00000000000..30da5265ce2 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java @@ -0,0 +1,134 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.preretrieval.query.transformation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.rag.Query; +import org.springframework.ai.util.PromptAssert; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Uses a large language model to rewrite a user query to provide better results when + * querying a target system, such as a vector store or a web search engine. + *

+ * This transformer is useful when the user query is verbose, ambiguous, or contains + * irrelevant information that may affect the quality of the search results. + * + * @author Thomas Vitale + * @since 1.0.0 + * @see arXiv:2305.14283 + */ +public class RewriteQueryTransformer implements QueryTransformer { + + private static final Logger logger = LoggerFactory.getLogger(RewriteQueryTransformer.class); + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + Given a user query, rewrite it to provide better results when querying a {target}. + Remove any irrelevant information, and ensure the query is concise and specific. + + Original query: + {query} + + Rewritten query: + """); + + private static final String DEFAULT_TARGET = "vector store"; + + private final ChatClient chatClient; + + private final PromptTemplate promptTemplate; + + private final String targetSearchSystem; + + public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, + @Nullable String targetSearchSystem) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); + + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET; + + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + } + + @Override + public Query transform(Query query) { + Assert.notNull(query, "query cannot be null"); + + logger.debug("Rewriting query to optimize for querying a {}.", this.targetSearchSystem); + + var rewrittenQueryText = this.chatClient.prompt() + .user(user -> user.text(this.promptTemplate.getTemplate()) + .param("target", targetSearchSystem) + .param("query", query.text())) + .options(ChatOptions.builder().temperature(0.0).build()) + .call() + .content(); + + if (!StringUtils.hasText(rewrittenQueryText)) { + logger.warn("Query rewrite result is null/empty. Returning the input query unchanged."); + return query; + } + + return query.mutate().text(rewrittenQueryText).build(); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private ChatClient.Builder chatClientBuilder; + + @Nullable + private PromptTemplate promptTemplate; + + @Nullable + private String targetSearchSystem; + + private Builder() { + } + + public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) { + this.chatClientBuilder = chatClientBuilder; + return this; + } + + public Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder targetSearchSystem(String targetSearchSystem) { + this.targetSearchSystem = targetSearchSystem; + return this; + } + + public RewriteQueryTransformer build() { + return new RewriteQueryTransformer(chatClientBuilder, promptTemplate, targetSearchSystem); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java index 37ceffcc0da..e155069b822 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/TranslationQueryTransformer.java @@ -18,7 +18,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.PromptTemplate; @@ -87,7 +86,7 @@ public Query transform(Query query) { logger.debug("Translating query to target language: {}", this.targetLanguage); - var translatedQuery = this.chatClient.prompt() + var translatedQueryText = this.chatClient.prompt() .user(user -> user.text(this.promptTemplate.getTemplate()) .param("targetLanguage", this.targetLanguage) .param("query", query.text())) @@ -95,12 +94,12 @@ public Query transform(Query query) { .call() .content(); - if (!StringUtils.hasText(translatedQuery)) { + if (!StringUtils.hasText(translatedQueryText)) { logger.warn("Query translation result is null/empty. Returning the input query unchanged."); return query; } - return new Query(translatedQuery); + return query.mutate().text(translatedQueryText).build(); } public static Builder builder() { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java index a78bb169336..e697b576a2e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java @@ -16,11 +16,8 @@ package org.springframework.ai.chat.client.advisor; -import java.util.List; - import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; - import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatModel; @@ -32,6 +29,8 @@ import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; import org.springframework.ai.rag.retrieval.search.DocumentRetriever; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.BDDMockito.given; @@ -54,10 +53,10 @@ void whenQueryTransformersContainNullElementsThenThrow() { } @Test - void whenQueryRouterIsNullThenThrow() { - assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder().queryRouter(null).build()) + void whenDocumentRetrieverIsNullThenThrow() { + assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder().documentRetriever(null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("queryRouter cannot be null"); + .hasMessageContaining("documentRetriever cannot be null"); } @Test diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/orchestration/routing/AllRetrieversQueryRouterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/orchestration/routing/AllRetrieversQueryRouterTests.java deleted file mode 100644 index 639c5c2ef9b..00000000000 --- a/spring-ai-core/src/test/java/org/springframework/ai/rag/orchestration/routing/AllRetrieversQueryRouterTests.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.rag.orchestration.routing; - -import java.util.ArrayList; -import java.util.List; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.rag.Query; -import org.springframework.ai.rag.retrieval.search.DocumentRetriever; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; - -/** - * Unit tests for {@link AllRetrieversQueryRouter}. - * - * @author Thomas Vitale - */ -class AllRetrieversQueryRouterTests { - - @Test - void whenDocumentRetrieversIsNullThenThrow() { - assertThatThrownBy( - () -> AllRetrieversQueryRouter.builder().documentRetrievers((List) null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("documentRetrievers cannot be null or empty"); - } - - @Test - void whenDocumentRetrieversIsEmptyThenThrow() { - assertThatThrownBy(() -> AllRetrieversQueryRouter.builder().documentRetrievers(List.of()).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("documentRetrievers cannot be null or empty"); - } - - @Test - void whenDocumentRetrieversContainsNullKeysThenThrow() { - var documentRetrievers = new ArrayList(); - documentRetrievers.add(null); - assertThatThrownBy(() -> AllRetrieversQueryRouter.builder().documentRetrievers(documentRetrievers).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("documentRetrievers cannot contain null elements"); - } - - @Test - void whenQueryIsNullThenThrow() { - DocumentRetriever documentRetriever = mock(DocumentRetriever.class); - QueryRouter queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build(); - assertThatThrownBy(() -> queryRouter.route(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("query cannot be null"); - } - - @Test - void routeToAllRetrievers() { - DocumentRetriever documentRetriever1 = mock(DocumentRetriever.class); - DocumentRetriever documentRetriever2 = mock(DocumentRetriever.class); - QueryRouter queryRouter = AllRetrieversQueryRouter.builder() - .documentRetrievers(documentRetriever1, documentRetriever2) - .build(); - List selectedDocumentRetrievers = queryRouter.route(new Query("test")); - assertThat(selectedDocumentRetrievers).containsAll(List.of(documentRetriever1, documentRetriever2)); - } - -} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java new file mode 100644 index 00000000000..aec1e3ddf42 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/CompressionQueryTransformerTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.preretrieval.query.transformation; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.PromptTemplate; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link CompressionQueryTransformer}. + * + * @author Thomas Vitale + */ +class CompressionQueryTransformerTests { + + @Test + void whenChatClientBuilderIsNullThenThrow() { + assertThatThrownBy(() -> CompressionQueryTransformer.builder().chatClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chatClientBuilder cannot be null"); + } + + @Test + void whenQueryIsNullThenThrow() { + QueryTransformer queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .build(); + assertThatThrownBy(() -> queryTransformer.transform(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("query cannot be null"); + } + + @Test + void whenPromptHasMissingHistoryPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("Compress {query}"); + assertThatThrownBy(() -> CompressionQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("history"); + } + + @Test + void whenPromptHasMissingQueryPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("Compress {history}"); + assertThatThrownBy(() -> CompressionQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("query"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java new file mode 100644 index 00000000000..099f5060219 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.preretrieval.query.transformation; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.PromptTemplate; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link RewriteQueryTransformer}. + * + * @author Thomas Vitale + */ +class RewriteQueryTransformerTests { + + @Test + void whenChatClientBuilderIsNullThenThrow() { + assertThatThrownBy(() -> RewriteQueryTransformer.builder().chatClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chatClientBuilder cannot be null"); + } + + @Test + void whenQueryIsNullThenThrow() { + QueryTransformer queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .build(); + assertThatThrownBy(() -> queryTransformer.transform(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("query cannot be null"); + } + + @Test + void whenPromptHasMissingTargetPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}"); + assertThatThrownBy(() -> RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("vector store") + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("target"); + } + + @Test + void whenPromptHasMissingQueryPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite for {target}"); + assertThatThrownBy(() -> RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("search engine") + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("query"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java index d4ea8fa12ce..5769c7922d5 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/join/ConcatenationDocumentJoinerTests.java @@ -21,7 +21,6 @@ import java.util.Map; import org.junit.jupiter.api.Test; - import org.springframework.ai.document.Document; import org.springframework.ai.rag.Query; @@ -82,15 +81,14 @@ void whenDuplicatedDocumentsThenOnlyFirstOccurrenceIsKept() { documentsForQuery.put(new Query("query1"), List.of(List.of(new Document("1", "Content 1", Map.of()), new Document("2", "Content 2", Map.of())), List.of(new Document("3", "Content 3", Map.of())))); - documentsForQuery.put(new Query("query2"), List - .of(List.of(new Document("2", "Content 2 Duplicate", Map.of()), new Document("4", "Content 4", Map.of())))); + documentsForQuery.put(new Query("query2"), + List.of(List.of(new Document("2", "Content 2", Map.of()), new Document("4", "Content 4", Map.of())))); List result = documentJoiner.join(documentsForQuery); assertThat(result).hasSize(4); assertThat(result).extracting(Document::getId).containsExactlyInAnyOrder("1", "2", "3", "4"); - assertThat(result).extracting(Document::getText).contains("Content 2"); - assertThat(result).extracting(Document::getText).doesNotContain("Content 2 Duplicate"); + assertThat(result).extracting(Document::getText).containsOnlyOnce("Content 2"); } } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 8686d67749c..d4234afe76e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -96,6 +96,7 @@ ** xref:api/vectordbs/typesense.adoc[] ** xref:api/vectordbs/weaviate.adoc[] +* xref:api/retrieval-augmented-generation.adoc[] * xref:observability/index.adoc[] * xref:api/prompt.adoc[] * xref:api/structured-output-converter.adoc[Structured Output] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index c89f428ef6b..fab86cece3a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -366,45 +366,7 @@ In this configuration, the `MessageChatMemoryAdvisor` will be executed first, ad === Retrieval Augmented Generation -A vector database stores data that the AI model is unaware of. -When a user question is sent to the AI model, a `QuestionAnswerAdvisor` queries the vector database for documents related to the user question. - -The response from the vector database is appended to the user text to provide context for the AI model to generate a response. - -Assuming you have already loaded data into a `VectorStore`, you can perform Retrieval Augmented Generation (RAG) by providing an instance of `QuestionAnswerAdvisor` to the `ChatClient`. - -[source,java] ----- -ChatResponse response = ChatClient.builder(chatModel) - .build().prompt() - .advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) - .user(userText) - .call() - .chatResponse(); ----- - -In this example, the `SearchRequest.defaults()` will perform a similarity search over all documents in the Vector Database. -To restrict the types of documents that are searched, the `SearchRequest` takes an SQL like filter expression that is portable across all `VectorStores`. - -==== Dynamic Filter Expressions - -Update the `SearchRequest` filter expression at runtime using the `FILTER_EXPRESSION` advisor context parameter: - -[source,java] ----- -ChatClient chatClient = ChatClient.builder(chatModel) - .defaultAdvisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) - .build(); - -// Update filter expression at runtime -String content = this.chatClient.prompt() - .user("Please answer my question XYZ") - .advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'")) - .call() - .content(); ----- - -The `FILTER_EXPRESSION` parameter allows you to dynamically filter the search results based on the provided expression. +Refer to the xref:_retrieval_augmented_generation[Retrieval Augmented Generation] guide. === Chat Memory diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc new file mode 100644 index 00000000000..b9edd059edc --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc @@ -0,0 +1,368 @@ +[[rag]] += Retrieval Augmented Generation + +Retrieval Augmented Generation (RAG) is a technique useful to overcome the limitations of large language models +that struggle with long-form content, factual accuracy, and context-awareness. + +Spring AI supports RAG by providing a modular architecture that allows you to build custom RAG flows yourself +or use out-of-the-box RAG flows using the `Advisor` API. + +NOTE: Learn more about Retrieval Augmented Generation in the xref:concepts.adoc#concept-rag[concepts] section. + +== Advisors + +Spring AI provides out-of-the-box support for common RAG flows using the `Advisor` API. + +=== QuestionAnswerAdvisor + +A vector database stores data that the AI model is unaware of. +When a user question is sent to the AI model, a `QuestionAnswerAdvisor` queries the vector database for documents related to the user question. + +The response from the vector database is appended to the user text to provide context for the AI model to generate a response. + +Assuming you have already loaded data into a `VectorStore`, you can perform Retrieval Augmented Generation (RAG) by providing an instance of `QuestionAnswerAdvisor` to the `ChatClient`. + +[source,java] +---- +ChatResponse response = ChatClient.builder(chatModel) + .build().prompt() + .advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) + .user(userText) + .call() + .chatResponse(); +---- + +In this example, the `SearchRequest.defaults()` will perform a similarity search over all documents in the Vector Database. +To restrict the types of documents that are searched, the `SearchRequest` takes an SQL like filter expression that is portable across all `VectorStores`. + +==== Dynamic Filter Expressions + +Update the `SearchRequest` filter expression at runtime using the `FILTER_EXPRESSION` advisor context parameter: + +[source,java] +---- +ChatClient chatClient = ChatClient.builder(chatModel) + .defaultAdvisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) + .build(); + +// Update filter expression at runtime +String content = this.chatClient.prompt() + .user("Please answer my question XYZ") + .advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'")) + .call() + .content(); +---- + +The `FILTER_EXPRESSION` parameter allows you to dynamically filter the search results based on the provided expression. + +=== RetrievalAugmentationAdvisor (Incubating) + +Spring AI includes a xref:api/retrieval-augmented-generation.adoc#modules[library of RAG modules] that you can use to build your own RAG flows. +The `RetrievalAugmentationAdvisor` is an experimental `Advisor` providing an out-of-the-box implementation for the most common RAG flows, +based on a modular architecture. + +WARNING: The `RetrievalAugmentationAdvisor` is an experimental feature and is subject to change in future releases. + +==== Sequential RAG Flows + +===== Naive RAG + +[source,java] +---- +Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder() + .similarityThreshold(0.50) + .vectorStore(vectorStore) + .build()) + .build(); + +String answer = chatClient.prompt() + .advisors(retrievalAugmentationAdvisor) + .user(question) + .call() + .content(); +---- + +By default, the `RetrievalAugmentationAdvisor` does not allow the retrieved context to be empty. When that happens, +it instructs the model not to answer the user query. You can allow empty context as follows. + +[source,java] +---- +Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder() + .similarityThreshold(0.50) + .vectorStore(vectorStore) + .build()) + .queryAugmenter(ContextualQueryAugmenter.builder() + .allowEmptyContext(true) + .build()) + .build(); + +String answer = chatClient.prompt() + .advisors(retrievalAugmentationAdvisor) + .user(question) + .call() + .content(); +---- + +===== Advanced RAG + +[source,java] +---- +Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .queryTransformer(RewriteQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder.build().mutate()) + .build()) + .documentRetriever(VectorStoreDocumentRetriever.builder() + .similarityThreshold(0.50) + .vectorStore(vectorStore) + .build()) + .build(); + +String answer = chatClient.prompt() + .advisors(retrievalAugmentationAdvisor) + .user(question) + .call() + .content(); +---- + +[[modules]] +== Modules + +Spring AI implements a Modular RAG architecture inspired by the concept of modularity detailed in the paper +"https://arxiv.org/abs/2407.21059[Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks]". + +WARNING:: Modular RAG is an experimental feature and is subject to change in future releases. + +=== Pre-Retrieval + +Pre-Retrieval modules are responsible for processing the user query to achieve the best possible retrieval results. + +==== Query Transformation + +A component for transforming the input query to make it more effective for retrieval tasks, addressing challenges +such as poorly formed queries, ambiguous terms, complex vocabulary, or unsupported languages. + +===== CompressionQueryTransformer + +A `CompressionQueryTransformer` uses a large language model to compress a conversation history and a follow-up query +into a standalone query that captures the essence of the conversation. + +This transformer is useful when the conversation history is long and the follow-up query is related +to the conversation context. + +[source,java] +---- +Query query = Query.builder() + .text("And what is its second largest city?") + .history(new UserMessage("What is the capital of Denmark?"), + new AssistantMessage("Copenhagen is the capital of Denmark.")) + .build(); + +QueryTransformer queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder) + .build(); + +Query transformedQuery = queryTransformer.transform(query); +---- + +The prompt used by this component can be customized via the `promptTemplate()` method available in the builder. + +===== RewriteQueryTransformer + +A `RewriteQueryTransformer` uses a large language model to rewrite a user query to provide better results when +querying a target system, such as a vector store or a web search engine. + +This transformer is useful when the user query is verbose, ambiguous, or contains irrelevant information +that may affect the quality of the search results. + +[source,java] +---- +Query query = new Query("I'm studying machine learning. What is an LLM?"); + +QueryTransformer queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder) + .build(); + +Query transformedQuery = queryTransformer.transform(query); +---- + +The prompt used by this component can be customized via the `promptTemplate()` method available in the builder. + +===== TranslationQueryTransformer + +A `TranslationQueryTransformer` uses a large language model to translate a query to a target language that is supported +by the embedding model used to generate the document embeddings. If the query is already in the target language, +it is returned unchanged. If the language of the query is unknown, it is also returned unchanged. + +This transformer is useful when the embedding model is trained on a specific language and the user query +is in a different language. + +[source,java] +---- +Query query = new Query("Hvad er Danmarks hovedstad?"); + +QueryTransformer queryTransformer = TranslationQueryTransformer.builder() + .chatClientBuilder(chatClientBuilder) + .targetLanguage("english") + .build(); + +Query transformedQuery = queryTransformer.transform(query); +---- + +The prompt used by this component can be customized via the `promptTemplate()` method available in the builder. + +==== Query Expansion + +A component for expanding the input query into a list of queries, addressing challenges such as poorly formed queries +by providing alternative query formulations, or by breaking down complex problems into simpler sub-queries. + +===== MultiQueryExpander + +A `MultiQueryExpander` uses a large language model to expand a query into multiple semantically diverse variations +to capture different perspectives, useful for retrieving additional contextual information and increasing the chances +of finding relevant results. + +[source,java] +---- +MultiQueryExpander queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(chatClientBuilder) + .numberOfQueries(3) + .build(); +List queries = expander.expand(new Query("How to run a Spring Boot app?")); +---- + +By default, the `MultiQueryExpander` includes the original query in the list of expanded queries. You can disable this behavior +via the `includeOriginal` method in the builder. + +[source,java] +---- +MultiQueryExpander queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(chatClientBuilder) + .includeOriginal(false) + .build(); +---- + +The prompt used by this component can be customized via the `promptTemplate()` method available in the builder. + +=== Retrieval + +Retrieval modules are responsible for querying data systems like vector store and retrieving the most relevant documents. + +==== Document Search + +Component responsible for retrieving `Documents` from an underlying data source, such as a search engine, a vector store, +a database, or a knowledge graph. + +===== VectorStoreDocumentRetriever + +A `VectorStoreDocumentRetriever` retrieves documents from a vector store that are semantically similar to the input +query. It supports filtering based on metadata, similarity threshold, and top-k results. + +[source,java] +---- +DocumentRetriever retriever = VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + .similarityThreshold(0.73) + .topK(5) + .filterExpression(new FilterExpressionBuilder() + .eq("genre", "fairytale") + .build()) + .build(); +List documents = retriever.retrieve(new Query("What is the main character of the story?")); +---- + +The filter expression can be static or dynamic. For dynamic filter expressions, you can pass a `Supplier`. + +[source,java] +---- +DocumentRetriever retriever = VectorStoreDocumentRetriever.builder() + .vectorStore(vectorStore) + .filterExpression(() -> new FilterExpressionBuilder() + .eq("tenant", TenantContextHolder.getTenantIdentifier()) + .build()) + .build(); +List documents = retriever.retrieve(new Query("What are the KPIs for the next semester?")); +---- + +==== Document Join + +A component for combining documents retrieved based on multiple queries and from multiple data sources into +a single collection of documents. As part of the joining process, it can also handle duplicate documents and reciprocal +ranking strategies. + +===== ConcatenationDocumentJoiner + +A `ConcatenationDocumentJoiner` combines documents retrieved based on multiple queries and from multiple data sources +by concatenating them into a single collection of documents. In case of duplicate documents, the first occurrence is kept. +The score of each document is kept as is. + +[source,java] +---- +Map>> documentsForQuery = ... +DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner(); +List documents = documentJoiner.join(documentsForQuery); +---- + +=== Post-Retrieval + +Post-Retrieval modules are responsible for processing the retrieved documents to achieve the best possible generation results. + +==== Document Ranking + +A component for ordering and ranking documents based on their relevance to a query to bring the most relevant documents +to the top of the list, addressing challenges such as _lost-in-the-middle_. + +Unlike `DocumentSelector`, this component does not remove entire documents from the list, but rather changes +the order/score of the documents in the list. Unlike `DocumentCompressor`, this component does not alter the content +of the documents. + +==== Document Selection + +A component for removing irrelevant or redundant documents from a list of retrieved documents, addressing challenges +such as _lost-in-the-middle_ and context length restrictions from the model. + +Unlike `DocumentRanker`, this component does not change the order/score of the documents in the list, but rather +removes irrelevant or redundant documents. Unlike `DocumentCompressor`, this component does not alter the content +of the documents, but rather removes entire documents. + +==== Document Compression + +A component for compressing the content of each document to reduce noise and redundancy in the retrieved information, +addressing challenges such as _lost-in-the-middle_ and context length restrictions from the model. + +Unlike `DocumentSelector`, this component does not remove entire documents from the list, but rather alters the content +of the documents. Unlike `DocumentRanker`, this component does not change the order/score of the documents in the list. + +=== Generation + +Generation modules are responsible for generating the final response based on the user query and retrieved documents. + +==== Query Augmentation + +A component for augmenting an input query with additional data, useful to provide a large language model +with the necessary context to answer the user query. + +===== ContextualQueryAugmenter + +The `ContextualQueryAugmenter` augments the user query with contextual data from the content of the provided documents. + +[source,java] +---- +QueryAugmenter queryAugmenter = ContextualQueryAugmenter.builder().build(); +---- + +By default, the `ContextualQueryAugmenter` does not allow the retrieved context to be empty. When that happens, +it instructs the model not to answer the user query. + +You can enable the `allowEmptyContext` option to allow the model to generate a response even when the retrieved context is empty. + +[source,java] +---- +QueryAugmenter queryAugmenter = ContextualQueryAugmenter.builder() + .allowEmptyContext(true) + .build(); +---- + +The prompts used by this component can be customized via the `promptTemplate()` and `emptyContextPromptTemplate()` methods +available in the builder. diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/CompressionQueryTransformerIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/CompressionQueryTransformerIT.java new file mode 100644 index 00000000000..c06280e3ad7 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/CompressionQueryTransformerIT.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.rag.preretrieval.query.transformation; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link CompressionQueryTransformer}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class CompressionQueryTransformerIT { + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void whenTransformerWithDefaults() { + Query query = Query.builder() + .text("And what is its second largest city?") + .history(new UserMessage("What is the capital of Denmark?"), + new AssistantMessage("Copenhagen is the capital of Denmark.")) + .build(); + + QueryTransformer queryTransformer = CompressionQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(this.openAiChatModel)) + .build(); + + Query transformedQuery = queryTransformer.apply(query); + + assertThat(transformedQuery).isNotNull(); + System.out.println(transformedQuery); + assertThat(transformedQuery.text()).containsIgnoringCase("Denmark"); + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java new file mode 100644 index 00000000000..fddb7b4b9f6 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/preretrieval/query/transformation/RewriteQueryTransformerIT.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.rag.preretrieval.query.transformation; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer; +import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link RewriteQueryTransformer}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class RewriteQueryTransformerIT { + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void whenTransformerWithDefaults() { + Query query = new Query("I'm studying machine learning. What is an LLM?"); + QueryTransformer queryTransformer = RewriteQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(this.openAiChatModel)) + .build(); + + Query transformedQuery = queryTransformer.apply(query); + + assertThat(transformedQuery).isNotNull(); + System.out.println(transformedQuery); + assertThat(transformedQuery.text()).containsIgnoringCase("Large Language Model"); + } + +}