Skip to content

Commit ba17f20

Browse files
committed
Support userText rendering strategy in RetrievalAugmentationAdvisor
The input userText to the advisor might contain the templating special characters for other purposes than templating (e.g. when including source code). A new UserTextProcessor functional interface allows to customize how the userText is processed when taking as input to the RetrievalAugmentationAdvisor. By default, the PromptTemplateUserTextProcessor implementation is used. If you want to disable the rendering step altogether, you can use the NoOpUserTextProcessor implementation. Signed-off-by: Thomas Vitale <[email protected]>
1 parent bc375ab commit ba17f20

File tree

8 files changed

+260
-11
lines changed

8 files changed

+260
-11
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package org.springframework.ai.chat.client.advisor;
2+
3+
import java.util.Map;
4+
5+
/**
6+
* A {@link UserTextProcessor} that returns the user text as is.
7+
*
8+
* @author Thomas Vitale
9+
* @since 1.0.0
10+
*/
11+
public class NoOpUserTextProcessor implements UserTextProcessor {
12+
13+
@Override
14+
public String process(String userText, Map<String, Object> userParams) {
15+
return userText;
16+
}
17+
18+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client.advisor;
18+
19+
import org.springframework.ai.chat.prompt.PromptTemplate;
20+
import org.springframework.util.Assert;
21+
22+
import java.util.Map;
23+
24+
/**
25+
* Processes the advised user text with the given user parameters using a
26+
* {@link PromptTemplate}.
27+
*
28+
* @author Thomas Vitale
29+
* @since 1.0.0
30+
*/
31+
public class PromptTemplateUserTextProcessor implements UserTextProcessor {
32+
33+
@Override
34+
public String process(String userText, Map<String, Object> userParams) {
35+
Assert.hasText(userText, "userText cannot be null or empty");
36+
Assert.notNull(userParams, "userParams cannot be null");
37+
Assert.noNullElements(userParams.keySet(), "userParams keys cannot be null");
38+
39+
return new PromptTemplate(userText, userParams).render();
40+
}
41+
42+
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
3030
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
3131
import org.springframework.ai.chat.model.ChatResponse;
32-
import org.springframework.ai.chat.prompt.PromptTemplate;
3332
import org.springframework.ai.document.Document;
3433
import org.springframework.ai.rag.Query;
3534
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
@@ -61,6 +60,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6160

6261
public static final String DOCUMENT_CONTEXT = "rag_document_context";
6362

63+
private final UserTextProcessor userTextProcessor;
64+
6465
private final List<QueryTransformer> queryTransformers;
6566

6667
@Nullable
@@ -78,10 +79,12 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
7879

7980
private final int order;
8081

81-
public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
82-
@Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever,
83-
@Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter,
84-
@Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
82+
public RetrievalAugmentationAdvisor(@Nullable UserTextProcessor userTextProcessor,
83+
@Nullable List<QueryTransformer> queryTransformers, @Nullable QueryExpander queryExpander,
84+
DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner,
85+
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
86+
@Nullable Integer order) {
87+
this.userTextProcessor = userTextProcessor != null ? userTextProcessor : new PromptTemplateUserTextProcessor();
8588
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
8689
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
8790
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
@@ -102,9 +105,11 @@ public static Builder builder() {
102105
public AdvisedRequest before(AdvisedRequest request) {
103106
Map<String, Object> context = new HashMap<>(request.adviseContext());
104107

108+
String processedUserText = this.userTextProcessor.apply(request.userText(), request.userParams());
109+
105110
// 0. Create a query from the user text, parameters, and conversation history.
106111
Query originalQuery = Query.builder()
107-
.text(new PromptTemplate(request.userText(), request.userParams()).render())
112+
.text(processedUserText)
108113
.history(request.messages())
109114
.context(context)
110115
.build();
@@ -183,6 +188,8 @@ private static TaskExecutor buildDefaultTaskExecutor() {
183188

184189
public static final class Builder {
185190

191+
private UserTextProcessor userTextProcessor;
192+
186193
private List<QueryTransformer> queryTransformers;
187194

188195
private QueryExpander queryExpander;
@@ -202,6 +209,11 @@ public static final class Builder {
202209
private Builder() {
203210
}
204211

212+
public Builder userTextProcessor(UserTextProcessor userTextProcessor) {
213+
this.userTextProcessor = userTextProcessor;
214+
return this;
215+
}
216+
205217
public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
206218
this.queryTransformers = queryTransformers;
207219
return this;
@@ -248,8 +260,9 @@ public Builder order(Integer order) {
248260
}
249261

250262
public RetrievalAugmentationAdvisor build() {
251-
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever,
252-
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
263+
return new RetrievalAugmentationAdvisor(this.userTextProcessor, this.queryTransformers, this.queryExpander,
264+
this.documentRetriever, this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler,
265+
this.order);
253266
}
254267

255268
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client.advisor;
18+
19+
import java.util.Map;
20+
import java.util.function.BiFunction;
21+
22+
/**
23+
* Processes the advised user text with the given user parameters.
24+
*
25+
* @author Thomas Vitale
26+
* @since 1.0.0
27+
*/
28+
@FunctionalInterface
29+
public interface UserTextProcessor extends BiFunction<String, Map<String, Object>, String> {
30+
31+
String process(String userText, Map<String, Object> userParams);
32+
33+
@Override
34+
default String apply(String userText, Map<String, Object> userParams) {
35+
return process(userText, userParams);
36+
}
37+
38+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package org.springframework.ai.chat.client.advisor;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import static org.junit.jupiter.api.Assertions.assertEquals;
6+
7+
/**
8+
* Unit tests for {@link NoOpUserTextProcessor}.
9+
*
10+
* @author Thomas Vitale
11+
*/
12+
class NoOpUserTextProcessorTests {
13+
14+
@Test
15+
void process() {
16+
NoOpUserTextProcessor processor = new NoOpUserTextProcessor();
17+
String userText = "Hello, {World}!";
18+
String processedText = processor.process(userText, null);
19+
assertEquals(userText, processedText);
20+
}
21+
22+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client.advisor;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.params.ParameterizedTest;
21+
import org.junit.jupiter.params.provider.NullAndEmptySource;
22+
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
28+
29+
/**
30+
* Unit tests for {@link PromptTemplateUserTextProcessor}.
31+
*
32+
* @author Thomas Vitale
33+
*/
34+
class PromptTemplateUserTextProcessorTests {
35+
36+
@ParameterizedTest
37+
@NullAndEmptySource
38+
void processWithNullOrEmptyUserText(String userText) {
39+
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
40+
Map<String, Object> userParams = Map.of("name", "William");
41+
assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams))
42+
.withMessage("userText cannot be null or empty");
43+
}
44+
45+
@Test
46+
void processWithNullUserParams() {
47+
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
48+
String userText = "Hello, {name}!";
49+
Map<String, Object> userParams = null;
50+
assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams))
51+
.withMessage("userParams cannot be null");
52+
}
53+
54+
@Test
55+
void processWithNullUserParamsKeys() {
56+
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
57+
String userText = "Hello, {name}!";
58+
Map<String, Object> userParams = new HashMap<>();
59+
userParams.put(null, "William");
60+
assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams))
61+
.withMessage("userParams keys cannot be null");
62+
}
63+
64+
@Test
65+
void process() {
66+
PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor();
67+
String userText = "Hello, {name}!";
68+
Map<String, Object> userParams = Map.of("name", "William");
69+
String processedText = processor.process(userText, userParams);
70+
assertThat(processedText).isEqualTo("Hello, William!");
71+
}
72+
73+
}

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,24 @@ String answer = chatClient.prompt()
138138

139139
See xref:api/retrieval-augmented-generation.adoc#_vectorstoredocumentretriever for more information.
140140

141+
By default, the `RetrievalAugmentationAdvisor` process the input user text with a `PromptTemplate`, ensuring that any template placeholder is correctly rendered before using the text for the retrieval process.
142+
If you want to customize the processing logic, you can provide a custom `UserTextProcessor` to the advisor, either as a lambda or a class.
143+
For example, in case you want to skip the rendering step, you can provide a `NoOpUserTextProcessor`. That is useful if you're planning to use the templating special characters in the user text for other purposes.
144+
145+
[source,java]
146+
----
147+
Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
148+
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore().build())
149+
.userTextProcessor(new NoOpUserTextProcessor())
150+
.build();
151+
152+
String answer = chatClient.prompt()
153+
.advisors(retrievalAugmentationAdvisor)
154+
.user(question)
155+
.call()
156+
.content();
157+
----
158+
141159
===== Advanced RAG
142160

143161
[source,java]

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616

1717
package org.springframework.ai.integration.tests.client.advisor;
1818

19-
import java.util.List;
20-
2119
import org.junit.jupiter.api.AfterEach;
2220
import org.junit.jupiter.api.BeforeEach;
2321
import org.junit.jupiter.api.Test;
2422
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
25-
2623
import org.springframework.ai.chat.client.ChatClient;
2724
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
2825
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
26+
import org.springframework.ai.chat.client.advisor.NoOpUserTextProcessor;
2927
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
3028
import org.springframework.ai.chat.memory.InMemoryChatMemory;
3129
import org.springframework.ai.chat.model.ChatResponse;
@@ -49,6 +47,8 @@
4947
import org.springframework.boot.test.context.SpringBootTest;
5048
import org.springframework.core.io.Resource;
5149

50+
import java.util.List;
51+
5252
import static org.assertj.core.api.Assertions.assertThat;
5353

5454
/**
@@ -131,6 +131,31 @@ void ragWithRequestFilter() {
131131
.isNull();
132132
}
133133

134+
@Test
135+
void ragWithCustomUserTextProcessor() {
136+
String question = "Where does the adventure of {Anacletus} and {Birba} take place?";
137+
138+
RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
139+
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
140+
.userTextProcessor(new NoOpUserTextProcessor())
141+
.build();
142+
143+
ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel)
144+
.build()
145+
.prompt(question)
146+
.advisors(ragAdvisor)
147+
.call()
148+
.chatResponse();
149+
150+
assertThat(chatResponse).isNotNull();
151+
152+
String response = chatResponse.getResult().getOutput().getText();
153+
System.out.println(response);
154+
assertThat(response).containsIgnoringCase("Highlands");
155+
156+
evaluateRelevancy(question, chatResponse);
157+
}
158+
134159
@Test
135160
void ragWithCompression() {
136161
MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();

0 commit comments

Comments
 (0)