Skip to content

Commit 9df3fbe

Browse files
committed
test: Add tests for MistralAI retry logic and AOT native image support
Adds essential test coverage for production reliability and native image compatibility Co-authored-by: Oleksandr Klymenko <[email protected]> Signed-off-by: Oleksandr Klymenko <[email protected]>
1 parent 27b09fe commit 9df3fbe

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ public void mistralAiEmbeddingNonTransientError() {
178178
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
179179
}
180180

181+
@Test
182+
public void mistralAiChatMixedTransientAndNonTransientErrors() {
183+
given(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
184+
.willThrow(new TransientAiException("Transient Error"))
185+
.willThrow(new RuntimeException("Non Transient Error"));
186+
187+
// Should fail immediately on non-transient error, no further retries
188+
assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text")));
189+
190+
// Should have 1 retry attempt before hitting non-transient error
191+
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
192+
}
193+
181194
private static class TestRetryListener implements RetryListener {
182195

183196
int onErrorRetryCount = 0;

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,60 @@ void verifyPackageScanningWorks() {
118118
assertThat(jsonAnnotatedClasses.size()).isGreaterThan(0);
119119
}
120120

121+
@Test
122+
void verifyAllCriticalApiClassesAreRegistered() {
123+
RuntimeHints runtimeHints = new RuntimeHints();
124+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
125+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
126+
127+
Set<TypeReference> registeredTypes = new HashSet<>();
128+
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
129+
130+
// Critical API classes that must be registered for runtime
131+
String[] criticalClasses = { "MistralAiApi$ChatCompletionRequest", "MistralAiApi$ChatCompletionMessage",
132+
"MistralAiApi$EmbeddingRequest", "MistralAiApi$EmbeddingList", "MistralAiApi$Usage" };
133+
134+
for (String className : criticalClasses) {
135+
assertThat(registeredTypes.stream()
136+
.anyMatch(tr -> tr.getName().contains(className.replace("$", "."))
137+
|| tr.getName().contains(className.replace("$", "$"))))
138+
.as("Critical class %s should be registered", className)
139+
.isTrue();
140+
}
141+
}
142+
143+
@Test
144+
void verifyEnumTypesAreRegistered() {
145+
RuntimeHints runtimeHints = new RuntimeHints();
146+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
147+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
148+
149+
Set<TypeReference> registeredTypes = new HashSet<>();
150+
runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
151+
152+
// Enums are critical for JSON deserialization in native images
153+
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.ChatModel.class)))
154+
.as("ChatModel enum should be registered")
155+
.isTrue();
156+
157+
assertThat(registeredTypes.contains(TypeReference.of(MistralAiApi.EmbeddingModel.class)))
158+
.as("EmbeddingModel enum should be registered")
159+
.isTrue();
160+
}
161+
162+
@Test
163+
void verifyReflectionHintsIncludeConstructors() {
164+
RuntimeHints runtimeHints = new RuntimeHints();
165+
MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints();
166+
mistralAiRuntimeHints.registerHints(runtimeHints, null);
167+
168+
// Verify that reflection hints include constructor access
169+
boolean hasConstructorHints = runtimeHints.reflection()
170+
.typeHints()
171+
.anyMatch(typeHint -> typeHint.constructors().findAny().isPresent() || typeHint.getMemberCategories()
172+
.contains(org.springframework.aot.hint.MemberCategory.INVOKE_DECLARED_CONSTRUCTORS));
173+
174+
assertThat(hasConstructorHints).as("Should register constructor hints for JSON deserialization").isTrue();
175+
}
176+
121177
}

0 commit comments

Comments
 (0)