diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java index 5955cc6ac58..97ea1c71250 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformer.java @@ -35,6 +35,7 @@ * irrelevant information that may affect the quality of the search results. * * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 * @see arXiv:2305.14283 */ @@ -60,6 +61,8 @@ public class RewriteQueryTransformer implements QueryTransformer { private final String targetSearchSystem; + private final ValidationMode validationMode; + public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, @Nullable String targetSearchSystem) { Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); @@ -67,8 +70,19 @@ public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable P this.chatClient = chatClientBuilder.build(); this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET; + this.validationMode = ValidationMode.THROW; + validate(); + } + + public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, + @Nullable String targetSearchSystem, @Nullable ValidationMode validationMode) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); - PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET; + this.validationMode = validationMode != null ? validationMode : ValidationMode.THROW; + validate(); } @Override @@ -92,6 +106,23 @@ public Query transform(Query query) { return query.mutate().text(rewrittenQueryText).build(); } + /** + * Verify whether the template contains the required variables. + */ + private void validate() { + switch (this.validationMode) { + case THROW -> PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + case WARN -> { + try { + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query"); + } + catch (IllegalArgumentException e) { + logger.warn(e.getMessage()); + } + } + } + } + public static Builder builder() { return new Builder(); } @@ -106,6 +137,9 @@ public static final class Builder { @Nullable private String targetSearchSystem; + @Nullable + private ValidationMode validationMode; + private Builder() { } @@ -124,8 +158,14 @@ public Builder targetSearchSystem(String targetSearchSystem) { return this; } + public Builder validationMode(ValidationMode validationMode) { + this.validationMode = validationMode; + return this; + } + public RewriteQueryTransformer build() { - return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem); + return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem, + this.validationMode); } } diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java new file mode 100644 index 00000000000..c621d65a487 --- /dev/null +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/preretrieval/query/transformation/ValidationMode.java @@ -0,0 +1,41 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.preretrieval.query.transformation; + +/** + * Validation modes for {@link RewriteQueryTransformer}. + * + * @author Sun Yuhan + */ +public enum ValidationMode { + + /** + * If the validation fails, an exception is thrown. This is the default mode. + */ + THROW, + + /** + * If the validation fails, a warning is logged. + */ + WARN, + + /** + * No validation is performed. + */ + NONE + +} diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java index 08096e99588..2b795d1f816 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/transformation/RewriteQueryTransformerTests.java @@ -16,11 +16,16 @@ package org.springframework.ai.rag.preretrieval.query.transformation; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.prompt.PromptTemplate; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; @@ -71,4 +76,90 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() { .hasMessageContaining("query"); } + @Test + void shouldLoggingWithMissingTargetPlaceholderInWarnMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("vector store") + .validationMode(ValidationMode.WARN) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isNotEmpty(); + assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN); + assertThat(logsList.get(0).getMessage()) + .contains("The following placeholders must be present in the prompt template: target"); + } + + @Test + void shouldLoggingWithMissingQueryPlaceholderInWarnMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite for {target}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("search engine") + .validationMode(ValidationMode.WARN) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isNotEmpty(); + assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN); + assertThat(logsList.get(0).getMessage()) + .contains("The following placeholders must be present in the prompt template: query"); + } + + @Test + void shouldContinueWithMissingTargetPlaceholderInNoneMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {target}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("vector store") + .validationMode(ValidationMode.NONE) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isEmpty(); + } + + @Test + void shouldContinueWithMissingQueryPlaceholderInNoneMode() { + PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}"); + Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class); + + ListAppender listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + + RewriteQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetSearchSystem("search engine") + .validationMode(ValidationMode.NONE) + .promptTemplate(customPromptTemplate) + .build(); + var logsList = listAppender.list; + + assertThat(logsList).isEmpty(); + } + }