diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java
index 5955cc6ac58..97ea1c71250 100644
--- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java
+++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java
@@ -35,6 +35,7 @@
* irrelevant information that may affect the quality of the search results.
*
* @author Thomas Vitale
+ * @author Sun Yuhan
* @since 1.0.0
* @see arXiv:2305.14283
*/
@@ -60,6 +61,8 @@ public class RewriteQueryTransformer implements QueryTransformer {
private final String targetSearchSystem;
+ private final ValidationMode validationMode;
+
public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
@Nullable String targetSearchSystem) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
@@ -67,8 +70,19 @@ public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable P
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
+ this.validationMode = ValidationMode.THROW;
+ validate();
+ }
+
+ public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
+ @Nullable String targetSearchSystem, @Nullable ValidationMode validationMode) {
+ Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");
- PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
+ this.chatClient = chatClientBuilder.build();
+ this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
+ this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
+ this.validationMode = validationMode != null ? validationMode : ValidationMode.THROW;
+ validate();
}
@Override
@@ -92,6 +106,23 @@ public Query transform(Query query) {
return query.mutate().text(rewrittenQueryText).build();
}
+ /**
+ * Verify whether the template contains the required variables.
+ */
+ private void validate() {
+ switch (this.validationMode) {
+ case THROW -> PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
+ case WARN -> {
+ try {
+ PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
+ }
+ catch (IllegalArgumentException e) {
+ logger.warn(e.getMessage());
+ }
+ }
+ }
+ }
+
public static Builder builder() {
return new Builder();
}
@@ -106,6 +137,9 @@ public static final class Builder {
@Nullable
private String targetSearchSystem;
+ @Nullable
+ private ValidationMode validationMode;
+
private Builder() {
}
@@ -124,8 +158,14 @@ public Builder targetSearchSystem(String targetSearchSystem) {
return this;
}
+ public Builder validationMode(ValidationMode validationMode) {
+ this.validationMode = validationMode;
+ return this;
+ }
+
public RewriteQueryTransformer build() {
- return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem);
+ return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem,
+ this.validationMode);
}
}
diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java
new file mode 100644
index 00000000000..c621d65a487
--- /dev/null
+++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2025-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.ai.rag.preretrieval.query.transformation;
+
+/**
+ * Validation modes for {@link RewriteQueryTransformer}.
+ *
+ * @author Sun Yuhan
+ */
+public enum ValidationMode {
+
+ /**
+ * If the validation fails, an exception is thrown. This is the default mode.
+ */
+ THROW,
+
+ /**
+ * If the validation fails, a warning is logged.
+ */
+ WARN,
+
+ /**
+ * No validation is performed.
+ */
+ NONE
+
+}
diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java
index 08096e99588..2b795d1f816 100644
--- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java
+++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java
@@ -16,11 +16,16 @@
package org.springframework.ai.rag.preretrieval.query.transformation;
+import ch.qos.logback.classic.Logger;
+import ch.qos.logback.classic.spi.ILoggingEvent;
+import ch.qos.logback.core.read.ListAppender;
import org.junit.jupiter.api.Test;
+import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
@@ -71,4 +76,90 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() {
.hasMessageContaining("query");
}
+ @Test
+ void shouldLoggingWithMissingTargetPlaceholderInWarnMode() {
+ PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}");
+ Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
+
+ ListAppender listAppender = new ListAppender<>();
+ listAppender.start();
+ logger.addAppender(listAppender);
+
+ RewriteQueryTransformer.builder()
+ .chatClientBuilder(mock(ChatClient.Builder.class))
+ .targetSearchSystem("vector store")
+ .validationMode(ValidationMode.WARN)
+ .promptTemplate(customPromptTemplate)
+ .build();
+ var logsList = listAppender.list;
+
+ assertThat(logsList).isNotEmpty();
+ assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN);
+ assertThat(logsList.get(0).getMessage())
+ .contains("The following placeholders must be present in the prompt template: target");
+ }
+
+ @Test
+ void shouldLoggingWithMissingQueryPlaceholderInWarnMode() {
+ PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite for {target}");
+ Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
+
+ ListAppender listAppender = new ListAppender<>();
+ listAppender.start();
+ logger.addAppender(listAppender);
+
+ RewriteQueryTransformer.builder()
+ .chatClientBuilder(mock(ChatClient.Builder.class))
+ .targetSearchSystem("search engine")
+ .validationMode(ValidationMode.WARN)
+ .promptTemplate(customPromptTemplate)
+ .build();
+ var logsList = listAppender.list;
+
+ assertThat(logsList).isNotEmpty();
+ assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN);
+ assertThat(logsList.get(0).getMessage())
+ .contains("The following placeholders must be present in the prompt template: query");
+ }
+
+ @Test
+ void shouldContinueWithMissingTargetPlaceholderInNoneMode() {
+ PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {target}");
+ Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
+
+ ListAppender listAppender = new ListAppender<>();
+ listAppender.start();
+ logger.addAppender(listAppender);
+
+ RewriteQueryTransformer.builder()
+ .chatClientBuilder(mock(ChatClient.Builder.class))
+ .targetSearchSystem("vector store")
+ .validationMode(ValidationMode.NONE)
+ .promptTemplate(customPromptTemplate)
+ .build();
+ var logsList = listAppender.list;
+
+ assertThat(logsList).isEmpty();
+ }
+
+ @Test
+ void shouldContinueWithMissingQueryPlaceholderInNoneMode() {
+ PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}");
+ Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);
+
+ ListAppender listAppender = new ListAppender<>();
+ listAppender.start();
+ logger.addAppender(listAppender);
+
+ RewriteQueryTransformer.builder()
+ .chatClientBuilder(mock(ChatClient.Builder.class))
+ .targetSearchSystem("search engine")
+ .validationMode(ValidationMode.NONE)
+ .promptTemplate(customPromptTemplate)
+ .build();
+ var logsList = listAppender.list;
+
+ assertThat(logsList).isEmpty();
+ }
+
}