diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java index fa096cab404..fd78a6aa29e 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java @@ -31,6 +31,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.retry.NonTransientAiException; import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; @@ -117,7 +118,7 @@ public void testRenderWithList() { assertEqualsWithNormalizedEOLs(expected, message.getText()); PromptTemplate unfilledPromptTemplate = new PromptTemplate(templateString); - assertThatExceptionOfType(IllegalStateException.class).isThrownBy(unfilledPromptTemplate::render) + assertThatExceptionOfType(NonTransientAiException.class).isThrownBy(unfilledPromptTemplate::render) .withMessage("Not all variables were replaced in the template. Missing variable names are: [items]."); } @@ -202,7 +203,7 @@ public void testRenderFailure() { PromptTemplate promptTemplate = PromptTemplate.builder().template(template).variables(model).build(); // Rendering the template with a missing key should throw an exception - assertThrows(IllegalStateException.class, promptTemplate::render); + assertThrows(NonTransientAiException.class, promptTemplate::render); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java index c3d33608f06..2905ab8ff81 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.retry.NonTransientAiException; import static org.assertj.core.api.Assertions.assertThat; @@ -43,7 +44,7 @@ void newApiPlaygroundTests() { // Try to render with missing value for template variable, expect exception Assertions.assertThatThrownBy(() -> pt.render(model)) - .isInstanceOf(IllegalStateException.class) + .isInstanceOf(NonTransientAiException.class) .hasMessage("Not all variables were replaced in the template. Missing variable names are: [lastName]."); pt.add("lastName", "Park"); // TODO investigate partial diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateBuilderTests.java index 249d980c615..a86e3800dcf 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateBuilderTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateBuilderTests.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.springframework.ai.retry.NonTransientAiException; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -85,7 +86,7 @@ void renderWithMissingVariableShouldThrow() { // If render() doesn't throw, fail the test Assertions.fail("Expected IllegalStateException was not thrown."); } - catch (IllegalStateException e) { + catch (NonTransientAiException e) { // Assert that the message is exactly the expected string assertThat(e.getMessage()) .isEqualTo("Not all variables were replaced in the template. Missing variable names are: [name]."); diff --git a/spring-ai-template-st/pom.xml b/spring-ai-template-st/pom.xml index e23d4b36acc..cf45ffbeb76 100644 --- a/spring-ai-template-st/pom.xml +++ b/spring-ai-template-st/pom.xml @@ -48,6 +48,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.antlr ST4 @@ -74,4 +80,4 @@ test - \ No newline at end of file + diff --git a/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java b/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java index 3780b948a09..74dc6964691 100644 --- a/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java +++ b/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java @@ -24,6 +24,7 @@ import org.antlr.runtime.TokenStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.retry.NonTransientAiException; import org.stringtemplate.v4.ST; import org.stringtemplate.v4.compiler.Compiler; import org.stringtemplate.v4.compiler.STLexer; @@ -115,7 +116,7 @@ private ST createST(String template) { return new ST(template, this.startDelimiterToken, this.endDelimiterToken); } catch (Exception ex) { - throw new IllegalArgumentException("The template string is not valid.", ex); + throw new NonTransientAiException("The template string is not valid.", ex); } } @@ -137,7 +138,7 @@ private Set validate(ST st, Map templateVariables) { logger.warn(VALIDATION_MESSAGE.formatted(missingVariables)); } else if (this.validationMode == ValidationMode.THROW) { - throw new IllegalStateException(VALIDATION_MESSAGE.formatted(missingVariables)); + throw new NonTransientAiException(VALIDATION_MESSAGE.formatted(missingVariables)); } } return missingVariables; diff --git a/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java b/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java index 4d4e979e869..a2b7869ef1d 100644 --- a/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java +++ b/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.retry.NonTransientAiException; import org.springframework.ai.template.ValidationMode; import org.springframework.test.util.ReflectionTestUtils; @@ -107,7 +108,7 @@ void shouldThrowExceptionForInvalidTemplateSyntax() { Map variables = new HashMap<>(); variables.put("name", "Spring AI"); - assertThatThrownBy(() -> renderer.apply("Hello {name!", variables)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> renderer.apply("Hello {name!", variables)).isInstanceOf(NonTransientAiException.class) .hasMessageContaining("The template string is not valid."); } @@ -118,7 +119,7 @@ void shouldThrowExceptionForMissingVariablesInThrowMode() { variables.put("greeting", "Hello"); assertThatThrownBy(() -> renderer.apply("{greeting} {name}!", variables)) - .isInstanceOf(IllegalStateException.class) + .isInstanceOf(NonTransientAiException.class) .hasMessageContaining( "Not all variables were replaced in the template. Missing variable names are: [name]"); }