diff --git a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpanderTests.java b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpanderTests.java index fac30e815ca..e7e5ff81c01 100644 --- a/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpanderTests.java +++ b/spring-ai-rag/src/test/java/org/springframework/ai/rag/preretrieval/query/expansion/MultiQueryExpanderTests.java @@ -85,4 +85,66 @@ void whenPromptTemplateIsNullThenUseDefault() { assertThat(queryExpander).isNotNull(); } + @Test + void whenPromptTemplateHasBothPlaceholdersThenBuild() { + PromptTemplate validTemplate = new PromptTemplate("Generate {number} variations of: {query}"); + + MultiQueryExpander expander = MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(validTemplate) + .build(); + + assertThat(expander).isNotNull(); + } + + @Test + void whenPromptTemplateHasExtraPlaceholdersThenBuild() { + PromptTemplate templateWithExtra = new PromptTemplate( + "Generate {number} variations of: {query}. Context: {context}"); + + MultiQueryExpander expander = MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(templateWithExtra) + .build(); + + assertThat(expander).isNotNull(); + } + + @Test + void whenBuilderSetMultipleTimesThenUseLastValue() { + ChatClient.Builder firstBuilder = mock(ChatClient.Builder.class); + ChatClient.Builder secondBuilder = mock(ChatClient.Builder.class); + + MultiQueryExpander expander = MultiQueryExpander.builder() + .chatClientBuilder(firstBuilder) + .chatClientBuilder(secondBuilder) + .build(); + + assertThat(expander).isNotNull(); + } + + @Test + void whenPromptTemplateSetToNullAfterValidTemplateThenUseDefault() { + PromptTemplate validTemplate = new PromptTemplate("Config: {number} values for {query}"); + + MultiQueryExpander expander = MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(validTemplate) + .promptTemplate(null) + .build(); + + assertThat(expander).isNotNull(); + } + + @Test + void whenPromptTemplateHasPlaceholdersInDifferentCaseThenThrow() { + PromptTemplate templateWithWrongCase = new PromptTemplate("Generate {NUMBER} variations of: {QUERY}"); + + assertThatThrownBy(() -> MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(templateWithWrongCase) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template"); + } + }