Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.orchestration.routing.AllRetrieversQueryRouter;
import org.springframework.ai.rag.orchestration.routing.QueryRouter;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
Expand All @@ -57,6 +55,7 @@
* @since 1.0.0
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
* @see <a href="https://export.arxiv.org/abs/2410.20878">arXiv:2410.20878</a>
*/
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {

Expand All @@ -67,7 +66,7 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
@Nullable
private final QueryExpander queryExpander;

private final QueryRouter queryRouter;
private final DocumentRetriever documentRetriever;

private final DocumentJoiner documentJoiner;

Expand All @@ -80,14 +79,14 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
private final int order;

public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
@Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentJoiner documentJoiner,
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
@Nullable Integer order) {
Assert.notNull(queryRouter, "queryRouter cannot be null");
@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
this.queryExpander = queryExpander;
this.queryRouter = queryRouter;
this.documentRetriever = documentRetriever;
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
Expand Down Expand Up @@ -122,7 +121,7 @@ public AdvisedRequest before(AdvisedRequest request) {
.toList()
.stream()
.map(CompletableFuture::join)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
.collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of(entry.getValue())));

// 4. Combine documents retrieved based on multiple queries and from multiple data
// sources.
Expand All @@ -140,14 +139,8 @@ public AdvisedRequest before(AdvisedRequest request) {
* Processes a single query by routing it to document retrievers and collecting
* documents.
*/
private Map.Entry<Query, List<List<Document>>> getDocumentsForQuery(Query query) {
List<DocumentRetriever> retrievers = this.queryRouter.route(query);
List<List<Document>> documents = retrievers.stream()
.map(retriever -> CompletableFuture.supplyAsync(() -> retriever.retrieve(query), this.taskExecutor))
.toList()
.stream()
.map(CompletableFuture::join)
.toList();
private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
List<Document> documents = documentRetriever.retrieve(query);
return Map.entry(query, documents);
}

Expand All @@ -160,7 +153,7 @@ public AdvisedResponse after(AdvisedResponse advisedResponse) {
else {
chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
}
chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
chatResponseBuilder.metadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

Expand Down Expand Up @@ -190,7 +183,7 @@ public static final class Builder {

private QueryExpander queryExpander;

private QueryRouter queryRouter;
private DocumentRetriever documentRetriever;

private DocumentJoiner documentJoiner;

Expand Down Expand Up @@ -220,15 +213,8 @@ public Builder queryExpander(QueryExpander queryExpander) {
return this;
}

public Builder queryRouter(QueryRouter queryRouter) {
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
this.queryRouter = queryRouter;
return this;
}

public Builder documentRetriever(DocumentRetriever documentRetriever) {
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
this.queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build();
this.documentRetriever = documentRetriever;
return this;
}

Expand Down Expand Up @@ -258,7 +244,7 @@ public Builder order(Integer order) {
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.queryRouter,
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever,
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,78 @@

package org.springframework.ai.rag;

import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

import java.util.List;
import java.util.Map;

/**
* Represents a query in the context of a Retrieval Augmented Generation (RAG) flow.
*
* @param text the text of the query
* @param history the messages in the conversation history
* @param context the context of the query
* @author Thomas Vitale
* @since 1.0.0
*/
public record Query(String text) {
public record Query(String text, List<Message> history, Map<String, Object> context) {

public Query {
Assert.hasText(text, "text cannot be null or empty");
Assert.notNull(history, "history cannot be null");
Assert.noNullElements(history, "history elements cannot be null");
Assert.notNull(context, "context cannot be null");
Assert.noNullElements(context.keySet(), "context keys cannot be null");
}

public Query(String text) {
this(text, List.of(), Map.of());
}

public Builder mutate() {
return new Builder().text(this.text).history(this.history).context(this.context);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private String text;

private List<Message> history = List.of();

private Map<String, Object> context = Map.of();

private Builder() {
}

public Builder text(String text) {
this.text = text;
return this;
}

public Builder history(List<Message> history) {
this.history = history;
return this;
}

public Builder history(Message... history) {
this.history = List.of(history);
return this;
}

public Builder context(Map<String, Object> context) {
this.context = context;
return this;
}

public Query build() {
return new Query(text, history, context);
}

}

}

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
* @see <a href="https://export.arxiv.org/abs/2410.20878">arXiv:2410.20878</a>
*/
@NonNullApi
@NonNullFields
Expand Down
Loading
Loading