From 5b2257b373a874c50004c9c886fad071e5509b1d Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Thu, 28 Aug 2025 11:22:33 +0800 Subject: [PATCH 1/3] feat: Added a `ValidationMode` property to `RewriteQueryTransformer`, allowing users to customize the behavior of template content validation when creating an instance of the object. Signed-off-by: Sun Yuhan --- .../RewriteQueryTransformer.java | 44 ++++++++- .../query/transformation/ValidationMode.java | 41 +++++++++ .../RewriteQueryTransformerTests.java | 91 +++++++++++++++++++ 3 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java index 5955cc6ac58..97ea1c71250 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java @@ -35,6 +35,7 @@ * irrelevant information that may affect the quality of the search results. * * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 * @see arXiv:2305.14283 */ @@ -60,6 +61,8 @@ public class RewriteQueryTransformer implements QueryTransformer { private final String targetSearchSystem; + private final ValidationMode validationMode; + public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, @Nullable String targetSearchSystem) { Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); @@ -67,8 +70,19 @@ public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable P this.chatClient = chatClientBuilder.build(); this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET; + this.validationMode = ValidationMode.THROW; + validate(); + } + + public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, + @Nullable String targetSearchSystem, @Nullable ValidationMode validationMode) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); - PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET; + this.validationMode = validationMode != null ? validationMode : ValidationMode.THROW; + validate(); } @Override @@ -92,6 +106,23 @@ public Query transform(Query query) { return query.mutate().text(rewrittenQueryText).build(); } + /** + * Verify whether the template contains the required variables. + */ + private void validate() { + switch (this.validationMode) { + case THROW -> PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + case WARN -> { + try { + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + } + catch (IllegalArgumentException e) { + logger.warn(e.getMessage()); + } + } + } + } + public static Builder builder() { return new Builder(); } @@ -106,6 +137,9 @@ public static final class Builder { @Nullable private String targetSearchSystem; + @Nullable + private ValidationMode validationMode; + private Builder() { } @@ -124,8 +158,14 @@ public Builder targetSearchSystem(String targetSearchSystem) { return this; } + public Builder validationMode(ValidationMode validationMode) { + this.validationMode = validationMode; + return this; + } + public RewriteQueryTransformer build() { - return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem); + return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem, + this.validationMode); } } diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java new file mode 100644 index 00000000000..c621d65a487 --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java @@ -0,0 +1,41 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.preretrieval.query.transformation; + +/** + * Validation modes for {@link RewriteQueryTransformer}. + * + * @author Sun Yuhan + */ +public enum ValidationMode { + + /** + * If the validation fails, an exception is thrown. This is the default mode. + */ + THROW, + + /** + * If the validation fails, a warning is logged. + */ + WARN, + + /** + * No validation is performed. + */ + NONE + +} diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java index 08096e99588..ae1bb7f8ec0 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java @@ -16,11 +16,16 @@ package org.springframework.ai.rag.preretrieval.query.transformation; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.prompt.PromptTemplate; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; @@ -71,4 +76,90 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() { .hasMessageContaining("query"); } + @Test + void shouldLoggingWithMissingTargetPlaceholderInWarnMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("vector store") + .validationMode(ValidationMode.WARN) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isNotEmpty(); + assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN); + assertThat(logsList.get(0).getMessage()) + .contains("The following placeholders must be present in the prompt template: target"); + } + + @Test + void shouldLoggingWithMissingQueryPlaceholderInWarnMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite for {target}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("search engine") + .validationMode(ValidationMode.WARN) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isNotEmpty(); + assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN); + assertThat(logsList.get(0).getMessage()) + .contains("The following placeholders must be present in the prompt template: query"); + } + + @Test + void shouldContinueWithMissingTargetPlaceholderInNoneMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {target}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("vector store") + .validationMode(ValidationMode.NONE) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isEmpty(); + } + + @Test + void shouldContinueWithMissingQueryPlaceholderInNoneMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("search engine") + .validationMode(ValidationMode.NONE) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isEmpty(); + } + } From 7c79b2384948cd4165684a9a942454fef13788fc Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Thu, 28 Aug 2025 11:32:20 +0800 Subject: [PATCH 2/3] fix: fix the import Signed-off-by: Sun Yuhan --- .../query/transformation/RewriteQueryTransformerTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java index ae1bb7f8ec0..86789d710bc 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java @@ -20,7 +20,6 @@ import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.core.read.ListAppender; import org.junit.jupiter.api.Test; - import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.prompt.PromptTemplate; From 0b7df9e5961a1bb5bb562a2c514e2f59584c2af1 Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Thu, 28 Aug 2025 11:35:36 +0800 Subject: [PATCH 3/3] fix: fix the import Signed-off-by: Sun Yuhan --- .../query/transformation/RewriteQueryTransformerTests.java | 1 + 1 file changed, 1 insertion(+) diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java index 86789d710bc..2b795d1f816 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java @@ -21,6 +21,7 @@ import ch.qos.logback.core.read.ListAppender; import org.junit.jupiter.api.Test; import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.prompt.PromptTemplate;