Skip to content

Commit 3b4fac7

Browse files
committed
fix: Throw NonTransientAiException on syntax errors in prompt templates
IAE/ISE are too generic, it's hard to use libraries that throw them. Using NonTransientAiException clearly marks the exception as coming from Spring AI, plus will integrate nicely with code that already avoids re-sending requests when faced with this exception. Signed-off-by: Piotr Kubowicz <[email protected]>
1 parent e23bf1a commit 3b4fac7

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTemplateTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.chat.prompt.ChatOptions;
3232
import org.springframework.ai.chat.prompt.Prompt;
3333
import org.springframework.ai.chat.prompt.PromptTemplate;
34+
import org.springframework.ai.retry.NonTransientAiException;
3435
import org.springframework.core.io.InputStreamResource;
3536
import org.springframework.core.io.Resource;
3637

@@ -117,7 +118,7 @@ public void testRenderWithList() {
117118
assertEqualsWithNormalizedEOLs(expected, message.getText());
118119

119120
PromptTemplate unfilledPromptTemplate = new PromptTemplate(templateString);
120-
assertThatExceptionOfType(IllegalStateException.class).isThrownBy(unfilledPromptTemplate::render)
121+
assertThatExceptionOfType(NonTransientAiException.class).isThrownBy(unfilledPromptTemplate::render)
121122
.withMessage("Not all variables were replaced in the template. Missing variable names are: [items].");
122123
}
123124

@@ -202,7 +203,7 @@ public void testRenderFailure() {
202203
PromptTemplate promptTemplate = PromptTemplate.builder().template(template).variables(model).build();
203204

204205
// Rendering the template with a missing key should throw an exception
205-
assertThrows(IllegalStateException.class, promptTemplate::render);
206+
assertThrows(NonTransientAiException.class, promptTemplate::render);
206207
}
207208

208209
}

spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/PromptTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.ai.chat.prompt.Prompt;
2727
import org.springframework.ai.chat.prompt.PromptTemplate;
2828
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
29+
import org.springframework.ai.retry.NonTransientAiException;
2930

3031
import static org.assertj.core.api.Assertions.assertThat;
3132

@@ -43,7 +44,7 @@ void newApiPlaygroundTests() {
4344

4445
// Try to render with missing value for template variable, expect exception
4546
Assertions.assertThatThrownBy(() -> pt.render(model))
46-
.isInstanceOf(IllegalStateException.class)
47+
.isInstanceOf(NonTransientAiException.class)
4748
.hasMessage("Not all variables were replaced in the template. Missing variable names are: [lastName].");
4849

4950
pt.add("lastName", "Park"); // TODO investigate partial

spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTemplateBuilderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.junit.jupiter.api.Assertions;
2323
import org.junit.jupiter.api.Test;
24+
import org.springframework.ai.retry.NonTransientAiException;
2425

2526
import static org.assertj.core.api.Assertions.assertThat;
2627
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -85,7 +86,7 @@ void renderWithMissingVariableShouldThrow() {
8586
// If render() doesn't throw, fail the test
8687
Assertions.fail("Expected IllegalStateException was not thrown.");
8788
}
88-
catch (IllegalStateException e) {
89+
catch (NonTransientAiException e) {
8990
// Assert that the message is exactly the expected string
9091
assertThat(e.getMessage())
9192
.isEqualTo("Not all variables were replaced in the template. Missing variable names are: [name].");

spring-ai-template-st/pom.xml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
<version>${project.parent.version}</version>
4949
</dependency>
5050

51+
<dependency>
52+
<groupId>org.springframework.ai</groupId>
53+
<artifactId>spring-ai-retry</artifactId>
54+
<version>${project.parent.version}</version>
55+
</dependency>
56+
5157
<dependency>
5258
<groupId>org.antlr</groupId>
5359
<artifactId>ST4</artifactId>
@@ -74,4 +80,4 @@
7480
<scope>test</scope>
7581
</dependency>
7682
</dependencies>
77-
</project>
83+
</project>

spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.antlr.runtime.TokenStream;
2525
import org.slf4j.Logger;
2626
import org.slf4j.LoggerFactory;
27+
import org.springframework.ai.retry.NonTransientAiException;
2728
import org.stringtemplate.v4.ST;
2829
import org.stringtemplate.v4.compiler.Compiler;
2930
import org.stringtemplate.v4.compiler.STLexer;
@@ -115,7 +116,7 @@ private ST createST(String template) {
115116
return new ST(template, this.startDelimiterToken, this.endDelimiterToken);
116117
}
117118
catch (Exception ex) {
118-
throw new IllegalArgumentException("The template string is not valid.", ex);
119+
throw new NonTransientAiException("The template string is not valid.", ex);
119120
}
120121
}
121122

@@ -137,7 +138,7 @@ private Set<String> validate(ST st, Map<String, Object> templateVariables) {
137138
logger.warn(VALIDATION_MESSAGE.formatted(missingVariables));
138139
}
139140
else if (this.validationMode == ValidationMode.THROW) {
140-
throw new IllegalStateException(VALIDATION_MESSAGE.formatted(missingVariables));
141+
throw new NonTransientAiException(VALIDATION_MESSAGE.formatted(missingVariables));
141142
}
142143
}
143144
return missingVariables;

spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.junit.jupiter.api.Test;
2323

24+
import org.springframework.ai.retry.NonTransientAiException;
2425
import org.springframework.ai.template.ValidationMode;
2526
import org.springframework.test.util.ReflectionTestUtils;
2627

@@ -107,7 +108,7 @@ void shouldThrowExceptionForInvalidTemplateSyntax() {
107108
Map<String, Object> variables = new HashMap<>();
108109
variables.put("name", "Spring AI");
109110

110-
assertThatThrownBy(() -> renderer.apply("Hello {name!", variables)).isInstanceOf(IllegalArgumentException.class)
111+
assertThatThrownBy(() -> renderer.apply("Hello {name!", variables)).isInstanceOf(NonTransientAiException.class)
111112
.hasMessageContaining("The template string is not valid.");
112113
}
113114

@@ -118,7 +119,7 @@ void shouldThrowExceptionForMissingVariablesInThrowMode() {
118119
variables.put("greeting", "Hello");
119120

120121
assertThatThrownBy(() -> renderer.apply("{greeting} {name}!", variables))
121-
.isInstanceOf(IllegalStateException.class)
122+
.isInstanceOf(NonTransientAiException.class)
122123
.hasMessageContaining(
123124
"Not all variables were replaced in the template. Missing variable names are: [name]");
124125
}

0 commit comments

Comments
 (0)