diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java new file mode 100644 index 00000000000..82188b88fa1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -0,0 +1,224 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.model.Content; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.retrieval.source.DocumentRetriever; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * This advisor implements common Retrieval Augmented Generation (RAG) flows using the + * building blocks defined in the {@link org.springframework.ai.rag} package and following + * the Modular RAG Architecture. + *

+ * It's the successor of the {@link QuestionAnswerAdvisor}. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @since 1.0.0 + * @see arXiv:2407.21059 + * @see arXiv:2312.10997 + */ +public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + public static final String DOCUMENT_CONTEXT = "rag_document_context"; + + public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + {query} + + Context information is below. Use this information to answer the user query. + + --------------------- + {context} + --------------------- + + Given the context and provided history information and not prior knowledge, + reply to the user query. If the answer is not in the context, inform + the user that you can't answer the query. + """); + + private final DocumentRetriever documentRetriever; + + private final PromptTemplate promptTemplate; + + private final boolean protectFromBlocking; + + private final int order; + + public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate, + @Nullable Boolean protectFromBlocking, @Nullable Integer order) { + Assert.notNull(documentRetriever, "documentRetriever cannot be null"); + this.documentRetriever = documentRetriever; + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false; + this.order = order != null ? order : 0; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + Assert.notNull(advisedRequest, "advisedRequest cannot be null"); + Assert.notNull(chain, "chain cannot be null"); + + AdvisedRequest processedAdvisedRequest = before(advisedRequest); + AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest); + return after(advisedResponse); + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + Assert.notNull(advisedRequest, "advisedRequest cannot be null"); + Assert.notNull(chain, "chain cannot be null"); + + // This can be executed by both blocking and non-blocking Threads + // E.g. a command line or Tomcat blocking Thread implementation + // or by a WebFlux dispatch in a non-blocking manner. + Flux advisedResponses = (this.protectFromBlocking) ? + // @formatter:off + Mono.just(advisedRequest) + .publishOn(Schedulers.boundedElastic()) + .map(this::before) + .flatMapMany(chain::nextAroundStream) + : chain.nextAroundStream(before(advisedRequest)); + // @formatter:on + + return advisedResponses.map(ar -> { + if (onFinishReason().test(ar)) { + ar = after(ar); + } + return ar; + }); + } + + private AdvisedRequest before(AdvisedRequest request) { + Map context = new HashMap<>(request.adviseContext()); + + // 0. Create a query from the user text and parameters. + Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render()); + + // 1. Retrieve similar documents for the original query. + List documents = this.documentRetriever.retrieve(query); + context.put(DOCUMENT_CONTEXT, documents); + + // 2. Combine retrieved documents. + String documentContext = documents.stream() + .map(Content::getContent) + .collect(Collectors.joining(System.lineSeparator())); + + // 3. Define augmentation prompt parameters. + Map promptParameters = Map.of("query", query.text(), "context", documentContext); + + // 4. Augment user prompt with the context data. + UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters); + + return AdvisedRequest.from(request) + .withUserText(augmentedUserMessage.getContent()) + .withAdviseContext(context) + .build(); + } + + private AdvisedResponse after(AdvisedResponse advisedResponse) { + ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); + chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT)); + return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); + } + + private Predicate onFinishReason() { + return advisedResponse -> advisedResponse.response() + .getResults() + .stream() + .anyMatch(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())); + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return this.order; + } + + public static final class Builder { + + private DocumentRetriever documentRetriever; + + private PromptTemplate promptTemplate; + + private Boolean protectFromBlocking; + + private Integer order; + + private Builder() { + } + + public Builder documentRetriever(DocumentRetriever documentRetriever) { + this.documentRetriever = documentRetriever; + return this; + } + + public Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder protectFromBlocking(Boolean protectFromBlocking) { + this.protectFromBlocking = protectFromBlocking; + return this; + } + + public Builder order(Integer order) { + this.order = order; + return this; + } + + public RetrievalAugmentationAdvisor build() { + return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate, + this.protectFromBlocking, this.order); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index afca774760f..08e1aa276e7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -34,6 +34,8 @@ import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -41,8 +43,6 @@ * The data of the chat client request that can be modified before the execution of the * ChatClient's call method * - * @author Christian Tzolov - * @since 1.0.0 * @param chatModel the chat model used * @param userText the text provided by the user * @param systemText the text provided by the system @@ -57,13 +57,53 @@ * @param advisorParams the map of advisor parameters * @param adviseContext the map of advise context * @param toolContext the tool context + * @author Christian Tzolov + * @author Thomas Vitale + * @since 1.0.0 */ -public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions, - List media, List functionNames, List functionCallbacks, List messages, - Map userParams, Map systemParams, List advisors, - Map advisorParams, Map adviseContext, Map toolContext) { +public record AdvisedRequest( +// @formatter:off + ChatModel chatModel, + String userText, + @Nullable + String systemText, + @Nullable + ChatOptions chatOptions, + List media, + List functionNames, + List functionCallbacks, + List messages, + Map userParams, + Map systemParams, + List advisors, + Map advisorParams, + Map adviseContext, + Map toolContext +// @formatter:on +) { + + public AdvisedRequest { + Assert.notNull(chatModel, "chatModel cannot be null"); + Assert.hasText(userText, "userText cannot be null or empty"); + Assert.notNull(media, "media cannot be null"); + Assert.notNull(functionNames, "functionNames cannot be null"); + Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + Assert.notNull(messages, "messages cannot be null"); + Assert.notNull(userParams, "userParams cannot be null"); + Assert.notNull(systemParams, "systemParams cannot be null"); + Assert.notNull(advisors, "advisors cannot be null"); + Assert.notNull(advisorParams, "advisorParams cannot be null"); + Assert.notNull(adviseContext, "adviseContext cannot be null"); + Assert.notNull(toolContext, "toolContext cannot be null"); + } + + public static Builder builder() { + return new Builder(); + } public static Builder from(AdvisedRequest from) { + Assert.notNull(from, "AdvisedRequest cannot be null"); + Builder builder = new Builder(); builder.chatModel = from.chatModel; builder.userText = from.userText; @@ -79,23 +119,18 @@ public static Builder from(AdvisedRequest from) { builder.advisorParams = from.advisorParams; builder.adviseContext = from.adviseContext; builder.toolContext = from.toolContext; - return builder; } - public static Builder builder() { - return new Builder(); - } - public AdvisedRequest updateContext(Function, Map> contextTransform) { + Assert.notNull(contextTransform, "contextTransform cannot be null"); return from(this) .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) .build(); } public Prompt toPrompt() { - - var messages = new ArrayList(this.messages()); + var messages = new ArrayList<>(this.messages()); String processedSystemText = this.systemText(); if (StringUtils.hasText(processedSystemText)) { @@ -111,7 +146,6 @@ public Prompt toPrompt() { ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); if (StringUtils.hasText(processedUserText)) { - Map userParams = new HashMap<>(this.userParams()); if (StringUtils.hasText(formatParam)) { userParams.put("spring_ai_soc_format", formatParam); @@ -137,17 +171,15 @@ public Prompt toPrompt() { return new Prompt(messages, this.chatOptions()); } - public static class Builder { - - public Map toolContext = Map.of(); + public static final class Builder { private ChatModel chatModel; - private String userText = ""; + private String userText; - private String systemText = ""; + private String systemText; - private ChatOptions chatOptions = null; + private ChatOptions chatOptions; private List media = List.of(); @@ -167,6 +199,11 @@ public static class Builder { private Map adviseContext = Map.of(); + public Map toolContext = Map.of(); + + private Builder() { + } + public Builder withChatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; @@ -202,11 +239,6 @@ public Builder withFunctionCallbacks(List functionCallbacks) { return this; } - public Builder withToolContext(Map toolContext) { - this.toolContext = toolContext; - return this; - } - public Builder withMessages(List messages) { this.messages = messages; return this; @@ -237,6 +269,11 @@ public Builder withAdviseContext(Map adviseContext) { return this; } + public Builder withToolContext(Map toolContext) { + this.toolContext = toolContext; + return this; + } + public AdvisedRequest build() { return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java index a03247fd6dd..8bb383c1cc5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -25,48 +25,54 @@ import org.springframework.util.Assert; /** + * The data of the chat client response that can be modified before the call returns. + * * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ public record AdvisedResponse(ChatResponse response, Map adviseContext) { + public AdvisedResponse { + Assert.notNull(response, "response cannot be null"); + Assert.notNull(adviseContext, "adviseContext cannot be null"); + } + public static Builder builder() { return new Builder(); } + public static Builder from(AdvisedResponse advisedResponse) { + Assert.notNull(advisedResponse, "advisedResponse cannot be null"); + return new Builder().withResponse(advisedResponse.response).withAdviseContext(advisedResponse.adviseContext); + } + public AdvisedResponse updateContext(Function, Map> contextTransform) { + Assert.notNull(contextTransform, "contextTransform cannot be null"); return new AdvisedResponse(this.response, Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))); } - public static class Builder { + public static final class Builder { private ChatResponse response; private Map adviseContext; - public Builder() { - } - - public static Builder from(AdvisedResponse advisedResponse) { - return new Builder().withResponse(advisedResponse.response) - .withAdviseContext(advisedResponse.adviseContext); + private Builder() { } public Builder withResponse(ChatResponse response) { - Assert.notNull(response, "the response must be non-null"); this.response = response; return this; } public Builder withAdviseContext(Map adviseContext) { - Assert.notNull(adviseContext, "the adviseContext must be non-null"); this.adviseContext = adviseContext; return this; } public AdvisedResponse build() { - Assert.notNull(this.adviseContext, "the adviseContext must be non-null"); return new AdvisedResponse(this.response, this.adviseContext); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/package-info.java new file mode 100644 index 00000000000..a4fc77b1242 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/package-info.java new file mode 100644 index 00000000000..52271abfbaa --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.client.advisor; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java similarity index 59% rename from spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java rename to spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java index 618af3fc664..cc8869bee4a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DocumentRetriever.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/Query.java @@ -14,22 +14,18 @@ * limitations under the License. */ -package org.springframework.ai.document; +package org.springframework.ai.rag; -import java.util.List; -import java.util.function.Function; +import org.springframework.util.Assert; -public interface DocumentRetriever extends Function> { - - /** - * Retrieves relevant documents however the implementation sees fit. - * @param query query string - * @return relevant documents - */ - List retrieve(String query); - - default List apply(String query) { - return retrieve(query); +/** + * Represents a query in the context of a Retrieval Augmented Generation (RAG) flow. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public record Query(String text) { + public Query { + Assert.hasText(text, "text cannot be null or empty"); } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java new file mode 100644 index 00000000000..a42ade9d8bb --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This package contains the core interfaces and classes supporting Retrieval Augmented + * Generation. + *

+ * It's based on the Modular RAG Architecture and provides the necessary building blocks + * to define and execute RAG flows. It includes three levels of abstraction: + *

    + *
  1. Module
  2. + *
  3. Sub-Module
  4. + *
  5. Operator
  6. + *
+ * + * @see arXiv:2407.21059 + * @see arXiv:2312.10997 + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.rag; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java new file mode 100644 index 00000000000..9995f15aa7c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RAG Module: Retrieval. + *

+ * This package includes submodules for handling the retrieval process in RAG flows. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.rag.retrieval; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/DocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/DocumentRetriever.java new file mode 100644 index 00000000000..e5adc128168 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/DocumentRetriever.java @@ -0,0 +1,56 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.retrieval.source; + +import java.util.List; +import java.util.function.Function; + +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; + +/** + * API for retrieving {@link Document}s from an underlying data source. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface DocumentRetriever extends Function> { + + /** + * Retrieves {@link Document}s from an underlying data source using the given + * {@link Query}. + */ + List retrieve(Query query); + + /** + * Retrieves {@link Document}s from an underlying data source using the given query + * string. + */ + default List retrieve(String query) { + return retrieve(new Query(query)); + } + + /** + * Retrieves {@link Document}s from an underlying data source using the given + * {@link Query}. + */ + default List apply(Query query) { + return retrieve(query); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetriever.java new file mode 100644 index 00000000000..2d8b45fabe1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetriever.java @@ -0,0 +1,135 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.retrieval.source; + +import java.util.List; +import java.util.function.Supplier; + +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * A document retriever that uses a vector store to search for documents. It supports + * filtering based on metadata, similarity threshold, and top-k results. + * + *

+ * Example usage:

{@code
+ * VectorStoreDocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
+ *     .vectorStore(vectorStore)
+ *     .similarityThreshold(0.73)
+ *     .topK(5)
+ *     .filterExpression(filterExpression)
+ *     .build();
+ * List documents = retriever.retrieve("example query");
+ * }
+ * + * @author Thomas Vitale + * @since 1.0.0 + * @see VectorStore + * @see Filter.Expression + */ +public class VectorStoreDocumentRetriever implements DocumentRetriever { + + private final VectorStore vectorStore; + + private final Double similarityThreshold; + + private final Integer topK; + + // Supplier to allow for lazy evaluation of the filter expression, + // which may depend on the execution content. For example, you may want to + // filter dynamically based on the current user's identity or tenant ID. + private final Supplier filterExpression; + + public VectorStoreDocumentRetriever(VectorStore vectorStore, @Nullable Double similarityThreshold, + @Nullable Integer topK, @Nullable Supplier filterExpression) { + Assert.notNull(vectorStore, "vectorStore cannot be null"); + this.vectorStore = vectorStore; + this.similarityThreshold = similarityThreshold != null ? similarityThreshold + : SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL; + this.topK = topK != null ? topK : SearchRequest.DEFAULT_TOP_K; + this.filterExpression = filterExpression != null ? filterExpression : () -> null; + } + + @Override + public List retrieve(Query query) { + Assert.notNull(query, "query cannot be null"); + var searchRequest = SearchRequest.query(query.text()) + .withFilterExpression(this.filterExpression.get()) + .withSimilarityThreshold(this.similarityThreshold) + .withTopK(this.topK); + return this.vectorStore.similaritySearch(searchRequest); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link VectorStoreDocumentRetriever}. + */ + public static final class Builder { + + private VectorStore vectorStore; + + private Double similarityThreshold; + + private Integer topK; + + private Supplier filterExpression; + + private Builder() { + } + + public Builder vectorStore(VectorStore vectorStore) { + this.vectorStore = vectorStore; + return this; + } + + public Builder similarityThreshold(Double similarityThreshold) { + this.similarityThreshold = similarityThreshold; + return this; + } + + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder filterExpression(Filter.Expression filterExpression) { + this.filterExpression = () -> filterExpression; + return this; + } + + public Builder filterExpression(Supplier filterExpression) { + this.filterExpression = filterExpression; + return this; + } + + public VectorStoreDocumentRetriever build() { + return new VectorStoreDocumentRetriever(this.vectorStore, this.similarityThreshold, this.topK, + this.filterExpression); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/package-info.java new file mode 100644 index 00000000000..7d65ec54b55 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/package-info.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RAG Sub-Module: Source. + *

+ * This package provides the functional building blocks for retrieving documents from a + * data source. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.rag.retrieval.source; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 3dc853b2fa6..44824f10323 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -50,17 +50,12 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @ExtendWith(MockitoExtension.class) public class ChatClientTest { - static Function mockFunction = new Function() { - - @Override - public String apply(String s) { - return s; - } - }; + static Function mockFunction = s -> s; @Mock ChatModel chatModel; @@ -88,7 +83,7 @@ public void defaultSystemText() { var chatClient = ChatClient.builder(this.chatModel).defaultSystem("Default system text").build(); - var content = chatClient.prompt().call().content(); + var content = chatClient.prompt("What's Spring AI?").call().content(); assertThat(content).isEqualTo("response"); @@ -96,7 +91,7 @@ public void defaultSystemText() { assertThat(systemMessage.getContent()).isEqualTo("Default system text"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); - content = join(chatClient.prompt().stream().content()); + content = join(chatClient.prompt("What's Spring AI?").stream().content()); assertThat(content).isEqualTo("response"); @@ -105,7 +100,7 @@ public void defaultSystemText() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Override the default system text with prompt system - content = chatClient.prompt().system("Override default system text").call().content(); + content = chatClient.prompt("What's Spring AI?").system("Override default system text").call().content(); assertThat(content).isEqualTo("response"); systemMessage = this.promptCaptor.getValue().getInstructions().get(0); @@ -113,7 +108,8 @@ public void defaultSystemText() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Streaming - content = join(chatClient.prompt().system("Override default system text").stream().content()); + content = join( + chatClient.prompt("What's Spring AI?").system("Override default system text").stream().content()); assertThat(content).isEqualTo("response"); systemMessage = this.promptCaptor.getValue().getInstructions().get(0); @@ -140,7 +136,7 @@ public void defaultSystemTextLambda() { .param("param2", "value2")) .build(); - var content = chatClient.prompt().call().content(); + var content = chatClient.prompt("What's Spring AI?").call().content(); assertThat(content).isEqualTo("response"); @@ -149,7 +145,7 @@ public void defaultSystemTextLambda() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Streaming - content = join(chatClient.prompt().stream().content()); + content = join(chatClient.prompt("What's Spring AI?").stream().content()); assertThat(content).isEqualTo("response"); @@ -158,7 +154,7 @@ public void defaultSystemTextLambda() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Override single default system parameter - content = chatClient.prompt().system(s -> s.param("param1", "value1New")).call().content(); + content = chatClient.prompt("What's Spring AI?").system(s -> s.param("param1", "value1New")).call().content(); assertThat(content).isEqualTo("response"); systemMessage = this.promptCaptor.getValue().getInstructions().get(0); @@ -166,7 +162,8 @@ public void defaultSystemTextLambda() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // streaming - content = join(chatClient.prompt().system(s -> s.param("param1", "value1New")).stream().content()); + content = join( + chatClient.prompt("What's Spring AI?").system(s -> s.param("param1", "value1New")).stream().content()); assertThat(content).isEqualTo("response"); systemMessage = this.promptCaptor.getValue().getInstructions().get(0); @@ -174,7 +171,7 @@ public void defaultSystemTextLambda() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Override default system text - content = chatClient.prompt() + content = chatClient.prompt("What's Spring AI?") .system(s -> s.text("Override default system text {param3}").param("param3", "value3")) .call() .content(); @@ -185,7 +182,7 @@ public void defaultSystemTextLambda() { assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); // Streaming - content = join(chatClient.prompt() + content = join(chatClient.prompt("What's Spring AI?") .system(s -> s.text("Override default system text {param3}").param("param3", "value3")) .stream() .content()); @@ -489,11 +486,16 @@ public void simpleSystemPrompt() throws MalformedURLException { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - String response = ChatClient.builder(this.chatModel).build().prompt().system("System prompt").call().content(); + String response = ChatClient.builder(this.chatModel) + .build() + .prompt("What's Spring AI?") + .system("System prompt") + .call() + .content(); assertThat(response).isEqualTo("response"); - assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(1); + assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(2); Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getContent()).isEqualTo("System prompt"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java new file mode 100644 index 00000000000..74a18cd31c8 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java @@ -0,0 +1,112 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.retrieval.source.DocumentRetriever; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link RetrievalAugmentationAdvisor}. + * + * @author Thomas Vitale + */ +class RetrievalAugmentationAdvisorTests { + + @Test + void whenDocumentRetrieverIsNullThenThrow() { + assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder().documentRetriever(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("documentRetriever cannot be null"); + } + + @Test + void theOneWithTheDocumentRetriever() { + // Chat Model + var chatModel = mock(ChatModel.class); + var promptCaptor = ArgumentCaptor.forClass(Prompt.class); + given(chatModel.call(promptCaptor.capture())).willReturn(ChatResponse.builder() + .withGenerations(List.of(new Generation(new AssistantMessage("Felix Felicis")))) + .build()); + + // Document Retriever + var documentContext = List.of(Document.builder().withId("1").withContent("doc1").build(), + Document.builder().withId("2").withContent("doc2").build()); + var documentRetriever = mock(DocumentRetriever.class); + var queryCaptor = ArgumentCaptor.forClass(Query.class); + given(documentRetriever.retrieve(queryCaptor.capture())).willReturn(documentContext); + + // Advisor + var advisor = RetrievalAugmentationAdvisor.builder().documentRetriever(documentRetriever).build(); + + // Chat Client + var chatClient = ChatClient.builder(chatModel) + .defaultAdvisors(advisor) + .defaultSystem("You are a wizard!") + .build(); + + // Call + var chatResponse = chatClient.prompt() + .user(user -> user.text("What would I get if I added {ingredient1} to {ingredient2}?") + .param("ingredient1", "a pinch of Moonstone") + .param("ingredient2", "a dash of powdered Gold")) + .call() + .chatResponse(); + + // Verify + assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo("Felix Felicis"); + assertThat(chatResponse.getMetadata().>get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT)) + .containsAll(documentContext); + + var query = queryCaptor.getValue(); + assertThat(query.text()) + .isEqualTo("What would I get if I added a pinch of Moonstone to a dash of powdered Gold?"); + + var prompt = promptCaptor.getValue(); + assertThat(prompt.getContents()).contains(""" + What would I get if I added a pinch of Moonstone to a dash of powdered Gold? + + Context information is below. Use this information to answer the user query. + + --------------------- + doc1 + doc2 + --------------------- + + Given the context and provided history information and not prior knowledge, + reply to the user query. If the answer is not in the context, inform + the user that you can't answer the query. + """); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/QueryTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/QueryTests.java new file mode 100644 index 00000000000..b97c054180b --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/QueryTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link Query}. + * + * @author Thomas Vitale + */ +class QueryTests { + + @Test + void whenTextIsNullThenThrow() { + assertThatThrownBy(() -> new Query(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("text cannot be null or empty"); + } + + @Test + void whenTextIsEmptyThenThrow() { + assertThatThrownBy(() -> new Query("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("text cannot be null or empty"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetrieverTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetrieverTests.java new file mode 100644 index 00000000000..030db98c38d --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetrieverTests.java @@ -0,0 +1,123 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.retrieval.source; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.internal.verification.Times; + +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; +import org.springframework.util.Assert; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; + +/** + * Unit tests for {@link VectorStoreDocumentRetriever}. + */ +class VectorStoreDocumentRetrieverTests { + + @Test + void whenVectorStoreIsNullThenThrow() { + assertThatThrownBy(() -> VectorStoreDocumentRetriever.builder().vectorStore(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("vectorStore cannot be null"); + } + + @Test + void searchRequestParameters() { + var mockVectorStore = mock(VectorStore.class); + var documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(mockVectorStore) + .similarityThreshold(0.73) + .topK(5) + .filterExpression(new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Rivendell"))) + .build(); + + documentRetriever.retrieve("query"); + + var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); + verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture()); + + var searchRequest = searchRequestCaptor.getValue(); + assertThat(searchRequest.getQuery()).isEqualTo("query"); + assertThat(searchRequest.getSimilarityThreshold()).isEqualTo(0.73); + assertThat(searchRequest.getTopK()).isEqualTo(5); + assertThat(searchRequest.getFilterExpression()) + .isEqualTo(new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Rivendell"))); + } + + @Test + void dynamicFilterExpressions() { + var mockVectorStore = mock(VectorStore.class); + var documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(mockVectorStore) + .filterExpression( + () -> new FilterExpressionBuilder().eq("tenantId", TenantContextHolder.getTenantIdentifier()) + .build()) + .build(); + + TenantContextHolder.setTenantIdentifier("tenant1"); + documentRetriever.retrieve("query"); + TenantContextHolder.clear(); + + TenantContextHolder.setTenantIdentifier("tenant2"); + documentRetriever.retrieve("query"); + TenantContextHolder.clear(); + + var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); + + verify(mockVectorStore, new Times(2)).similaritySearch(searchRequestCaptor.capture()); + + var searchRequest1 = searchRequestCaptor.getAllValues().get(0); + assertThat(searchRequest1.getFilterExpression()) + .isEqualTo(new Filter.Expression(EQ, new Filter.Key("tenantId"), new Filter.Value("tenant1"))); + + var searchRequest2 = searchRequestCaptor.getAllValues().get(1); + assertThat(searchRequest2.getFilterExpression()) + .isEqualTo(new Filter.Expression(EQ, new Filter.Key("tenantId"), new Filter.Value("tenant2"))); + } + + static final class TenantContextHolder { + + private static final ThreadLocal tenantIdentifier = new ThreadLocal<>(); + + private TenantContextHolder() { + } + + public static void setTenantIdentifier(String tenant) { + Assert.hasText(tenant, "tenant cannot be null or empty"); + tenantIdentifier.set(tenant); + } + + public static String getTenantIdentifier() { + return tenantIdentifier.get(); + } + + public static void clear() { + tenantIdentifier.remove(); + } + + } + +}