Skip to content

Commit 5b2257b

Browse files
committed
feat: Added a ValidationMode property to RewriteQueryTransformer, allowing users to customize the behavior of template content validation when creating an instance of the object.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent 6880753 commit 5b2257b

File tree

3 files changed

+174
-2
lines changed

3 files changed

+174
-2
lines changed

spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
* irrelevant information that may affect the quality of the search results.
3636
*
3737
* @author Thomas Vitale
38+
* @author Sun Yuhan
3839
* @since 1.0.0
3940
* @see <a href="https://arxiv.org/pdf/2305.14283">arXiv:2305.14283</a>
4041
*/
@@ -60,15 +61,28 @@ public class RewriteQueryTransformer implements QueryTransformer {
6061

6162
private final String targetSearchSystem;
6263

64+
private final ValidationMode validationMode;
65+
6366
public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
6467
@Nullable String targetSearchSystem) {
6568
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
6669

6770
this.chatClient = chatClientBuilder.build();
6871
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
6972
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
73+
this.validationMode = ValidationMode.THROW;
74+
validate();
75+
}
76+
77+
public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
78+
@Nullable String targetSearchSystem, @Nullable ValidationMode validationMode) {
79+
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
7080

71-
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
81+
this.chatClient = chatClientBuilder.build();
82+
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
83+
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
84+
this.validationMode = validationMode != null ? validationMode : ValidationMode.THROW;
85+
validate();
7286
}
7387

7488
@Override
@@ -92,6 +106,23 @@ public Query transform(Query query) {
92106
return query.mutate().text(rewrittenQueryText).build();
93107
}
94108

109+
/**
110+
* Verify whether the template contains the required variables.
111+
*/
112+
private void validate() {
113+
switch (this.validationMode) {
114+
case THROW -> PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
115+
case WARN -> {
116+
try {
117+
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
118+
}
119+
catch (IllegalArgumentException e) {
120+
logger.warn(e.getMessage());
121+
}
122+
}
123+
}
124+
}
125+
95126
public static Builder builder() {
96127
return new Builder();
97128
}
@@ -106,6 +137,9 @@ public static final class Builder {
106137
@Nullable
107138
private String targetSearchSystem;
108139

140+
@Nullable
141+
private ValidationMode validationMode;
142+
109143
private Builder() {
110144
}
111145

@@ -124,8 +158,14 @@ public Builder targetSearchSystem(String targetSearchSystem) {
124158
return this;
125159
}
126160

161+
public Builder validationMode(ValidationMode validationMode) {
162+
this.validationMode = validationMode;
163+
return this;
164+
}
165+
127166
public RewriteQueryTransformer build() {
128-
return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem);
167+
return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem,
168+
this.validationMode);
129169
}
130170

131171
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.rag.preretrieval.query.transformation;
18+
19+
/**
20+
* Validation modes for {@link RewriteQueryTransformer}.
21+
*
22+
* @author Sun Yuhan
23+
*/
24+
public enum ValidationMode {
25+
26+
/**
27+
* If the validation fails, an exception is thrown. This is the default mode.
28+
*/
29+
THROW,
30+
31+
/**
32+
* If the validation fails, a warning is logged.
33+
*/
34+
WARN,
35+
36+
/**
37+
* No validation is performed.
38+
*/
39+
NONE
40+
41+
}

spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616

1717
package org.springframework.ai.rag.preretrieval.query.transformation;
1818

19+
import ch.qos.logback.classic.Logger;
20+
import ch.qos.logback.classic.spi.ILoggingEvent;
21+
import ch.qos.logback.core.read.ListAppender;
1922
import org.junit.jupiter.api.Test;
2023

24+
import org.slf4j.LoggerFactory;
2125
import org.springframework.ai.chat.client.ChatClient;
2226
import org.springframework.ai.chat.prompt.PromptTemplate;
2327

28+
import static org.assertj.core.api.Assertions.assertThat;
2429
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2530
import static org.mockito.Mockito.mock;
2631

@@ -71,4 +76,90 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() {
7176
.hasMessageContaining("query");
7277
}
7378

79+
@Test
80+
void shouldLoggingWithMissingTargetPlaceholderInWarnMode() {
81+
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}");
82+
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
83+
84+
ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
85+
listAppender.start();
86+
logger.addAppender(listAppender);
87+
88+
RewriteQueryTransformer.builder()
89+
.chatClientBuilder(mock(ChatClient.Builder.class))
90+
.targetSearchSystem("vector store")
91+
.validationMode(ValidationMode.WARN)
92+
.promptTemplate(customPromptTemplate)
93+
.build();
94+
var logsList = listAppender.list;
95+
96+
assertThat(logsList).isNotEmpty();
97+
assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN);
98+
assertThat(logsList.get(0).getMessage())
99+
.contains("The following placeholders must be present in the prompt template: target");
100+
}
101+
102+
@Test
103+
void shouldLoggingWithMissingQueryPlaceholderInWarnMode() {
104+
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite for {target}");
105+
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
106+
107+
ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
108+
listAppender.start();
109+
logger.addAppender(listAppender);
110+
111+
RewriteQueryTransformer.builder()
112+
.chatClientBuilder(mock(ChatClient.Builder.class))
113+
.targetSearchSystem("search engine")
114+
.validationMode(ValidationMode.WARN)
115+
.promptTemplate(customPromptTemplate)
116+
.build();
117+
var logsList = listAppender.list;
118+
119+
assertThat(logsList).isNotEmpty();
120+
assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN);
121+
assertThat(logsList.get(0).getMessage())
122+
.contains("The following placeholders must be present in the prompt template: query");
123+
}
124+
125+
@Test
126+
void shouldContinueWithMissingTargetPlaceholderInNoneMode() {
127+
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {target}");
128+
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
129+
130+
ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
131+
listAppender.start();
132+
logger.addAppender(listAppender);
133+
134+
RewriteQueryTransformer.builder()
135+
.chatClientBuilder(mock(ChatClient.Builder.class))
136+
.targetSearchSystem("vector store")
137+
.validationMode(ValidationMode.NONE)
138+
.promptTemplate(customPromptTemplate)
139+
.build();
140+
var logsList = listAppender.list;
141+
142+
assertThat(logsList).isEmpty();
143+
}
144+
145+
@Test
146+
void shouldContinueWithMissingQueryPlaceholderInNoneMode() {
147+
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}");
148+
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
149+
150+
ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
151+
listAppender.start();
152+
logger.addAppender(listAppender);
153+
154+
RewriteQueryTransformer.builder()
155+
.chatClientBuilder(mock(ChatClient.Builder.class))
156+
.targetSearchSystem("search engine")
157+
.validationMode(ValidationMode.NONE)
158+
.promptTemplate(customPromptTemplate)
159+
.build();
160+
var logsList = listAppender.list;
161+
162+
assertThat(logsList).isEmpty();
163+
}
164+
74165
}

0 commit comments

Comments
 (0)